# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import enum
import diffusers
import dgenerate.hfhub as _hfhub
import dgenerate.memoize as _d_memoize
import dgenerate.memory as _memory
import dgenerate.messages as _messages
import dgenerate.pipelinewrapper.enums as _enums
import dgenerate.pipelinewrapper.util as _pipelinewrapper_util
import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types
from dgenerate.memoize import memoize as _memoize
from dgenerate.pipelinewrapper import constants as _constants
from dgenerate.pipelinewrapper.uris import exceptions as _exceptions
from dgenerate.pipelinewrapper.uris import util as _util
_controlnet_uri_parser = _textprocessing.ConceptUriParser(
'ControlNet', ['scale', 'start', 'end', 'mode', 'revision', 'variant', 'subfolder', 'dtype', 'quantizer'])
_controlnet_cache = _d_memoize.create_object_cache(
'controlnet', cache_type=_memory.SizedConstrainedObjectCache
)
[docs]
class FluxControlNetUnionUriModes(enum.IntEnum):
"""
Represents controlnet modes associated with the Flux Union controlnet.
"""
CANNY = 0
TILE = 1
DEPTH = 2
BLUR = 3
POSE = 4
GRAY = 5
LQ = 6
[docs]
class SDXLControlNetUnionUriModes(enum.IntEnum):
"""
Represents controlnet modes associated with the SDXL Union controlnet.
"""
OPENPOSE = 0
DEPTH = 1
HED = 2
PIDI = 2
SCRIBBLE = 2
TED = 2
CANNY = 3
LINEART = 3
ANIME_LINEART = 3
MLSD = 3
NORMAL = 4
SEGMENT = 5
[docs]
class ControlNetUri:
"""
Representation of ``--control-nets`` URI.
"""
# pipelinewrapper.uris.util.get_uri_accepted_args_schema metadata
NAMES = ['Control Net']
[docs]
@staticmethod
def help():
import dgenerate.arguments as _a
return _a.get_raw_help_text('--control-nets')
# Arguments that should be hidden from schema
# generation, because they are not parsed from the URI
HIDE_ARGS = {'model-type'}
OPTION_ARGS = {
'dtype': ['float16', 'bfloat16', 'float32']
}
FILE_ARGS = {
'model': {'mode': ['in', 'dir'], 'filetypes': [('Models', ['*.safetensors', '*.pt', '*.pth', '*.cpkt', '*.bin'])]}
}
# ===
@property
def model(self) -> str:
"""
Model path, huggingface slug
"""
return self._model
@property
def revision(self) -> _types.OptionalString:
"""
Model repo revision
"""
return self._revision
@property
def variant(self) -> _types.OptionalString:
"""
Model repo revision
"""
return self._variant
@property
def subfolder(self) -> _types.OptionalPath:
"""
Model repo subfolder
"""
return self._subfolder
@property
def dtype(self) -> _enums.DataType | None:
"""
Model dtype (precision)
"""
return self._dtype
@property
def scale(self) -> float:
"""
ControlNet guidance scale
"""
return self._scale
@property
def start(self) -> float:
"""
ControlNet guidance start point, fraction of inference / timesteps.
"""
return self._start
@property
def end(self) -> float:
"""
ControlNet guidance end point, fraction of inference / timesteps.
"""
return self._end
@property
def mode(self) -> int | None:
"""
Union ControlNet mode.
"""
return self._mode
@property
def model_type(self) -> _enums.ModelType:
"""
Model type the ControlNet model is expected to attach to.
"""
return self._model_type
@property
def quantizer(self) -> _types.OptionalUri:
"""
--quantizer URI override
"""
return self._quantizer
[docs]
def __init__(self,
model: str,
revision: _types.OptionalString,
variant: _types.OptionalString,
subfolder: _types.OptionalPath,
dtype: _enums.DataType | str | None = None,
scale: float = 1.0,
start: float = 0.0,
end: float = 1.0,
mode: int | str | FluxControlNetUnionUriModes | SDXLControlNetUnionUriModes | None = None,
quantizer: _types.OptionalUri = None,
model_type: _enums.ModelType = _enums.ModelType.SD):
"""
:param model: model path
:param revision: model revision (branch name)
:param variant: model variant, for example ``fp16``
:param subfolder: model subfolder
:param dtype: model data type (precision)
:param scale: controlnet scale
:param start: controlnet guidance start value
:param end: controlnet guidance end value
:param mode: Flux / SDXL Union controlnet mode.
:param quantizer: --quantizer URI override
:param model_type: Model type this ControlNet will be attached to.
:raises InvalidControlNetUriError: If ``dtype`` is passed an invalid data type string,
or if ``model`` points to a single file and ``quantizer`` is specified (not supported).
"""
if _hfhub.is_single_file_model_load(model):
if quantizer:
raise _exceptions.InvalidControlNetUriError(
'specifying a ControlNet quantizer URI is only supported for Hugging Face '
'repository loads from a repo slug or disk path, single file loads are not supported.')
self._model = model
self._revision = revision
self._variant = variant
self._subfolder = subfolder
self._quantizer = quantizer
self._model_type = model_type
if isinstance(mode, str):
if _enums.model_type_is_sdxl(model_type):
self._mode = ControlNetUri._sdxl_mode_int_from_str(mode)
elif _enums.model_type_is_flux(model_type):
self._mode = ControlNetUri._flux_mode_int_from_str(mode)
else:
raise _exceptions.InvalidControlNetUriError(
f'Torch ControlNet "mode" argument not supported '
f'for model type: {_enums.get_model_type_string(model_type)}.'
)
else:
self._mode = int(mode) if mode is not None else None
try:
self._dtype = _enums.get_data_type_enum(dtype) if dtype else None
except ValueError:
raise _exceptions.InvalidControlNetUriError(
f'invalid dtype string, must be one of: {_textprocessing.oxford_comma(_enums.supported_data_type_strings(), "or")}')
self._scale = scale
self._start = start
self._end = end
[docs]
def load(self,
dtype_fallback: _enums.DataType = _enums.DataType.AUTO,
use_auth_token: _types.OptionalString = None,
local_files_only: bool = False,
no_cache: bool = False,
device_map: str | None = None,
model_class:
type[diffusers.ControlNetModel] |
type[diffusers.ControlNetUnionModel] |
type[diffusers.SD3ControlNetModel] |
type[diffusers.FluxControlNetModel] | None = None) -> \
diffusers.ControlNetModel | \
diffusers.ControlNetUnionModel | \
diffusers.SD3ControlNetModel | \
diffusers.FluxControlNetModel:
"""
Load a :py:class:`diffusers.ControlNetModel` from this URI.
:param dtype_fallback: Fallback datatype if ``dtype`` was not specified in the URI.
:param use_auth_token: Optional huggingface API auth token, used for downloading
restricted repos that your account has access to.
:param local_files_only: Avoid connecting to huggingface to download models and
only use cached models?
:param no_cache: If True, force the returned object not to be cached by the memoize decorator.
:param device_map: device placement strategy for quantized models, defaults to ``None``
:param model_class: What class of controlnet model should be loaded?
if ``None`` is specified, load based off :py:attr:`ControlNetUri.model_type`
and provided URI arguments.
:raises ModelNotFoundError: If the model could not be found.
:return: :py:class:`diffusers.ControlNetModel`, :py:class:`diffusers.SD3ControlNetModel`, or :py:class:`diffusers.FluxControlNetModel`
"""
def cache_all(e):
raise _exceptions.ControlNetUriLoadError(
f'error loading controlnet "{self.model}": {e}') from e
with _hfhub.with_hf_errors_as_model_not_found(cache_all):
if model_class is None:
if _enums.model_type_is_flux(self.model_type):
model_class = diffusers.FluxControlNetModel
elif _enums.model_type_is_sd3(self.model_type):
model_class = diffusers.SD3ControlNetModel
elif _enums.model_type_is_sdxl(self.model_type) and self.mode is not None:
model_class = diffusers.ControlNetUnionModel
else:
model_class = diffusers.ControlNetModel
return self._load(dtype_fallback,
use_auth_token,
local_files_only,
no_cache,
device_map,
model_class)
@staticmethod
def _enforce_cache_size(new_controlnet_size):
_controlnet_cache.enforce_cpu_mem_constraints(
_constants.CONTROLNET_CACHE_MEMORY_CONSTRAINTS,
size_var='controlnet_size',
new_object_size=new_controlnet_size)
@_memoize(_controlnet_cache,
exceptions={'local_files_only'},
hasher=lambda args: _d_memoize.args_cache_key(
args, {'self': lambda o: _d_memoize.property_hasher(
o, exclude={'scale', 'start', 'end'})}),
on_hit=lambda key, hit: _d_memoize.simple_cache_hit_debug("Torch ControlNet", key, hit),
on_create=lambda key, new: _d_memoize.simple_cache_miss_debug("Torch ControlNet", key, new))
def _load(self,
dtype_fallback: _enums.DataType = _enums.DataType.AUTO,
use_auth_token: _types.OptionalString = None,
local_files_only: bool = False,
no_cache: bool = False,
device_map: str | None = None,
model_class:
type[diffusers.ControlNetModel] |
type[diffusers.ControlNetUnionModel] |
type[diffusers.SD3ControlNetModel] |
type[diffusers.FluxControlNetModel] = diffusers.ControlNetModel) -> \
diffusers.ControlNetModel | \
diffusers.ControlNetUnionModel | \
diffusers.SD3ControlNetModel | \
diffusers.FluxControlNetModel:
if model_class not in {diffusers.FluxControlNetModel, diffusers.ControlNetUnionModel}:
if self.mode is not None:
raise ValueError(
f'The "mode" argument of ControlNet "{self.model}" is invalid to use '
'with non Flux / SDXL ControlNet Union models.'
)
model_path = _hfhub.download_non_hf_slug_model(self.model)
single_file_load_path = _hfhub.is_single_file_model_load(model_path)
torch_dtype = _enums.get_torch_dtype(
dtype_fallback if self.dtype is None else self.dtype)
if self.quantizer:
quant_config = _util.get_quantizer_uri_class(
self.quantizer,
_exceptions.InvalidControlNetUriError
).parse(self.quantizer).to_config(torch_dtype)
else:
quant_config = None
if single_file_load_path:
estimated_memory_usage = _pipelinewrapper_util.estimate_model_memory_use(
repo_id=model_path,
revision=self.revision,
use_auth_token=use_auth_token,
local_files_only=local_files_only
)
self._enforce_cache_size(estimated_memory_usage)
new_net = model_class.from_single_file(
model_path,
revision=self.revision,
torch_dtype=torch_dtype,
token=use_auth_token,
local_files_only=local_files_only)
else:
estimated_memory_usage = _pipelinewrapper_util.estimate_model_memory_use(
repo_id=model_path,
revision=self.revision,
variant=self.variant,
subfolder=self.subfolder,
use_auth_token=use_auth_token,
local_files_only=local_files_only
)
self._enforce_cache_size(estimated_memory_usage)
new_net = model_class.from_pretrained(
model_path,
revision=self.revision,
variant=self.variant,
subfolder=self.subfolder,
torch_dtype=torch_dtype,
token=use_auth_token,
local_files_only=local_files_only,
quantization_config=quant_config,
device_map=device_map)
_messages.debug_log('Estimated Torch ControlNet Memory Use:',
_memory.bytes_best_human_unit(estimated_memory_usage))
_util._patch_module_to_for_sized_cache(_controlnet_cache, new_net)
# noinspection PyTypeChecker
return new_net, _d_memoize.CachedObjectMetadata(
size=estimated_memory_usage,
skip=self.quantizer or no_cache
)
[docs]
@staticmethod
def parse(uri: _types.Uri,
model_type=_enums.ModelType.SD) -> 'ControlNetUri':
"""
Parse a ``--control-nets`` uri specification and return an object representing its constituents
:param uri: string with ``--control-nets`` uri syntax
:param model_type: model type that the ControlNet will be attached to.
:raise InvalidControlNetUriError:
:return: :py:class:`.TorchControlNetUri`
"""
try:
r = _controlnet_uri_parser.parse(uri)
dtype = r.args.get('dtype')
scale = r.args.get('scale', 1.0)
start = r.args.get('start', 0.0)
end = r.args.get('end', 1.0)
mode = r.args.get('mode', None)
supported_dtypes = _enums.supported_data_type_strings()
if dtype is not None and dtype not in supported_dtypes:
raise _exceptions.InvalidControlNetUriError(
f'Torch ControlNet "dtype" must be {", ".join(supported_dtypes)}, '
f'or left undefined, received: {dtype}')
try:
scale = float(scale)
except ValueError:
raise _exceptions.InvalidControlNetUriError(
f'Torch ControlNet "scale" must be a floating point number, received: {scale}')
try:
start = float(start)
except ValueError:
raise _exceptions.InvalidControlNetUriError(
f'Torch ControlNet "start" must be a floating point number, received: {start}')
try:
end = float(end)
except ValueError:
raise _exceptions.InvalidControlNetUriError(
f'Torch ControlNet "end" must be a floating point number, received: {end}')
if start > end:
raise _exceptions.InvalidControlNetUriError(
f'Torch ControlNet "start" must be less than or equal to "end".')
if mode is not None:
if _enums.model_type_is_sdxl(model_type):
mode = ControlNetUri._sdxl_mode_int_from_str(mode)
elif _enums.model_type_is_flux(model_type):
mode = ControlNetUri._flux_mode_int_from_str(mode)
else:
raise _exceptions.InvalidControlNetUriError(
f'Torch ControlNet "mode" argument not supported '
f'for model type: {_enums.get_model_type_string(model_type)}.'
)
return ControlNetUri(
model=r.concept,
revision=r.args.get('revision', None),
variant=r.args.get('variant', None),
subfolder=r.args.get('subfolder', None),
dtype=dtype,
scale=scale,
start=start,
end=end,
mode=mode,
quantizer=r.args.get('quantizer', None),
model_type=model_type
)
except _textprocessing.ConceptUriParseError as e:
raise _exceptions.InvalidControlNetUriError(e) from e
@staticmethod
def _sdxl_mode_int_from_str(mode):
modes = _textprocessing.oxford_comma(
[n.name.lower() for n in SDXLControlNetUnionUriModes], "or")
try:
try:
mode = int(mode)
except ValueError:
mode = SDXLControlNetUnionUriModes[mode.upper()].value
except KeyError:
raise _exceptions.InvalidControlNetUriError(
f'Torch SDXL Union ControlNet "mode" must be an integer, '
f'or one of: {modes}. received: {mode}')
if mode >= len(SDXLControlNetUnionUriModes) or mode < 0:
raise _exceptions.InvalidControlNetUriError(
f'Torch SDXL Union ControlNet "mode" must be less than '
f'{len(SDXLControlNetUnionUriModes)} and greater than zero, '
f'mode number {mode} does not exist.')
return mode
@staticmethod
def _flux_mode_int_from_str(mode):
modes = _textprocessing.oxford_comma(
[n.name.lower() for n in FluxControlNetUnionUriModes], "or")
try:
try:
mode = int(mode)
except ValueError:
mode = FluxControlNetUnionUriModes[mode.upper()].value
except KeyError:
raise _exceptions.InvalidControlNetUriError(
f'Torch Flux Union ControlNet "mode" must be an integer, '
f'or one of: {modes}. received: {mode}')
if mode >= len(FluxControlNetUnionUriModes) or mode < 0:
raise _exceptions.InvalidControlNetUriError(
f'Torch Flux Union ControlNet "mode" must be less than '
f'{len(FluxControlNetUnionUriModes)} and greater than zero, '
f'mode number {mode} does not exist.')
return mode