# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import typing
import diffusers
import huggingface_hub
import dgenerate.memoize as _d_memoize
import dgenerate.memory as _memory
import dgenerate.messages as _messages
import dgenerate.pipelinewrapper.cache as _cache
import dgenerate.pipelinewrapper.enums as _enums
import dgenerate.pipelinewrapper.hfutil as _hfutil
import dgenerate.pipelinewrapper.uris as _uris
import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types
from dgenerate.memoize import memoize as _memoize
[docs]
class InvalidSchedulerName(Exception):
"""
Unknown scheduler name used
"""
pass
[docs]
class SchedulerHelpException(Exception):
"""
Not an error, runtime scheduler help was requested by passing "help" to a scheduler name
argument of :py:meth:`.DiffusionPipelineWrapper.__init__` such as ``scheduler`` or ``sdxl_refiner_scheduler``.
Upon calling :py:meth:`.DiffusionPipelineWrapper.__call__` info was printed using :py:meth:`dgenerate.messages.log`,
then this exception raised to get out of the call stack.
"""
pass
def _disabled_safety_checker(images, clip_input):
if len(images.shape) == 4:
num_images = images.shape[0]
return images, [False] * num_images
else:
return images, False
def _floyd_disabled_safety_checker(images, clip_input):
if len(images.shape) == 4:
num_images = images.shape[0]
return images, [False] * num_images, False
else:
return images, False, False
def _set_torch_safety_checker(pipeline: diffusers.DiffusionPipeline, safety_checker: bool):
if not safety_checker:
if hasattr(pipeline, 'safety_checker') and pipeline.safety_checker is not None:
# If it's already None for some reason you'll get a call
# to an unassigned feature_extractor by assigning it a value
# The attribute will not exist for SDXL pipelines currently
pipeline.safety_checker = _disabled_safety_checker
def _set_floyd_safety_checker(pipeline: diffusers.DiffusionPipeline, safety_checker: bool):
if not safety_checker:
if hasattr(pipeline, 'safety_checker') and pipeline.safety_checker is not None:
pipeline.safety_checker = _floyd_disabled_safety_checker
[docs]
def scheduler_is_help(name: typing.Optional[str]):
"""
This scheduler name is simply a request for help?, IE: "help"?
:param name: string to test
:return: true of false
"""
if name is None:
return False
return name.strip().lower() == 'help'
[docs]
def load_scheduler(pipeline: typing.Union[diffusers.DiffusionPipeline, diffusers.FlaxDiffusionPipeline],
scheduler_name=None, model_path: typing.Optional[str] = None):
"""
Load a specific compatible scheduler class name onto a huggingface diffusers pipeline object.
:param pipeline: pipeline object
:param scheduler_name: compatible scheduler class name, pass "help" to receive a print out to STDOUT
and raise :py:exc:`.SchedulerHelpException`
:param model_path: Optional model path to be used in the message to STDOUT produced by passing "help"
:return:
"""
if scheduler_name is None:
return
compatibles = pipeline.scheduler.compatibles
if isinstance(pipeline, diffusers.StableDiffusionLatentUpscalePipeline):
# Seems to only work with this scheduler
compatibles = [c for c in compatibles if c.__name__ == 'EulerDiscreteScheduler']
if scheduler_is_help(scheduler_name):
help_string = _textprocessing.underline(f'Compatible schedulers for "{model_path}" are:') + '\n\n'
help_string += '\n'.join((" " * 4) + _textprocessing.quote(i.__name__) for i in compatibles) + '\n'
_messages.log(help_string, underline=True)
raise SchedulerHelpException(help_string)
for i in compatibles:
if i.__name__.endswith(scheduler_name):
pipeline.scheduler = i.from_config(pipeline.scheduler.config)
return
raise InvalidSchedulerName(
f'Scheduler named "{scheduler_name}" is not a valid compatible scheduler, '
f'options are:\n\n{chr(10).join(sorted(" " * 4 + _textprocessing.quote(i.__name__.split(".")[-1]) for i in compatibles))}')
[docs]
def estimate_pipeline_memory_use(
pipeline_type: _enums.PipelineTypes,
model_path: str,
model_type: _enums.ModelTypes,
revision='main',
variant=None,
subfolder=None,
vae_uri=None,
lora_uris=None,
textual_inversion_uris=None,
safety_checker=False,
auth_token=None,
extra_args=None,
local_files_only=False):
"""
Estimate the CPU side memory use of a model.
:param pipeline_type: :py:class:`dgenerate.pipelinewrapper.PipelineTypes`
:param model_path: huggingface slug, blob link, path to folder on disk, path to model file.
:param model_type: :py:class:`dgenerate.pipelinewrapper.ModelTypes`
:param revision: huggingface repo revision if using a huggingface slug
:param variant: model file variant desired, for example "fp16"
:param subfolder: huggingface repo subfolder if using a huggingface slug
:param vae_uri: optional user specified ``--vae`` URI that will be loaded on to the pipeline
:param lora_uris: optional user specified ``--loras`` URIs that will be loaded on to the pipeline
:param textual_inversion_uris: optional user specified ``--textual-inversion`` URIs that will be loaded on to the pipeline
:param safety_checker: consider the safety checker? dgenerate usually loads the safety checker and then retroactively
disables it if needed, so it usually considers the size of the safety checker model.
:param auth_token: optional huggingface auth token to access restricted repositories that your account has access to.
:param extra_args: ``extra_args`` as to be passed to :py:meth:`.create_torch_diffusion_pipeline`
or :py:meth:`.create_flax_diffusion_pipeline`
:param local_files_only: Only ever attempt to look in the local huggingface cache? if ``False`` the huggingface
API will be contacted when necessary.
:return: size estimate in bytes.
"""
if extra_args is None:
extra_args = dict()
usage = _hfutil.estimate_model_memory_use(
repo_id=model_path,
revision=revision,
variant=variant,
subfolder=subfolder,
include_vae=not vae_uri or 'vae' not in extra_args,
safety_checker=safety_checker and 'safety_checker' not in extra_args,
include_text_encoder='text_encoder' not in extra_args,
include_text_encoder_2='text_encoder_2' not in extra_args,
use_auth_token=auth_token,
local_files_only=local_files_only,
flax=_enums.model_type_is_flax(model_type),
sentencepiece=_enums.model_type_is_floyd(model_type)
)
if lora_uris:
if isinstance(lora_uris, str):
lora_uris = [lora_uris]
for lora_uri in lora_uris:
parsed = _uris.LoRAUri.parse(lora_uri)
usage += _hfutil.estimate_model_memory_use(
repo_id=parsed.model,
revision=parsed.revision,
subfolder=parsed.subfolder,
weight_name=parsed.weight_name,
use_auth_token=auth_token,
local_files_only=local_files_only,
flax=_enums.model_type_is_flax(model_type)
)
if textual_inversion_uris:
if isinstance(textual_inversion_uris, str):
textual_inversion_uris = [textual_inversion_uris]
for textual_inversion_uri in textual_inversion_uris:
parsed = _uris.TextualInversionUri.parse(textual_inversion_uri)
usage += _hfutil.estimate_model_memory_use(
repo_id=parsed.model,
revision=parsed.revision,
subfolder=parsed.subfolder,
weight_name=parsed.weight_name,
use_auth_token=auth_token,
local_files_only=local_files_only,
flax=_enums.model_type_is_flax(model_type)
)
return usage
[docs]
def set_vae_slicing_tiling(pipeline: typing.Union[diffusers.DiffusionPipeline,
diffusers.FlaxDiffusionPipeline],
vae_tiling: bool,
vae_slicing: bool):
"""
Set the vae_slicing and vae_tiling status on a created huggingface diffusers pipeline.
:param pipeline: pipeline object
:param vae_tiling: tiling status
:param vae_slicing: slicing status
:return:
"""
has_vae = hasattr(pipeline, 'vae') and pipeline.vae is not None
pipeline_class = pipeline.__class__
if vae_tiling:
if has_vae:
if hasattr(pipeline.vae, 'enable_tiling'):
_messages.debug_log(f'Enabling VAE tiling on Pipeline: "{pipeline_class.__name__}",',
f'VAE: "{pipeline.vae.__class__.__name__}"')
pipeline.vae.enable_tiling()
else:
raise NotImplementedError(
'--vae-tiling not supported as loaded VAE does not support it.'
)
else:
raise NotImplementedError(
'--vae-tiling not supported as no VAE is present for the specified model.')
elif has_vae:
if hasattr(pipeline.vae, 'disable_tiling'):
_messages.debug_log(f'Disabling VAE tiling on Pipeline: "{pipeline_class.__name__}",',
f'VAE: "{pipeline.vae.__class__.__name__}"')
pipeline.vae.disable_tiling()
if vae_slicing:
if has_vae:
if hasattr(pipeline.vae, 'enable_slicing'):
_messages.debug_log(f'Enabling VAE slicing on Pipeline: "{pipeline_class.__name__}",',
f'VAE: "{pipeline.vae.__class__.__name__}"')
pipeline.vae.enable_slicing()
else:
raise NotImplementedError(
'--vae-slicing not supported as loaded VAE does not support it.'
)
else:
raise NotImplementedError(
'--vae-slicing not supported as no VAE is present for the specified model.')
elif has_vae:
if hasattr(pipeline.vae, 'disable_slicing'):
_messages.debug_log(f'Disabling VAE slicing on Pipeline: "{pipeline_class.__name__}",',
f'VAE: "{pipeline.vae.__class__.__name__}"')
pipeline.vae.disable_slicing()
[docs]
class PipelineCreationResult:
[docs]
def __init__(self, pipeline):
self._pipeline = pipeline
@property
def pipeline(self):
return self._pipeline
[docs]
def get_pipeline_modules(self, names=typing.Iterable[str]):
"""
Get associated pipeline module such as ``vae`` etc, in
a dictionary mapped from name to module value.
Possible Module Names:
* ``vae``
* ``text_encoder``
* ``text_encoder_2``
* ``tokenizer``
* ``tokenizer_2``
* ``safety_checker``
* ``feature_extractor``
* ``controlnet``
* ``scheduler``
* ``unet``
If the module is not present or a recognized name, a :py:exc:`ValueError`
will be thrown describing the module that is not part of the pipeline.
:raise ValueError:
:param names: module names, such as ``vae``, ``text_encoder``
:return: dictionary
"""
module_values = dict()
acceptable_lookups = {
'vae',
'text_encoder',
'text_encoder_2',
'tokenizer',
'tokenizer_2',
'safety_checker',
'feature_extractor',
'controlnet',
'scheduler',
'unet'
}
for name in names:
if name not in acceptable_lookups:
raise ValueError(f'"{name}" is not a recognized pipeline module name.')
if not hasattr(self.pipeline, name):
raise ValueError(f'Created pipeline does not possess a module named: "{name}".')
module_values[name] = getattr(self.pipeline, name)
return module_values
[docs]
class TorchPipelineCreationResult(PipelineCreationResult):
@property
def pipeline(self) -> diffusers.DiffusionPipeline:
"""
A created subclass of :py:class:`diffusers.DiffusionPipeline`
"""
return super().pipeline
parsed_vae_uri: typing.Optional[_uris.TorchVAEUri]
"""
Parsed VAE URI if one was present
"""
parsed_lora_uris: typing.List[_uris.LoRAUri]
"""
Parsed LoRA URIs if any were present
"""
parsed_textual_inversion_uris: typing.List[_uris.TextualInversionUri]
"""
Parsed Textual Inversion URIs if any were present
"""
parsed_control_net_uris: typing.List[_uris.TorchControlNetUri]
"""
Parsed ControlNet URIs if any were present
"""
[docs]
def __init__(self,
pipeline: diffusers.DiffusionPipeline,
parsed_vae_uri: typing.Optional[_uris.TorchVAEUri],
parsed_lora_uris: typing.List[_uris.LoRAUri],
parsed_textual_inversion_uris: typing.List[_uris.TextualInversionUri],
parsed_control_net_uris: typing.List[_uris.TorchControlNetUri]):
super().__init__(pipeline)
self.parsed_vae_uri = parsed_vae_uri
self.parsed_lora_uris = parsed_lora_uris
self.parsed_textual_inversion_uris = parsed_textual_inversion_uris
self.parsed_control_net_uris = parsed_control_net_uris
[docs]
def call(self, *args, **kwargs) -> diffusers.utils.BaseOutput:
"""
Call **pipeline**
:param args: forward args to pipeline
:param kwargs: forward kwargs to pipeline
:return: A subclass of :py:class:`diffusers.utils.BaseOutput`
"""
return self.pipeline(*args, **kwargs)
[docs]
def create_torch_diffusion_pipeline(pipeline_type: _enums.PipelineTypes,
model_path: str,
model_type: _enums.ModelTypes = _enums.ModelTypes.TORCH,
revision: _types.OptionalString = None,
variant: _types.OptionalString = None,
subfolder: _types.OptionalString = None,
dtype: _enums.DataTypes = _enums.DataTypes.AUTO,
vae_uri: _types.OptionalUri = None,
lora_uris: _types.OptionalUriOrUris = None,
textual_inversion_uris: _types.OptionalUriOrUris = None,
control_net_uris: _types.OptionalUriOrUris = None,
scheduler: _types.OptionalString = None,
safety_checker: bool = False,
auth_token: _types.OptionalString = None,
device: str = 'cuda',
extra_modules: typing.Optional[typing.Dict[str, typing.Any]] = None,
model_cpu_offload: bool = False,
sequential_cpu_offload: bool = False,
local_files_only: bool = False) -> TorchPipelineCreationResult:
"""
Create a :py:class:`diffusers.DiffusionPipeline` in dgenerates in memory cacheing system.
:param pipeline_type: py:class:`dgenerate.pipelinewrapper.PipelineTypes` enum value
:param model_type: py:class:`dgenerate.pipelinewrapper.ModelTypes` enum value
:param model_path: huggingface slug, huggingface blob link, path to folder on disk, path to file on disk
:param revision: huggingface repo revision (branch)
:param variant: model weights name variant, for example 'fp16'
:param subfolder: huggingface repo subfolder if applicable
:param dtype: Optional py:class:`dgenerate.pipelinewrapper.DataTypes` enum value
:param vae_uri: Optional ``--vae`` URI string for specifying a specific VAE
:param lora_uris: Optional ``--loras`` URI strings for specifying LoRA weights
:param textual_inversion_uris: Optional ``--textual-inversions`` URI strings for specifying Textual Inversion weights
:param control_net_uris: Optional ``--control-nets`` URI strings for specifying ControlNet models
:param scheduler: Optional scheduler (sampler) class name, unqualified, or "help" to print supported values
to STDOUT and raise :py:exc:`dgenerate.pipelinewrapper.SchedulerHelpException`
:param safety_checker: Safety checker enabled? default is false
:param auth_token: Optional huggingface API token for accessing repositories that are restricted to your account
:param device: Optional ``--device`` string, defaults to "cuda"
:param extra_modules: Extra module arguments to pass directly into
:py:meth:`diffusers.DiffusionPipeline.from_single_file` or :py:meth:`diffusers.DiffusionPipeline.from_pretrained`
:param model_cpu_offload: This pipeline has model_cpu_offloading enabled?
:param sequential_cpu_offload: This pipeline has sequential_cpu_offloading enabled?
:param local_files_only: Only look in the huggingface cache and do not connect to download models?
:raises ModelNotFoundError:
:raises InvalidModelUriError:
:raises InvalidSchedulerName:
:raises NotImplementedError:
:return: :py:class:`.TorchPipelineCreationResult`
"""
__locals = locals()
try:
return _create_torch_diffusion_pipeline(**__locals)
except (huggingface_hub.utils.HFValidationError,
huggingface_hub.utils.HfHubHTTPError) as e:
raise _hfutil.ModelNotFoundError(e)
[docs]
class TorchPipelineFactory:
"""
Combines :py:meth:`.create_torch_diffusion_pipeline` and :py:meth:`.set_vae_slicing_tiling` into a factory
that can recreate the same Torch pipeline over again, possibly from cache.
"""
[docs]
def __init__(self,
pipeline_type: _enums.PipelineTypes,
model_path: str,
model_type: _enums.ModelTypes = _enums.ModelTypes.TORCH,
revision: _types.OptionalString = None,
variant: _types.OptionalString = None,
subfolder: _types.OptionalString = None,
dtype: _enums.DataTypes = _enums.DataTypes.AUTO,
vae_uri: _types.OptionalUri = None,
lora_uris: _types.OptionalUriOrUris = None,
textual_inversion_uris: _types.OptionalUriOrUris = None,
control_net_uris: _types.OptionalUriOrUris = None,
scheduler: _types.OptionalString = None,
safety_checker: bool = False,
auth_token: _types.OptionalString = None,
device: str = 'cuda',
extra_modules: typing.Optional[typing.Dict[str, typing.Any]] = None,
model_cpu_offload: bool = False,
sequential_cpu_offload: bool = False,
local_files_only: bool = False,
vae_tiling=False,
vae_slicing=False):
self._args = {k: v for k, v in locals().items() if k not in {'self', 'vae_tiling', 'vae_slicing'}}
self._vae_tiling = vae_tiling
self._vae_slicing = vae_slicing
[docs]
def __call__(self) -> TorchPipelineCreationResult:
"""
:raises ModelNotFoundError:
:raises InvalidModelUriError:
:raises InvalidSchedulerName:
:raises NotImplementedError:
:return: :py:class:`.TorchPipelineCreationResult`
"""
r = create_torch_diffusion_pipeline(**self._args)
set_vae_slicing_tiling(r.pipeline,
vae_tiling=self._vae_tiling,
vae_slicing=self._vae_slicing)
return r
@_memoize(_cache._TORCH_PIPELINE_CACHE,
exceptions={'local_files_only'},
hasher=lambda args: _d_memoize.args_cache_key(args,
{'vae_uri': _cache.uri_hash_with_parser(
_uris.TorchVAEUri.parse),
'lora_uris':
_cache.uri_list_hash_with_parser(_uris.LoRAUri.parse),
'textual_inversion_uris':
_cache.uri_list_hash_with_parser(
_uris.TextualInversionUri.parse),
'control_net_uris':
_cache.uri_list_hash_with_parser(
_uris.TorchControlNetUri.parse)}),
on_hit=lambda key, hit: _d_memoize.simple_cache_hit_debug("Torch Pipeline", key, hit.pipeline),
on_create=lambda key, new: _d_memoize.simple_cache_miss_debug('Torch Pipeline', key, new.pipeline))
def _create_torch_diffusion_pipeline(pipeline_type: _enums.PipelineTypes,
model_path: str,
model_type: _enums.ModelTypes = _enums.ModelTypes.TORCH,
revision: _types.OptionalString = None,
variant: _types.OptionalString = None,
subfolder: _types.OptionalString = None,
dtype: _enums.DataTypes = _enums.DataTypes.AUTO,
vae_uri: _types.OptionalUri = None,
lora_uris: _types.OptionalUriOrUris = None,
textual_inversion_uris: _types.OptionalUriOrUris = None,
control_net_uris: _types.OptionalUriOrUris = None,
scheduler: _types.OptionalString = None,
safety_checker: bool = False,
auth_token: _types.OptionalString = None,
device: str = 'cuda',
extra_modules: typing.Optional[typing.Dict[str, typing.Any]] = None,
model_cpu_offload: bool = False,
sequential_cpu_offload: bool = False,
local_files_only: bool = False) -> TorchPipelineCreationResult:
if not _enums.model_type_is_torch(model_type):
raise ValueError('model_type must be a TORCH ModelTypes enum value.')
# Pipeline class selection
if _enums.model_type_is_floyd(model_type):
if control_net_uris:
raise NotImplementedError(
'Deep Floyd --model-type values are not compatible with --control-nets.')
if textual_inversion_uris:
raise NotImplementedError(
'Deep Floyd --model-type values are not compatible with --textual-inversions.')
if _enums.model_type_is_upscaler(model_type):
if pipeline_type != _enums.PipelineTypes.IMG2IMG and not scheduler_is_help(scheduler):
raise NotImplementedError(
'Upscaler models only work with img2img generation, IE: --image-seeds (with no image masks).')
if model_type == _enums.ModelTypes.TORCH_UPSCALER_X2:
if lora_uris or textual_inversion_uris:
raise NotImplementedError(
'--model-type torch-upscaler-x2 is not compatible with --lora or --textual-inversions.')
pipeline_class = (diffusers.StableDiffusionUpscalePipeline if model_type == _enums.ModelTypes.TORCH_UPSCALER_X4
else diffusers.StableDiffusionLatentUpscalePipeline)
else:
sdxl = _enums.model_type_is_sdxl(model_type)
pix2pix = _enums.model_type_is_pix2pix(model_type)
if pipeline_type == _enums.PipelineTypes.TXT2IMG:
if pix2pix:
raise NotImplementedError(
'pix2pix models only work in img2img mode and cannot work without --image-seeds.')
if model_type == _enums.ModelTypes.TORCH_IF:
pipeline_class = diffusers.IFPipeline
elif model_type == _enums.ModelTypes.TORCH_IFS:
raise NotImplementedError(
'Deep Floyd IF super resolution (IFS) only works in img2img mode and cannot work without --image-seeds.')
elif control_net_uris:
pipeline_class = diffusers.StableDiffusionXLControlNetPipeline if sdxl else diffusers.StableDiffusionControlNetPipeline
else:
pipeline_class = diffusers.StableDiffusionXLPipeline if sdxl else diffusers.StableDiffusionPipeline
elif pipeline_type == _enums.PipelineTypes.IMG2IMG:
if pix2pix:
if control_net_uris:
raise NotImplementedError(
'pix2pix models are not compatible with --control-nets.')
pipeline_class = diffusers.StableDiffusionXLInstructPix2PixPipeline if sdxl else diffusers.StableDiffusionInstructPix2PixPipeline
elif model_type == _enums.ModelTypes.TORCH_IF:
pipeline_class = diffusers.IFImg2ImgPipeline
elif model_type == _enums.ModelTypes.TORCH_IFS:
pipeline_class = diffusers.IFSuperResolutionPipeline
elif model_type == _enums.ModelTypes.TORCH_IFS_IMG2IMG:
pipeline_class = diffusers.IFImg2ImgSuperResolutionPipeline
elif control_net_uris:
if sdxl:
pipeline_class = diffusers.StableDiffusionXLControlNetImg2ImgPipeline
else:
pipeline_class = diffusers.StableDiffusionControlNetImg2ImgPipeline
else:
pipeline_class = diffusers.StableDiffusionXLImg2ImgPipeline if sdxl else diffusers.StableDiffusionImg2ImgPipeline
elif pipeline_type == _enums.PipelineTypes.INPAINT:
if pix2pix:
raise NotImplementedError(
'pix2pix models only work in img2img mode and cannot work in inpaint mode (with a mask).')
if model_type == _enums.ModelTypes.TORCH_IF:
pipeline_class = diffusers.IFInpaintingPipeline
elif model_type == _enums.ModelTypes.TORCH_IFS:
pipeline_class = diffusers.IFInpaintingSuperResolutionPipeline
elif control_net_uris:
if sdxl:
pipeline_class = diffusers.StableDiffusionXLControlNetInpaintPipeline
else:
pipeline_class = diffusers.StableDiffusionControlNetInpaintPipeline
else:
pipeline_class = diffusers.StableDiffusionXLInpaintPipeline if sdxl else diffusers.StableDiffusionInpaintPipeline
else:
# Should be impossible
raise NotImplementedError('Pipeline type not implemented.')
vae_override = extra_modules and 'vae' in extra_modules
controlnet_override = extra_modules and 'controlnet' in extra_modules
safety_checker_override = extra_modules and 'safety_checker' in extra_modules
scheduler_override = extra_modules and 'scheduler' in extra_modules
estimated_memory_usage = estimate_pipeline_memory_use(
pipeline_type=pipeline_type,
model_type=model_type,
model_path=model_path,
revision=revision,
variant=variant,
subfolder=subfolder,
vae_uri=vae_uri if not vae_override else None,
lora_uris=lora_uris,
textual_inversion_uris=textual_inversion_uris,
safety_checker=safety_checker and not safety_checker_override,
auth_token=auth_token,
extra_args=extra_modules,
local_files_only=local_files_only
)
_messages.debug_log(
f'Creating Torch Pipeline: "{pipeline_class.__name__}", '
f'Estimated CPU Side Memory Use: {_memory.bytes_best_human_unit(estimated_memory_usage)}')
_cache.enforce_pipeline_cache_constraints(
new_pipeline_size=estimated_memory_usage)
# Block invalid Textual Inversion and LoRA usage
if textual_inversion_uris:
if model_type == _enums.ModelTypes.TORCH_UPSCALER_X2:
raise NotImplementedError(
'--model-type torch-upscaler-x2 cannot be used with textual inversion models.')
if isinstance(textual_inversion_uris, str):
textual_inversion_uris = [textual_inversion_uris]
if lora_uris:
if _enums.model_type_is_upscaler(model_type):
raise NotImplementedError(
'LoRA models cannot be used with upscaler models.')
if isinstance(lora_uris, str):
lora_uris = [lora_uris]
# ControlNet and VAE loading
# Used during pipeline load
creation_kwargs = {}
torch_dtype = _enums.get_torch_dtype(dtype)
parsed_control_net_uris = []
parsed_vae_uri = None
if not scheduler_is_help(scheduler):
# prevent waiting on VAE load just to get the scheduler
# help message for the main model
if vae_uri is not None and not vae_override:
parsed_vae_uri = _uris.TorchVAEUri.parse(vae_uri)
creation_kwargs['vae'] = \
parsed_vae_uri.load(
dtype_fallback=dtype,
use_auth_token=auth_token,
local_files_only=local_files_only)
_messages.debug_log(lambda:
f'Added Torch VAE: "{vae_uri}" to pipeline: "{pipeline_class.__name__}"')
if control_net_uris and not controlnet_override:
if _enums.model_type_is_pix2pix(model_type):
raise NotImplementedError(
'Using ControlNets with pix2pix models is not supported.'
)
control_nets = None
for control_net_uri in control_net_uris:
parsed_control_net_uri = _uris.TorchControlNetUri.parse(control_net_uri)
parsed_control_net_uris.append(parsed_control_net_uri)
new_net = parsed_control_net_uri.load(use_auth_token=auth_token,
dtype_fallback=dtype,
local_files_only=local_files_only)
_messages.debug_log(lambda:
f'Added Torch ControlNet: "{control_net_uri}" '
f'to pipeline: "{pipeline_class.__name__}"')
if control_nets is not None:
if not isinstance(control_nets, list):
control_nets = [control_nets, new_net]
else:
control_nets.append(new_net)
else:
control_nets = new_net
creation_kwargs['controlnet'] = control_nets
if _enums.model_type_is_floyd(model_type):
creation_kwargs['watermarker'] = None
if not safety_checker and not _enums.model_type_is_sdxl(model_type) and not safety_checker_override:
creation_kwargs['safety_checker'] = None
if extra_modules is not None:
creation_kwargs.update(extra_modules)
# Create Pipeline
if _hfutil.is_single_file_model_load(model_path):
if subfolder is not None:
raise NotImplementedError('Single file model loads do not support the subfolder option.')
pipeline = pipeline_class.from_single_file(model_path,
revision=revision,
variant=variant,
torch_dtype=torch_dtype,
use_safe_tensors=model_path.endswith('.safetensors'),
local_files_only=local_files_only,
**creation_kwargs)
else:
pipeline = pipeline_class.from_pretrained(model_path,
revision=revision,
variant=variant,
torch_dtype=torch_dtype,
subfolder=subfolder,
use_auth_token=auth_token,
local_files_only=local_files_only,
**creation_kwargs)
# Select Scheduler
if not scheduler_override:
load_scheduler(pipeline=pipeline,
model_path=model_path,
scheduler_name=scheduler)
# Textual Inversions and LoRAs
parsed_textual_inversion_uris = []
parsed_lora_uris = []
if textual_inversion_uris:
for inversion_uri in textual_inversion_uris:
parsed = _uris.TextualInversionUri.parse(inversion_uri)
parsed_textual_inversion_uris.append(parsed)
parsed.load_on_pipeline(pipeline,
use_auth_token=auth_token,
local_files_only=local_files_only)
if lora_uris:
for lora_uri in lora_uris:
parsed = _uris.LoRAUri.parse(lora_uri)
parsed_lora_uris.append(parsed)
parsed.load_on_pipeline(pipeline,
use_auth_token=auth_token,
local_files_only=local_files_only)
# Safety Checker
if not safety_checker_override:
if _enums.model_type_is_floyd(model_type):
_set_floyd_safety_checker(pipeline, safety_checker)
else:
_set_torch_safety_checker(pipeline, safety_checker)
# Model Offloading
# Tag the pipeline with our own attributes
pipeline.DGENERATE_SEQUENTIAL_OFFLOAD = sequential_cpu_offload
pipeline.DGENERATE_CPU_OFFLOAD = model_cpu_offload
if sequential_cpu_offload and 'cuda' in device:
pipeline.enable_sequential_cpu_offload(device=device)
elif model_cpu_offload and 'cuda' in device:
pipeline.enable_model_cpu_offload(device=device)
_cache.pipeline_create_update_cache_info(pipeline=pipeline,
estimated_size=estimated_memory_usage)
_messages.debug_log(f'Finished Creating Torch Pipeline: "{pipeline_class.__name__}"')
return TorchPipelineCreationResult(
pipeline=pipeline,
parsed_vae_uri=parsed_vae_uri,
parsed_lora_uris=parsed_lora_uris,
parsed_textual_inversion_uris=parsed_textual_inversion_uris,
parsed_control_net_uris=parsed_control_net_uris
)
[docs]
class FlaxPipelineCreationResult(PipelineCreationResult):
@property
def pipeline(self) -> diffusers.FlaxDiffusionPipeline:
"""
A created subclass of :py:class:`diffusers.FlaxDiffusionPipeline`
"""
return super().pipeline
flax_params: typing.Dict[str, typing.Any]
"""
Flax specific Pipeline params object
"""
parsed_vae_uri: typing.Optional[_uris.FlaxVAEUri]
"""
Parsed VAE URI if one was present
"""
flax_vae_params: typing.Optional[typing.Dict[str, typing.Any]]
"""
Flax specific VAE params object
"""
parsed_control_net_uris: typing.List[_uris.FlaxControlNetUri]
"""
Parsed ControlNet URIs if any were present
"""
flax_control_net_params: typing.Optional[typing.Dict[str, typing.Any]]
"""
Flax specific ControlNet params object
"""
[docs]
def __init__(self,
pipeline: diffusers.FlaxDiffusionPipeline,
flax_params: typing.Dict[str, typing.Any],
parsed_vae_uri: typing.Optional[_uris.FlaxVAEUri],
flax_vae_params: typing.Optional[typing.Dict[str, typing.Any]],
parsed_control_net_uris: typing.List[_uris.FlaxControlNetUri],
flax_control_net_params: typing.Optional[typing.Dict[str, typing.Any]]):
super().__init__(pipeline)
self.flax_params = flax_params
self.parsed_control_net_uris = parsed_control_net_uris
self.parsed_vae_uri = parsed_vae_uri
self.flax_vae_params = flax_vae_params
self.flax_control_net_params = flax_control_net_params
[docs]
def call(self, *args, **kwargs) -> diffusers.utils.BaseOutput:
"""
Call **pipeline**
:param args: forward args to pipeline
:param kwargs: forward kwargs to pipeline
:return: A subclass of :py:class:`diffusers.utils.BaseOutput`
"""
return self.pipeline(*args, **kwargs)
[docs]
def create_flax_diffusion_pipeline(pipeline_type: _enums.PipelineTypes,
model_path: str,
model_type: _enums.ModelTypes = _enums.ModelTypes.FLAX,
revision: _types.OptionalString = None,
subfolder: _types.OptionalString = None,
dtype: _enums.DataTypes = _enums.DataTypes.AUTO,
vae_uri: _types.OptionalUri = None,
control_net_uris: _types.OptionalUriOrUris = None,
scheduler: _types.OptionalString = None,
safety_checker: bool = False,
auth_token: _types.OptionalString = None,
extra_modules: typing.Optional[typing.Dict[str, typing.Any]] = None,
local_files_only: bool = False) -> FlaxPipelineCreationResult:
"""
Create a :py:class:`diffusers.FlaxDiffusionPipeline` in dgenerates in memory cacheing system.
:param pipeline_type: py:class:`dgenerate.pipelinewrapper.PipelineTypes` enum value
:param model_path: huggingface slug, huggingface blob link, path to folder on disk, path to file on disk
:param model_type: Currently only accepts :py:attr:`dgenerate.pipelinewrapper.ModelTypes.FLAX`
:param revision: huggingface repo revision (branch)
:param subfolder: huggingface repo subfolder if applicable
:param dtype: Optional py:class:`dgenerate.pipelinewrapper.DataTypes` enum value
:param vae_uri: Optional Flax specific ``--vae`` URI string for specifying a specific VAE
:param control_net_uris: Optional ``--control-nets`` URI strings for specifying ControlNet models
:param scheduler: Optional scheduler (sampler) class name, unqualified, or "help" to print supported values
to STDOUT and raise :py:exc:`dgenerate.pipelinewrapper.SchedulerHelpException`
:param safety_checker: Safety checker enabled? default is false
:param auth_token: Optional huggingface API token for accessing repositories that are restricted to your account
:param extra_modules: Extra module arguments to pass directly into :py:meth:`diffusers.FlaxDiffusionPipeline.from_pretrained`
:param local_files_only: Only look in the huggingface cache and do not connect to download models?
:raises ModelNotFoundError:
:raises InvalidModelUriError:
:raises InvalidSchedulerName:
:raises NotImplementedError:
:return: :py:class:`.FlaxPipelineCreationResult`
"""
__locals = locals()
try:
return _create_flax_diffusion_pipeline(**__locals)
except (huggingface_hub.utils.HFValidationError,
huggingface_hub.utils.HfHubHTTPError) as e:
raise _hfutil.ModelNotFoundError(e)
[docs]
class FlaxPipelineFactory:
"""
Turns :py:meth:`.create_flax_diffusion_pipeline` into a factory
that can recreate the same Flax pipeline over again, possibly from cache.
"""
[docs]
def __init__(self, pipeline_type: _enums.PipelineTypes,
model_path: str,
model_type: _enums.ModelTypes = _enums.ModelTypes.FLAX,
revision: _types.OptionalString = None,
subfolder: _types.OptionalString = None,
dtype: _enums.DataTypes = _enums.DataTypes.AUTO,
vae_uri: _types.OptionalUri = None,
control_net_uris: _types.OptionalUriOrUris = None,
scheduler: _types.OptionalString = None,
safety_checker: bool = False,
auth_token: _types.OptionalString = None,
extra_modules: typing.Optional[typing.Dict[str, typing.Any]] = None,
local_files_only: bool = False):
self._args = {k: v for k, v in locals().items() if k not in {'self'}}
[docs]
def __call__(self) -> FlaxPipelineCreationResult:
"""
:raises ModelNotFoundError:
:raises InvalidModelUriError:
:raises InvalidSchedulerName:
:raises NotImplementedError:
:return: :py:class:`.FlaxPipelineCreationResult`
"""
return create_flax_diffusion_pipeline(**self._args)
@_memoize(_cache._FLAX_PIPELINE_CACHE,
exceptions={'local_files_only'},
hasher=lambda args: _d_memoize.args_cache_key(args,
{'vae_uri': _cache.uri_hash_with_parser(
_uris.FlaxVAEUri.parse),
'control_net_uris':
_cache.uri_list_hash_with_parser(
_uris.FlaxControlNetUri.parse)}),
on_hit=lambda key, hit: _d_memoize.simple_cache_hit_debug("Flax Pipeline", key, hit.pipeline),
on_create=lambda key, new: _d_memoize.simple_cache_miss_debug('Flax Pipeline', key, new.pipeline))
def _create_flax_diffusion_pipeline(pipeline_type: _enums.PipelineTypes,
model_path: str,
model_type: _enums.ModelTypes = _enums.ModelTypes.FLAX,
revision: _types.OptionalString = None,
subfolder: _types.OptionalString = None,
dtype: _enums.DataTypes = _enums.DataTypes.AUTO,
vae_uri: _types.OptionalUri = None,
control_net_uris: _types.OptionalUriOrUris = None,
scheduler: _types.OptionalString = None,
safety_checker: bool = False,
auth_token: _types.OptionalString = None,
extra_modules: typing.Optional[typing.Dict[str, typing.Any]] = None,
local_files_only: bool = False) -> FlaxPipelineCreationResult:
if not _enums.model_type_is_flax(model_type):
raise ValueError('model_type must be a FLAX ModelTypes enum value.')
has_control_nets = False
if control_net_uris:
if len(control_net_uris) > 1:
raise NotImplementedError('Flax does not support multiple --control-nets.')
if len(control_net_uris) == 1:
has_control_nets = True
if pipeline_type == _enums.PipelineTypes.TXT2IMG:
if has_control_nets:
pipeline_class = diffusers.FlaxStableDiffusionControlNetPipeline
else:
pipeline_class = diffusers.FlaxStableDiffusionPipeline
elif pipeline_type == _enums.PipelineTypes.IMG2IMG:
if has_control_nets:
raise NotImplementedError('Flax does not support img2img mode with --control-nets.')
pipeline_class = diffusers.FlaxStableDiffusionImg2ImgPipeline
elif pipeline_type == _enums.PipelineTypes.INPAINT:
if has_control_nets:
raise NotImplementedError('Flax does not support inpaint mode with --control-nets.')
pipeline_class = diffusers.FlaxStableDiffusionInpaintPipeline
else:
raise NotImplementedError('Pipeline type not implemented.')
vae_override = extra_modules and 'vae' in extra_modules
controlnet_override = extra_modules and 'controlnet' in extra_modules
safety_checker_override = extra_modules and 'safety_checker' in extra_modules
scheduler_override = extra_modules and 'scheduler' in extra_modules
feature_extractor_override = extra_modules and 'feature_extractor' in extra_modules
estimated_memory_usage = estimate_pipeline_memory_use(
pipeline_type=pipeline_type,
model_type=model_type,
model_path=model_path,
revision=revision,
subfolder=subfolder,
vae_uri=vae_uri if not vae_override else None,
safety_checker=safety_checker and not safety_checker_override,
auth_token=auth_token,
extra_args=extra_modules,
local_files_only=local_files_only
)
_messages.debug_log(
f'Creating Flax Pipeline: "{pipeline_class.__name__}", '
f'Estimated CPU Side Memory Use: {_memory.bytes_best_human_unit(estimated_memory_usage)}')
_cache.enforce_pipeline_cache_constraints(
new_pipeline_size=estimated_memory_usage)
creation_kwargs = {}
vae_params = None
control_net_params = None
flax_dtype = _enums.get_flax_dtype(dtype)
parsed_control_net_uris = []
parsed_flax_vae_uri = None
if not scheduler_is_help(scheduler):
# prevent waiting on VAE load just get the scheduler
# help message for the main model
if vae_uri is not None and not vae_override:
parsed_flax_vae_uri = _uris.FlaxVAEUri.parse(vae_uri)
creation_kwargs['vae'], vae_params = parsed_flax_vae_uri.load(
dtype_fallback=dtype,
use_auth_token=auth_token,
local_files_only=local_files_only)
_messages.debug_log(lambda:
f'Added Flax VAE: "{vae_uri}" to pipeline: "{pipeline_class.__name__}"')
if control_net_uris and not controlnet_override:
control_net_uri = control_net_uris[0]
parsed_flax_control_net_uri = _uris.FlaxControlNetUri.parse(control_net_uri)
parsed_control_net_uris.append(parsed_flax_control_net_uri)
control_net, control_net_params = parsed_flax_control_net_uri \
.load(use_auth_token=auth_token,
dtype_fallback=dtype,
local_files_only=local_files_only)
_messages.debug_log(lambda:
f'Added Flax ControlNet: "{control_net_uri}" '
f'to pipeline: "{pipeline_class.__name__}"')
creation_kwargs['controlnet'] = control_net
if extra_modules is not None:
creation_kwargs.update(extra_modules)
if not safety_checker and not safety_checker_override:
creation_kwargs['safety_checker'] = None
try:
pipeline, params = pipeline_class.from_pretrained(model_path,
revision=revision,
dtype=flax_dtype,
subfolder=subfolder,
use_auth_token=auth_token,
local_files_only=local_files_only,
**creation_kwargs)
except ValueError as e:
if 'feature_extractor' not in str(e):
raise e
# odd diffusers bug
if not feature_extractor_override:
creation_kwargs['feature_extractor'] = None
pipeline, params = pipeline_class.from_pretrained(model_path,
revision=revision,
dtype=flax_dtype,
subfolder=subfolder,
use_auth_token=auth_token,
local_files_only=local_files_only,
**creation_kwargs)
if vae_params is not None:
params['vae'] = vae_params
if control_net_params is not None:
params['controlnet'] = control_net_params
if not scheduler_override:
load_scheduler(pipeline=pipeline,
model_path=model_path,
scheduler_name=scheduler)
if not safety_checker and not safety_checker_override:
pipeline.safety_checker = None
_cache.pipeline_create_update_cache_info(pipeline=pipeline,
estimated_size=estimated_memory_usage)
_messages.debug_log(f'Finished Creating Flax Pipeline: "{pipeline_class.__name__}"')
return FlaxPipelineCreationResult(
pipeline=pipeline,
flax_params=params,
parsed_vae_uri=parsed_flax_vae_uri,
flax_vae_params=vae_params,
parsed_control_net_uris=parsed_control_net_uris,
flax_control_net_params=control_net_params
)
__all__ = _types.module_all()