# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import collections.abc
import functools
import gc
import hashlib
import inspect
import os.path
import pathlib
import random
import re
import typing
import accelerate
import diffusers
import diffusers.loaders
import diffusers.loaders.single_file_utils
import diffusers.quantizers.quantization_config
import torch
import torch.nn
import transformers
import dgenerate.devicecache as _devicecache
import dgenerate.exceptions as _d_exceptions
import dgenerate.extras.kolors as _kolors
import dgenerate.extras.ultraedit
import dgenerate.filecache as _filecache
import dgenerate.hfhub as _hfhub
import dgenerate.memoize as _d_memoize
import dgenerate.memory as _memory
import dgenerate.messages as _messages
import dgenerate.pipelinewrapper.enums as _enums
import dgenerate.pipelinewrapper.schedulers as _schedulers
import dgenerate.pipelinewrapper.uris as _uris
import dgenerate.pipelinewrapper.util as _util
import dgenerate.promptweighters as _promptweighters
import dgenerate.textprocessing as _textprocessing
import dgenerate.torchutil as _torchutil
import dgenerate.types as _types
from dgenerate.memoize import memoize as _memoize
from dgenerate.pipelinewrapper import constants as _constants
import dgenerate.pipelinewrapper.models as _models
[docs]
class UnsupportedPipelineConfigError(Exception):
"""
Occurs when a diffusers pipeline is requested to be
configured in a way that is unsupported by that pipeline.
"""
pass
[docs]
class InvalidModelFileError(Exception):
"""
Raised when a file is loaded from disk that is an invalid diffusers model format.
This indicates that was a problem loading the primary diffusion model,
This could also refer to an SDXL refiner model or Stable Cascade decoder
model which are considered primary models.
"""
pass
_pipeline_cache = _d_memoize.create_object_cache(
'pipeline', cache_type=_memory.SizedConstrainedObjectCache
)
def _disabled_safety_checker(images, clip_input):
if len(images.shape) == 4:
num_images = images.shape[0]
return images, [False] * num_images
else:
return images, False
def _floyd_disabled_safety_checker(images, clip_input):
if len(images.shape) == 4:
num_images = images.shape[0]
return images, [False] * num_images, False
else:
return images, False, False
def _set_sd_safety_checker(pipeline: diffusers.DiffusionPipeline, safety_checker: bool):
if not safety_checker:
if hasattr(pipeline, 'safety_checker') and pipeline.safety_checker is not None:
# If it's already None for some reason you'll get a call
# to an unassigned feature_extractor by assigning it a value
# The attribute will not exist for SDXL pipelines currently
pipeline.safety_checker = _disabled_safety_checker
def _set_floyd_safety_checker(pipeline: diffusers.DiffusionPipeline, safety_checker: bool):
if not safety_checker:
if hasattr(pipeline, 'safety_checker') and pipeline.safety_checker is not None:
pipeline.safety_checker = _floyd_disabled_safety_checker
[docs]
def set_vae_tiling_and_slicing(
pipeline: diffusers.DiffusionPipeline,
tiling: bool,
slicing: bool
):
"""
Set the ``vae_slicing`` and ``vae_tiling`` status on a diffusers pipeline.
:raises UnsupportedPipelineConfigError: if the pipeline does not support one or both
of the provided values for ``vae_tiling`` and ``vae_slicing``
:param pipeline: pipeline object
:param tiling: tiling status
:param slicing: slicing status
"""
has_vae = hasattr(pipeline, 'vae') and pipeline.vae is not None
pipeline_class = pipeline.__class__
if tiling:
if has_vae:
if hasattr(pipeline.vae, 'enable_tiling'):
_messages.debug_log(f'Enabling VAE tiling on Pipeline: "{pipeline_class.__name__}",',
f'VAE: "{pipeline.vae.__class__.__name__}"')
pipeline.vae.enable_tiling()
else:
raise UnsupportedPipelineConfigError(
'--vae-tiling not supported as loaded VAE does not support it.'
)
else:
raise UnsupportedPipelineConfigError(
'--vae-tiling not supported as no VAE is present for the specified model.')
elif has_vae:
if hasattr(pipeline.vae, 'disable_tiling'):
_messages.debug_log(f'Disabling VAE tiling on Pipeline: "{pipeline_class.__name__}",',
f'VAE: "{pipeline.vae.__class__.__name__}"')
pipeline.vae.disable_tiling()
if slicing:
if has_vae:
if hasattr(pipeline.vae, 'enable_slicing'):
_messages.debug_log(f'Enabling VAE slicing on Pipeline: "{pipeline_class.__name__}",',
f'VAE: "{pipeline.vae.__class__.__name__}"')
pipeline.vae.enable_slicing()
else:
raise UnsupportedPipelineConfigError(
'--vae-slicing not supported as loaded VAE does not support it.'
)
else:
raise UnsupportedPipelineConfigError(
'--vae-slicing not supported as no VAE is present for the specified model.')
elif has_vae:
if hasattr(pipeline.vae, 'disable_slicing'):
_messages.debug_log(f'Disabling VAE slicing on Pipeline: "{pipeline_class.__name__}",',
f'VAE: "{pipeline.vae.__class__.__name__}"')
pipeline.vae.disable_slicing()
[docs]
def get_pipeline_modules(pipeline: diffusers.DiffusionPipeline):
"""
Get all component modules of a torch diffusers pipeline.
:param pipeline: the pipeline
:return: dictionary of modules by name
"""
return {k: v for k, v in pipeline.components.items() if isinstance(v, torch.nn.Module)}
def _set_sequential_cpu_offload_flag(module: diffusers.DiffusionPipeline | torch.nn.Module, value: bool):
module.DGENERATE_SEQUENTIAL_CPU_OFFLOAD = bool(value)
_messages.debug_log(
f'setting DGENERATE_SEQUENTIAL_CPU_OFFLOAD={value} on module "{module.__class__.__name__}"')
def _set_cpu_offload_flag(module: diffusers.DiffusionPipeline | torch.nn.Module, value: bool):
module.DGENERATE_MODEL_CPU_OFFLOAD = bool(value)
_messages.debug_log(
f'setting DGENERATE_MODEL_CPU_OFFLOAD={value} on module "{module.__class__.__name__}"')
[docs]
def is_sequential_cpu_offload_enabled(module: diffusers.DiffusionPipeline | torch.nn.Module):
"""
Test if a pipeline or torch neural net module created by dgenerate has sequential offload enabled.
:param module: the module object
:return: ``True`` or ``False``
"""
return hasattr(module, 'DGENERATE_SEQUENTIAL_CPU_OFFLOAD') and bool(module.DGENERATE_SEQUENTIAL_CPU_OFFLOAD)
[docs]
def is_model_cpu_offload_enabled(module: diffusers.DiffusionPipeline | torch.nn.Module):
"""
Test if a pipeline or torch neural net module created by dgenerate has model cpu offload enabled.
:param module: the module object
:return: ``True`` or ``False``
"""
return hasattr(module, 'DGENERATE_MODEL_CPU_OFFLOAD') and bool(module.DGENERATE_MODEL_CPU_OFFLOAD)
def _disable_to(module, vae=False):
if hasattr(module, '_DGENERATE_ORIGINAL_TO_DISABLED'):
return # Already patched
module._DGENERATE_ORIGINAL_TO_DISABLED = module.to
def dummy(*args, **kwargs):
if vae and module.config.force_upcast and \
(len(args) == 1 and isinstance(args[0], torch.dtype)) or \
(len(kwargs) == 1 and 'dtype' in kwargs):
# basically, is this a VAE that the pipeline needs to upcast
# this has to happen even if it is described as 'meta'
module._DGENERATE_ORIGINAL_TO_DISABLED(*args, **kwargs)
else:
pass
module.to = dummy
_messages.debug_log(
f'Disabled .to() on module: {_types.fullname(module)}')
[docs]
def enable_sequential_cpu_offload(pipeline: diffusers.DiffusionPipeline,
device: torch.device | str = _torchutil.default_device()):
"""
Enable sequential offloading on a torch pipeline, in a way dgenerate can keep track of.
:param pipeline: the pipeline
:param device: the device
"""
torch_device = torch.device(device)
# Check if the requested device type is actually available
# If not, fall back to the system's default device
if torch_device.type == 'cuda' and not _torchutil.is_cuda_available():
fallback_device = _torchutil.default_device()
_messages.debug_log(
f'enable_sequential_cpu_offload: CUDA requested but not available, using {fallback_device} for execution device')
torch_device = torch.device(fallback_device)
elif torch_device.type == 'mps' and not _torchutil.is_mps_available():
fallback_device = _torchutil.default_device()
_messages.debug_log(
f'enable_sequential_cpu_offload: MPS requested but not available, using {fallback_device} for execution device')
torch_device = torch.device(fallback_device)
elif torch_device.type == 'xpu' and not _torchutil.is_xpu_available():
fallback_device = _torchutil.default_device()
_messages.debug_log(
f'enable_sequential_cpu_offload: XPU requested but not available, using {fallback_device} for execution device')
torch_device = torch.device(fallback_device)
pipeline.remove_all_hooks()
_set_sequential_cpu_offload_flag(pipeline, True)
for name, model in get_pipeline_modules(pipeline).items():
quant, _, _ = _util.check_bnb_status(model)
if name in pipeline._exclude_from_cpu_offload or quant:
continue
elif not is_sequential_cpu_offload_enabled(model):
_set_sequential_cpu_offload_flag(model, True)
accelerate.cpu_offload(model, torch_device, offload_buffers=len(model._parameters) > 0)
_disable_to(
model,
name == 'vae'
)
[docs]
def enable_model_cpu_offload(pipeline: diffusers.DiffusionPipeline,
device: torch.device | str = _torchutil.default_device()):
"""
Enable sequential model cpu offload on a torch pipeline, in a way dgenerate can keep track of.
:param pipeline: the pipeline
:param device: the device
"""
if pipeline.model_cpu_offload_seq is None:
raise ValueError(
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
)
torch_device = torch.device(device)
pipeline.remove_all_hooks()
pipeline._offload_gpu_id = torch_device.index or getattr(pipeline, "_offload_gpu_id", 0)
device_type = torch_device.type
# Check if the requested device type is actually available
# If not, fall back to the system's default device
if device_type == 'cuda' and not _torchutil.is_cuda_available():
fallback_device = _torchutil.default_device()
_messages.debug_log(
f'enable_model_cpu_offload: CUDA requested but not available, using {fallback_device} for execution device')
device = torch.device(fallback_device)
elif device_type == 'mps' and not _torchutil.is_mps_available():
fallback_device = _torchutil.default_device()
_messages.debug_log(
f'enable_model_cpu_offload: MPS requested but not available, using {fallback_device} for execution device')
device = torch.device(fallback_device)
elif device_type == 'xpu' and not _torchutil.is_xpu_available():
fallback_device = _torchutil.default_device()
_messages.debug_log(
f'enable_model_cpu_offload: XPU requested but not available, using {fallback_device} for execution device')
device = torch.device(fallback_device)
else:
device = torch.device(f"{device_type}:{pipeline._offload_gpu_id}")
if pipeline.device.type != "cpu":
pipeline.to("cpu", silence_dtype_warnings=True)
device_mod = getattr(torch, pipeline.device.type, None)
if device_mod is not None and hasattr(device_mod, "empty_cache") and hasattr(device_mod, "is_available") and device_mod.is_available():
device_mod.empty_cache()
_set_cpu_offload_flag(pipeline, True)
all_model_components = {k: v for k, v in pipeline.components.items() if isinstance(v, torch.nn.Module)}
hook = None
pipeline._all_hooks = []
for model_str in pipeline.model_cpu_offload_seq.split("->"):
model = all_model_components.pop(model_str, None)
if not isinstance(model, torch.nn.Module):
continue
_, _, is_loaded_in_8bit_bnb = _util.check_bnb_status(model)
if is_loaded_in_8bit_bnb:
_messages.debug_log(
f'Not cpu offloading pipeline module: {model_str}, due to bitsandbytes 8 bit quantization.')
continue
_, hook = accelerate.cpu_offload_with_hook(model, device, prev_module_hook=hook)
_set_cpu_offload_flag(model, True)
pipeline._all_hooks.append(hook)
for name, model in all_model_components.items():
if not isinstance(model, torch.nn.Module):
continue
if name in pipeline._exclude_from_cpu_offload:
model.to(device)
else:
_, hook = accelerate.cpu_offload_with_hook(model, device)
_set_cpu_offload_flag(model, True)
pipeline._all_hooks.append(hook)
[docs]
def get_torch_device(component: diffusers.DiffusionPipeline | torch.nn.Module) -> torch.device:
"""
Get the device that a pipeline or pipeline component exists on.
:param component: pipeline or pipeline component.
:return: :py:class:`torch.device`
"""
if hasattr(component, 'device'):
return component.device
elif hasattr(component, 'get_device'):
return component.get_device()
raise ValueError(f'component type {component.__class__} did not have a '
f'device attribute or the function get_device()')
[docs]
def get_torch_device_string(component: diffusers.DiffusionPipeline | torch.nn.Module) -> str:
"""
Get the device string that a pipeline or pipeline component exists on.
:param component: pipeline or pipeline component.
:return: device string
"""
return str(get_torch_device(component))
def _pipeline_to(pipeline, device: torch.device | str | None):
if device is None:
_messages.debug_log(
f'pipeline_to() Not moving pipeline "{pipeline.__class__.__name__}" '
f'as specified device was None.')
return
if not hasattr(pipeline, 'to'):
_messages.debug_log(
f'pipeline_to() Not moving pipeline "{pipeline.__class__.__name__}" to '
f'"{device}" as it has no to() method.')
return
device = torch.device(device)
pipeline_device = get_torch_device(pipeline)
all_modules_on_device = all(
_torchutil.devices_equal(device, get_torch_device(m)) for m in get_pipeline_modules(pipeline).values())
pipeline_on_device = _torchutil.devices_equal(get_torch_device(pipeline), pipeline_device)
if pipeline_on_device and all_modules_on_device:
_messages.debug_log(
f'pipeline_to() Not moving pipeline "{pipeline.__class__.__name__}" to '
f'"{device}" as it is already on that device.')
return
if pipeline_on_device != all_modules_on_device:
# really the most likely way for this to occur is if
# an OOM happened moving a pipeline to the GPU, which
# is something we want to be able to recover from hence
# the fall through above
#
# This also happens when the pipeline has cpu offload
# enabled, we can fall through that harmlessly as its
# modules can never be moved to anything but the CPU
# and that is accounted for below
_messages.debug_log(
f'pipeline_to() Moving pipeline "{pipeline.__class__.__name__}" to "{device}", '
f'pipeline_on_device={pipeline_on_device}, all_modules_on_device={all_modules_on_device}.')
if not _torchutil.devices_equal(pipeline_device, device):
try:
cache_metadata = _pipeline_cache.get_metadata(pipeline)
if device.type != 'cpu':
_pipeline_cache.size -= cache_metadata.size
_messages.debug_log(
f'Cached {_types.class_and_id_string(pipeline)} Size = '
f'{cache_metadata.size} Bytes '
f'({_memory.bytes_best_human_unit(cache_metadata.size)}) '
f'is leaving CPU side memory, '
f'cache size is now '
f'{cache_metadata.size} Bytes '
f'({_memory.bytes_best_human_unit(cache_metadata.size)})')
else:
_messages.debug_log(
f'Cached {_types.class_and_id_string(pipeline)} Size = '
f'{cache_metadata.size} Bytes '
f'({_memory.bytes_best_human_unit(cache_metadata.size)}) '
f'is entering CPU side memory, '
f'cache size is now '
f'{cache_metadata.size} Bytes '
f'({_memory.bytes_best_human_unit(cache_metadata.size)})')
_pipeline_cache.size += cache_metadata.size
except _d_memoize.ObjectCacheKeyError:
# does not exist in the cache
pass
for name, value in get_pipeline_modules(pipeline).items():
is_loaded_in_8bit_bnb = _util.is_loaded_in_8bit_bnb(value)
if is_loaded_in_8bit_bnb:
_messages.debug_log(
f'pipeline_to() Not moving module "{name} = {value.__class__.__name__}" to "{device}" '
f'as it is loaded in 8bit mode via bitsandbytes.')
_disable_to(value)
continue
current_device = get_torch_device(value)
if current_device.type == 'meta':
_messages.debug_log(
f'pipeline_to() Not moving module "{name} = {value.__class__.__name__}" to "{device}" '
f'as its device value is "meta".')
_disable_to(
value,
name == 'vae'
)
continue
if _torchutil.devices_equal(current_device, device):
_messages.debug_log(
f'pipeline_to() Not moving module "{name} = {value.__class__.__name__}" to "{device}" '
f'as it is already on that device.')
continue
if is_model_cpu_offload_enabled(value) and device.type != 'cpu':
_messages.debug_log(
f'pipeline_to() Not moving module "{name} = {value.__class__.__name__}" to "{device}" '
f'as it has cpu offload enabled and can only move to cpu.')
continue
_messages.debug_log(
f'pipeline_to() Moving module "{name}" of pipeline {_types.fullname(pipeline)} '
f'from device "{current_device}" to device "{device}"')
value.to(device)
if device.type == 'cpu':
_memory.torch_gc()
[docs]
def pipeline_to(pipeline, device: torch.device | str | None):
"""
Move a diffusers pipeline to a device if possible, in a way that dgenerate can keep track of.
This calls methods associated with updating the cache statistics such as
:py:func:`dgenerate.pipelinewrapper.pipeline_off_cpu_update_cache_info` and
:py:func:`dgenerate.pipelinewrapper.pipeline_to_cpu_update_cache_info` for you,
as well as the associated cache update functions for the pipelines individual
components as needed.
If ``device==None`` this is a no-op.
Modules which are meta tensors will not be moved (sequentially offloaded modules)
Modules which have model cpu offload enabled will not be moved unless they are moving to "cpu"
:raise dgenerate.OutOfMemoryError: if there is not enough memory on the specified device
:param pipeline: the pipeline
:param device: the device
:return: the moved pipeline
"""
try:
_pipeline_to(pipeline=pipeline, device=device)
except _d_exceptions.TORCH_CUDA_OOM_EXCEPTIONS as e:
_d_exceptions.raise_if_not_cuda_oom(e)
# attempt to recover VRAM before rethrowing
# move any modules back to cpu which have entered VRAM
_pipeline_to(pipeline=pipeline, device='cpu')
_memory.torch_gc()
gc.collect()
raise _d_exceptions.OutOfMemoryError(e) from e
except MemoryError as e:
# probably out of RAM on a back
# to CPU move not much we can do
gc.collect()
raise _d_exceptions.OutOfMemoryError('cpu (system memory)') from e
def _call_args_debug_transformer(key, value):
if isinstance(value, torch.Generator):
return f'torch.Generator(seed={value.initial_seed()})'
if isinstance(value, torch.Tensor):
return f'torch.Tensor({value.shape})'
return value
def _warn_prompt_lengths(pipeline, **kwargs):
prompts = [
('Primary positive prompt', kwargs.get('prompt'), 'tokenizer'),
('Primary negative prompt', kwargs.get('negative_prompt'), 'tokenizer'),
('Secondary positive prompt', kwargs.get('prompt_2'), 'tokenizer_2'),
('Secondary negative prompt', kwargs.get('negative_prompt_2'), 'tokenizer_2'),
('Tertiary positive prompt', kwargs.get('prompt_3'), 'tokenizer_3'),
('Tertiary negative prompt', kwargs.get('negative_prompt_3'), 'tokenizer_3')
]
warned_prompts = {}
for label, prompt, tokenizer_attr in prompts:
if prompt and not isinstance(prompt, list):
prompt = [prompt]
if prompt:
tokenizer = getattr(pipeline, tokenizer_attr, None)
if tokenizer:
if tokenizer_attr == 'tokenizer_3' and pipeline.__class__.__name__.startswith('StableDiffusion3'):
max_length = min(kwargs.get('max_sequence_length', 256), tokenizer.model_max_length)
elif tokenizer_attr == 'tokenizer_2' and pipeline.__class__.__name__.startswith('Flux'):
max_length = min(kwargs.get('max_sequence_length', 512), tokenizer.model_max_length)
else:
max_length = tokenizer.model_max_length
for p in prompt:
if len(tokenizer.tokenize(p)) > max_length:
key = f'{label}{tokenizer_attr}{p}'
if key not in warned_prompts:
_messages.warning(
f'{label} exceeds max token length '
f'of {max_length} for the model\'s tokenizer '
f'and will be truncated: "{p}"'
)
warned_prompts[key] = True
_LAST_CALLED_PIPELINE = None
[docs]
def get_last_called_pipeline() -> diffusers.DiffusionPipeline | None:
"""
Get a reference to the globally cached pipeline last called with :py:func:`call_pipeline`.
This value may be ``None`` if a pipeline was never called.
:return: diffusion pipeline object
"""
return _LAST_CALLED_PIPELINE
[docs]
def destroy_last_called_pipeline(collect=True):
"""
Move to CPU and dereference the globally cached pipeline last called with :py:func:`call_pipeline`.
This is a no-op if a pipeline has never been called with :py:func:`call_pipeline`
:param collect: call ``gc.collect`` and :py:func:`dgenerate.memory.torch_gc` if
there is a pipeline to dereference?
"""
global _LAST_CALLED_PIPELINE
if _LAST_CALLED_PIPELINE is not None:
pipeline_to(_LAST_CALLED_PIPELINE, 'cpu')
_LAST_CALLED_PIPELINE = None
if collect:
gc.collect()
_memory.torch_gc()
def _evict_last_pipeline(device: torch.device | None):
active_pipe = get_last_called_pipeline()
if active_pipe is None:
return
if device is None or _torchutil.devices_equal(get_torch_device(active_pipe), device):
# get rid of this reference immediately
# noinspection PyUnusedLocal
active_pipe = None
_messages.debug_log(
f'{_types.fullname(_devicecache.clear_device_cache)} is attempting to evacuate any previously '
f'called diffusion pipeline in the VRAM of device: {device}.')
# potentially free up VRAM on the GPU we are
# about to move to
destroy_last_called_pipeline()
def _evict_8bit_bnb_pipelines(device: torch.device | None):
# clear out any 8bit bnb pipelines in the cache, they are
# on the GPU and cannot be moved
for cached_pipeline in _pipeline_cache.values():
bit8 = any(_util.is_loaded_in_8bit_bnb(module) for
module in get_pipeline_modules(cached_pipeline.pipeline).values())
if bit8 and (device is None or _torchutil.devices_equal(device, get_torch_device(cached_pipeline.pipeline))):
_messages.debug_log(
f'Clearing out cached 8bit pipeline from cache: {cached_pipeline.pipeline.__class__.__name__}')
_pipeline_cache.un_cache(cached_pipeline)
del cached_pipeline
_devicecache.register_eviction_method(_evict_last_pipeline)
_devicecache.register_eviction_method(_evict_8bit_bnb_pipelines)
# noinspection PyCallingNonCallable
[docs]
@torch.inference_mode()
def call_pipeline(pipeline: diffusers.DiffusionPipeline,
device: torch.device | str | None = _torchutil.default_device(),
prompt_weighter: _promptweighters.PromptWeighter = None,
**kwargs):
"""
Call a diffusers pipeline, offload the last called pipeline to CPU before
doing so if the last pipeline is not being called in succession
:param pipeline: The pipeline
:param device: The device to move the pipeline to before calling, it will be
moved to this device if it is not already on the device. If the pipeline
does not support moving to specific device, such as with sequentially offloaded
pipelines which cannot move at all, or cpu offloaded pipelines which can
only move to CPU, this argument is ignored.
:param kwargs: diffusers pipeline keyword arguments
:param prompt_weighter: Optional prompt weighter for weighted prompt syntaxes
:raises dgenerate.OutOfMemoryError: if there is not enough memory on the specified device
:raises UnsupportedPipelineConfiguration:
If the pipeline is missing certain required modules, such as text encoders.
:return: the result of calling the diffusers pipeline
"""
global _LAST_CALLED_PIPELINE
if prompt_weighter is not None:
if not _torchutil.devices_equal(device, prompt_weighter.device):
raise UnsupportedPipelineConfigError(
'dgenerate.pipelinewrapper.call_pipeline: prompt_weighter '
'must specify the same compute device that pipeline is to be called on. '
f'Got prompt_weighter={prompt_weighter.device}, and pipeline={device}')
_messages.debug_log(
f'Calling Pipeline: "{pipeline.__class__.__name__}",',
f'Device: "{device}",',
'Args:',
lambda: _textprocessing.debug_format_args(
kwargs, value_transformer=_call_args_debug_transformer))
enable_retry_pipe = True
def _cleanup_prompt_weighter():
try:
_messages.debug_log(
f'Executing prompt weighter cleanup for "{prompt_weighter.__class__.__name__}"')
prompt_weighter.cleanup()
except Exception as e:
_messages.debug_log(
f'Ignoring prompt weighter cleanup '
f'exception in "{prompt_weighter.__class__.__name__}.cleanup()": {e}')
pass
def _call_prompt_weighter():
nonlocal enable_retry_pipe
try:
translated = prompt_weighter.translate_to_embeds(pipeline, kwargs)
except _d_exceptions.TORCH_CUDA_OOM_EXCEPTIONS as e:
_d_exceptions.raise_if_not_cuda_oom(e)
_cleanup_prompt_weighter()
_memory.torch_gc()
gc.collect()
raise _d_exceptions.OutOfMemoryError(e) from e
except MemoryError as e:
_cleanup_prompt_weighter()
gc.collect()
raise _d_exceptions.OutOfMemoryError('cpu (system memory)') from e
except Exception:
_cleanup_prompt_weighter()
_memory.torch_gc()
gc.collect()
raise
def _debug_string_func():
return f'{prompt_weighter.__class__.__name__} translated pipeline call args to: ' + \
_textprocessing.debug_format_args(
translated,
value_transformer=_call_args_debug_transformer)
_messages.debug_log(_debug_string_func)
return translated
prompt_warning_issued = False
def _call_pipeline_raw():
nonlocal prompt_warning_issued
try:
if prompt_weighter is None:
if not prompt_warning_issued:
_warn_prompt_lengths(pipeline, **kwargs)
prompt_warning_issued = True
pipeline_to(pipeline, device)
pipe_result = pipeline(**kwargs)
else:
args = _call_prompt_weighter()
pipeline_to(pipeline, device)
pipe_result = pipeline(**args)
prompt_weighter.cleanup()
return pipe_result
except TypeError as e:
null_call_name = _types.get_null_call_name(e)
if null_call_name:
raise UnsupportedPipelineConfigError(
'Missing pipeline module?, cannot call: ' + null_call_name)
raise
except AttributeError as e:
null_attr_name = _types.get_null_attr_name(e)
if null_attr_name:
raise UnsupportedPipelineConfigError(
'Missing pipeline module?, cannot access: ' + null_attr_name)
raise
def _torch_oom_handler():
global _LAST_CALLED_PIPELINE
if pipeline is _LAST_CALLED_PIPELINE:
_LAST_CALLED_PIPELINE = None
# move the torch pipeline back to the CPU
pipeline_to(pipeline, 'cpu')
# empty the CUDA cache
_memory.torch_gc()
# force garbage collection
gc.collect()
def _call_pipeline():
nonlocal enable_retry_pipe
old_execution_device_property = None
try:
if hasattr(pipeline, '_execution_device'):
# HACK
# The device this returns is sometimes wrong and causes issues
# with a randomly generated tensor (complaining about) being
# generated on the wrong device as compared to the torch.Generator
# object being used to generate it, this is a diffusers problem in
# the code of this private property
old_execution_device_property = pipeline.__class__._execution_device
pipeline.__class__._execution_device = property(lambda s: torch.device(device))
return _call_pipeline_raw()
except _d_exceptions.TORCH_CUDA_OOM_EXCEPTIONS as e:
_d_exceptions.raise_if_not_cuda_oom(e)
_torch_oom_handler()
raise _d_exceptions.OutOfMemoryError(e) from e
except MemoryError as e:
gc.collect()
raise _d_exceptions.OutOfMemoryError('cpu (system memory)') from e
except Exception:
# same cleanup
_torch_oom_handler()
raise
finally:
if old_execution_device_property is not None:
pipeline.__class__._execution_device = old_execution_device_property
for cached_pipeline in _pipeline_cache.values():
# hack to clear out cached 8bit pipelines, only one should ever
# be in the object cache for sequential calls, this is necessary
# for efficient main pipeline recall in DiffusionPipelineWrapper
# which is used for the adetailer post processor
if cached_pipeline.pipeline is not pipeline:
bit8 = any(_util.is_loaded_in_8bit_bnb(module) for
module in get_pipeline_modules(cached_pipeline.pipeline).values())
if bit8 and _torchutil.devices_equal(device, get_torch_device(cached_pipeline.pipeline)):
_messages.debug_log(
f'Clearing out cached 8bit pipeline from cache: {cached_pipeline.pipeline.__class__.__name__}')
_pipeline_cache.un_cache(cached_pipeline)
del cached_pipeline
gc.collect()
_memory.torch_gc()
if pipeline is _LAST_CALLED_PIPELINE:
try:
return _call_pipeline()
except _d_exceptions.OutOfMemoryError:
if not enable_retry_pipe:
raise
_messages.debug_log(
f'Attempting to call pipeline '
f'"{pipeline.__class__.__name__}" again after out '
f'of memory condition and cleanup.')
# retry after memory cleanup
result = _call_pipeline()
_LAST_CALLED_PIPELINE = pipeline
return result
if _LAST_CALLED_PIPELINE is not None and hasattr(_LAST_CALLED_PIPELINE, 'to'):
_messages.debug_log(
f'Moving previously called pipeline '
f'"{_LAST_CALLED_PIPELINE.__class__.__name__}", back to the CPU.')
pipeline_to(_LAST_CALLED_PIPELINE, 'cpu')
try:
result = _call_pipeline()
except _d_exceptions.OutOfMemoryError:
if not enable_retry_pipe:
raise
_messages.debug_log(
f'Attempting to call pipeline '
f'"{pipeline.__class__.__name__}" again after out '
f'of memory condition and cleanup.')
# allow for memory cleanup and try again
# might be able to run now
result = _call_pipeline()
_LAST_CALLED_PIPELINE = pipeline
return result
[docs]
class PipelineCreationResult:
model_path: _types.OptionalPath
"""
Path the model was loaded from.
"""
parsed_unet_uri: _uris.UNetUri | None
"""
Parsed UNet URI if one was present
"""
parsed_vae_uri: _uris.VAEUri | None
"""
Parsed VAE URI if one was present
"""
parsed_lora_uris: collections.abc.Sequence[_uris.LoRAUri]
"""
Parsed LoRA URIs if any were present
"""
parsed_ip_adapter_uris: collections.abc.Sequence[_uris.IPAdapterUri]
"""
Parsed IP Adapter URIs if any were present
"""
parsed_textual_inversion_uris: collections.abc.Sequence[_uris.TextualInversionUri]
"""
Parsed Textual Inversion URIs if any were present
"""
parsed_controlnet_uris: collections.abc.Sequence[_uris.ControlNetUri]
"""
Parsed ControlNet URIs if any were present
"""
parsed_t2i_adapter_uris: collections.abc.Sequence[_uris.T2IAdapterUri]
"""
Parsed T2IAdapter URIs if any were present
"""
parsed_image_encoder_uri: _uris.ImageEncoderUri | None
"""
Parsed ImageEncoder URI if one was present
"""
parsed_transformer_uri: _uris.TransformerUri | None
"""
Parsed Transformer URI if one was present
"""
[docs]
def load_scheduler(self, scheduler_uri: _types.Uri | None):
"""
Load a scheduler onto the pipeline using a URI specification.
Passing ``None`` to the URI reloads the original scheduler that the model was loaded
with, if no new scheduler has been set since then, this is a no-op.
:param scheduler_uri: The scheduler URI
"""
_schedulers.load_scheduler(self.pipeline, scheduler_uri)
[docs]
def set_vae_tiling_and_slicing(self, vae_tiling: bool, vae_slicing: bool):
"""
Set the VAE tiling and slicing status of the pipeline.
:param vae_tiling: vae tiling?
:param vae_slicing: vae slicing?
"""
set_vae_tiling_and_slicing(self.pipeline, tiling=vae_tiling, slicing=vae_slicing)
[docs]
def __init__(self,
model_path: _types.Path,
pipeline: diffusers.DiffusionPipeline,
parsed_unet_uri: _uris.UNetUri | None,
parsed_transformer_uri: _uris.TransformerUri | None,
parsed_vae_uri: _uris.VAEUri | None,
parsed_image_encoder_uri: _uris.ImageEncoderUri | None,
parsed_lora_uris: collections.abc.Sequence[_uris.LoRAUri],
parsed_ip_adapter_uris: collections.abc.Sequence[_uris.IPAdapterUri],
parsed_textual_inversion_uris: collections.abc.Sequence[_uris.TextualInversionUri],
parsed_controlnet_uris: collections.abc.Sequence[_uris.ControlNetUri],
parsed_t2i_adapter_uris: collections.abc.Sequence[_uris.T2IAdapterUri]):
self.model_path = model_path
self.parsed_unet_uri = parsed_unet_uri
self.parsed_vae_uri = parsed_vae_uri
self.parsed_lora_uris = parsed_lora_uris
self.parsed_textual_inversion_uris = parsed_textual_inversion_uris
self.parsed_controlnet_uris = parsed_controlnet_uris
self.parsed_t2i_adapter_uris = parsed_t2i_adapter_uris
self.parsed_ip_adapter_uris = parsed_ip_adapter_uris
self.parsed_image_encoder_uri = parsed_image_encoder_uri
self.parsed_transformer_uri = parsed_transformer_uri
self._pipeline = pipeline
@property
def pipeline(self):
return self._pipeline
[docs]
def get_pipeline_modules(self, names: collections.abc.Iterable[str]):
"""
Get associated pipeline module such as ``vae`` etc, in
a dictionary mapped from name to module value.
Possible Module Names:
* ``unet``
* ``vae``
* ``transformer``
* ``text_encoder``
* ``text_encoder_2``
* ``text_encoder_3``
* ``tokenizer``
* ``tokenizer_2``
* ``tokenizer_3``
* ``safety_checker``
* ``feature_extractor``
* ``image_encoder``
* ``adapter``
* ``controlnet``
* ``scheduler``
If the module is not present or a recognized name, a :py:exc:`ValueError`
will be thrown describing the module that is not part of the pipeline.
:raise ValueError:
:param names: module names, such as ``vae``, ``text_encoder``
:return: dictionary
"""
module_values = dict()
acceptable_lookups = {
'unet',
'vae',
'transformer',
'text_encoder',
'text_encoder_2',
'text_encoder_3',
'tokenizer',
'tokenizer_2',
'tokenizer_3',
'safety_checker',
'feature_extractor',
'image_encoder',
'adapter',
'controlnet',
'scheduler'
}
for name in names:
if name not in acceptable_lookups:
raise ValueError(f'"{name}" is not a recognized pipeline module name.')
if not hasattr(self.pipeline, name):
raise ValueError(f'Created pipeline does not possess a module named: "{name}".')
module_values[name] = getattr(self.pipeline, name)
return module_values
[docs]
def call(self,
device: torch.device | str | None = _torchutil.default_device(),
prompt_weighter: _promptweighters.PromptWeighter | None = None,
**kwargs) -> diffusers.utils.BaseOutput:
"""
Call **pipeline**, see: :py:func:`.call_pipeline`
:param device: move the pipeline to this device before calling
:param prompt_weighter: Optional prompt weighter for weighted prompt syntaxes
:param kwargs: forward kwargs to pipeline
:return: A subclass of :py:class:`diffusers.utils.BaseOutput`
"""
return call_pipeline(self.pipeline,
device,
prompt_weighter,
**kwargs)
[docs]
def create_diffusion_pipeline(
model_path: str,
model_type: _enums.ModelType = _enums.ModelType.SD,
pipeline_type: _enums.PipelineType = _enums.PipelineType.TXT2IMG,
revision: _types.OptionalString = None,
variant: _types.OptionalString = None,
subfolder: _types.OptionalString = None,
dtype: _enums.DataType = _enums.DataType.AUTO,
unet_uri: _types.OptionalUri = None,
transformer_uri: _types.OptionalUri = None,
vae_uri: _types.OptionalUri = None,
lora_uris: _types.OptionalUris = None,
lora_fuse_scale: _types.OptionalFloat = None,
image_encoder_uri: _types.OptionalUri = None,
ip_adapter_uris: _types.OptionalUris = None,
textual_inversion_uris: _types.OptionalUris = None,
text_encoder_uris: _types.OptionalUris = None,
controlnet_uris: _types.OptionalUris = None,
t2i_adapter_uris: _types.OptionalUris = None,
quantizer_uri: _types.OptionalUri = None,
quantizer_map: _types.OptionalStrings = None,
pag: bool = False,
safety_checker: bool = False,
original_config: _types.OptionalString = None,
auth_token: _types.OptionalString = None,
device: str = _torchutil.default_device(),
extra_modules: dict[str, typing.Any] | None = None,
model_cpu_offload: bool = False,
sequential_cpu_offload: bool = False,
local_files_only: bool = False,
missing_submodules_ok: bool = False
) -> PipelineCreationResult:
"""
Create a :py:class:`diffusers.DiffusionPipeline` in dgenerate's in memory cacheing system.
:param model_type: :py:class:`dgenerate.pipelinewrapper.ModelType` enum value
:param model_path: huggingface slug, huggingface blob link, path to folder on disk, path to file on disk
:param pipeline_type: :py:class:`dgenerate.pipelinewrapper.PipelineType` enum value
:param revision: huggingface repo revision (branch)
:param variant: model weights name variant, for example 'fp16'
:param subfolder: huggingface repo subfolder if applicable
:param dtype: Optional :py:class:`dgenerate.pipelinewrapper.DataType` enum value
:param unet_uri: Optional ``--unet`` URI string for specifying a specific UNet
:param transformer_uri: Optional ``--transformer`` URI string for specifying a specific Transformer,
currently this is only supported for Stable Diffusion 3 and Flux models.
:param vae_uri: Optional ``--vae`` URI string for specifying a specific VAE
:param lora_uris: Optional ``--loras`` URI strings for specifying LoRA weights
:param lora_fuse_scale: Optional ``--lora-fuse-scale`` global LoRA fuse scale value.
Once all LoRAs are merged with their individual scales, the merged weights will be fused
into the pipeline at this scale. The default value is 1.0.
:param image_encoder_uri: Optional ``--image-encoder`` URI for use with IP Adapter weights or Stable Cascade
:param ip_adapter_uris: Optional ``--ip-adapters`` URI strings for specifying IP Adapter weights
:param textual_inversion_uris: Optional ``--textual-inversions`` URI strings for specifying Textual Inversion weights
:param text_encoder_uris: Optional user specified ``--text-encoders`` URIs that will be loaded on to the
pipeline in order. A uri value of ``+`` or ``None`` indicates use default, a string value of ``null``
indicates to explicitly not load any encoder all
:param controlnet_uris: Optional ``--control-nets`` URI strings for specifying ControlNet models
:param t2i_adapter_uris: Optional ``--t2i-adapters`` URI strings for specifying T2IAdapter models
:param quantizer_uri: Optional ``--quantizer`` URI value
:param quantizer_map: Collection of pipeline submodule names to which quantization should be applied when
``quantizer_uri`` is provided. Valid values include: ``unet``, ``transformer``, ``text_encoder``,
``text_encoder_2``, ``text_encoder_3``, and ``controlnet``. If ``None``, all supported modules will be quantized,
except for ``controlnet``.
:param pag: Use perturbed attention guidance?
:param safety_checker: Safety checker enabled? default is ``False``
:param original_config: Optional original training config .yaml file path when loading a single file checkpoint.
:param auth_token: Optional huggingface API token for accessing repositories that are restricted to your account
:param device: Optional ``--device`` string, defaults to "cuda"
:param extra_modules: Extra module arguments to pass directly into
:py:meth:`diffusers.DiffusionPipeline.from_single_file` or :py:meth:`diffusers.DiffusionPipeline.from_pretrained`
:param model_cpu_offload: This pipeline has model_cpu_offloading enabled?
:param sequential_cpu_offload: This pipeline has sequential_cpu_offloading enabled?
:param local_files_only: Only look in the huggingface cache and do not connect to download models?
:param missing_submodules_ok: It is okay if Text Encoders or VAE is missing from the checkpoint?
:raises InvalidModelFileError:
:raises InvalidModelUriError:
:raises InvalidSchedulerNameError:
:raises UnsupportedPipelineConfigError:
:raises dgenerate.ModelNotFoundError:
:raises dgenerate.ConfigNotFoundError:
:raises dgenerate.NonHFModelDownloadError:
:raises dgenerate.NonHFConfigDownloadError:
:raises dgenerate.WebFileCacheOfflineModeException:
:return: :py:class:`.TorchPipelineCreationResult`
"""
__locals = locals()
for name, value in __locals.items():
if name.endswith('_uris') and isinstance(value, str):
__locals[name] = [value]
with _hfhub.with_hf_errors_as_model_not_found():
return _create_diffusion_pipeline(**__locals)
[docs]
class PipelineFactory:
"""
Turns :py:func:`.create_diffusion_pipeline` into a factory that can
repeatedly create a pipeline with the same arguments, possibly from cache.
"""
[docs]
def __init__(self,
model_path: str,
model_type: _enums.ModelType = _enums.ModelType.SD,
pipeline_type: _enums.PipelineType = _enums.PipelineType.TXT2IMG,
revision: _types.OptionalString = None,
variant: _types.OptionalString = None,
subfolder: _types.OptionalString = None,
dtype: _enums.DataType = _enums.DataType.AUTO,
unet_uri: _types.OptionalUri = None,
transformer_uri: _types.OptionalUri = None,
vae_uri: _types.OptionalUri = None,
lora_uris: _types.OptionalUris = None,
lora_fuse_scale: _types.OptionalFloat = None,
image_encoder_uri: _types.OptionalUri = None,
ip_adapter_uris: _types.OptionalUris = None,
textual_inversion_uris: _types.OptionalUris = None,
controlnet_uris: _types.OptionalUris = None,
t2i_adapter_uris: _types.OptionalUris = None,
text_encoder_uris: _types.OptionalUris = None,
quantizer_uri: _types.OptionalUri = None,
quantizer_map: _types.OptionalStrings = None,
pag: bool = False,
safety_checker: bool = False,
original_config: _types.OptionalString = None,
auth_token: _types.OptionalString = None,
device: str = _torchutil.default_device(),
extra_modules: dict[str, typing.Any] | None = None,
model_cpu_offload: bool = False,
sequential_cpu_offload: bool = False,
local_files_only: bool = False):
self._args = {k: v for k, v in
_types.partial_deep_copy_container(locals()).items()
if k not in {'self'}}
[docs]
def __call__(self) -> PipelineCreationResult:
"""
:raises InvalidModelFileError:
:raises ModelNotFoundError:
:raises InvalidModelUriError:
:raises InvalidSchedulerNameError:
:raises UnsupportedPipelineConfigError:
:raises dgenerate.NonHFModelDownloadError:
:raises dgenerate.NonHFConfigDownloadError:
:return: :py:class:`.TorchPipelineCreationResult`
"""
return create_diffusion_pipeline(**self._args)
def _format_pipeline_creation_debug_arg(arg_name, v):
if isinstance(v, torch.dtype):
return str(v)
if isinstance(v, str):
return f'"{v}"'
if v.__class__.__module__ != 'builtins':
return _types.class_and_id_string(v)
if isinstance(v, list):
return '[' + ', '.join(_format_pipeline_creation_debug_arg(None, v) for v in v) + ']'
if isinstance(v, (set, frozenset)):
return '{' + ', '.join(_format_pipeline_creation_debug_arg(None, v) for v in v) + '}'
if isinstance(v, dict):
return '{' + ', '.join(f'"{k}"={_format_pipeline_creation_debug_arg(None, v)}' for k, v in v.items()) + '}'
return str(v)
def _pipeline_creation_args_debug(backend, cls, method, model, **kwargs):
_messages.debug_log(
lambda:
f'{backend} Pipeline Creation Call: {cls.__name__}.{method.__name__}("{model}", ' +
_textprocessing.debug_format_args(kwargs, _format_pipeline_creation_debug_arg, as_kwargs=True) + ')')
return method(model, **kwargs)
def _text_encoder_default(uri):
return uri is None or uri.strip() == '+'
def _text_encoder_null(uri):
return uri and uri.strip().lower() == 'null'
def _pipeline_args_hasher(args):
def text_encoder_uri_parse(uri):
if _text_encoder_default(uri):
return None
if uri.strip() == 'help':
return 'help'
if uri.strip() == 'null':
return 'null'
return _uris.TextEncoderUri.parse(uri)
quantizer_uri = args['quantizer_uri']
custom_hashes = {
'unet_uri': _uris.uri_hash_with_parser(_uris.UNetUri.parse),
'transformer_uri': _uris.uri_hash_with_parser(_uris.TransformerUri),
'vae_uri': _uris.uri_hash_with_parser(_uris.VAEUri.parse),
'image_encoder_uri': _uris.uri_hash_with_parser(_uris.ImageEncoderUri),
'lora_uris': _uris.uri_list_hash_with_parser(_uris.LoRAUri.parse),
'ip_adapter_uris': _uris.uri_list_hash_with_parser(_uris.IPAdapterUri),
'textual_inversion_uris': _uris.uri_list_hash_with_parser(_uris.TextualInversionUri.parse),
'text_encoder_uris': _uris.uri_list_hash_with_parser(text_encoder_uri_parse),
'controlnet_uris': _uris.uri_list_hash_with_parser(
lambda s: _uris.ControlNetUri.parse(s, model_type=args['model_type']),
exclude={'scale', 'start', 'end'}),
't2i_adapter_uris': _uris.uri_list_hash_with_parser(_uris.T2IAdapterUri.parse,
exclude={'scale'}),
'quantizer_uri':
_uris.uri_hash_with_parser(
_uris.get_quantizer_uri_class(quantizer_uri).parse)
if quantizer_uri else lambda x: None,
'quantizer_map': lambda x: hash(tuple(sorted(x))) if x else None
}
return _d_memoize.args_cache_key(args, custom_hashes=custom_hashes)
def _pipeline_on_hit(key, hit):
_d_memoize.simple_cache_hit_debug("Torch Pipeline", key, hit.pipeline)
def _pipeline_on_create(key, new):
_d_memoize.simple_cache_miss_debug('Torch Pipeline', key, new.pipeline)
def pipeline_class_supports_textual_inversion(cls: typing.Type[diffusers.DiffusionPipeline]):
"""
Does a pipeline class support Textual Inversions?
:param cls: ``diffusers`` pipeline class
:return: ``True`` or ``False``
"""
return any('TextualInversionLoaderMixin' in x.__name__ for x in cls.__bases__)
def pipeline_class_supports_lora(cls: typing.Type[diffusers.DiffusionPipeline]):
"""
Does a pipeline class support LoRAs?
:param cls: ``diffusers`` pipeline class
:return: ``True`` or ``False``
"""
return any('LoraLoaderMixin' in x.__name__ for x in cls.__bases__)
def pipeline_class_supports_ip_adapter(cls: typing.Type[diffusers.DiffusionPipeline]):
"""
Does a pipeline class support IP Adapters?
:param cls: ``diffusers`` pipeline class
:return: ``True`` or ``False``
"""
return any('IPAdapterMixin' in x.__name__ for x in cls.__bases__)
[docs]
def get_pipeline_class(
model_type: _enums.ModelType = _enums.ModelType.SD,
pipeline_type: _enums.PipelineType = _enums.PipelineType.TXT2IMG,
unet_uri: _types.OptionalUri = None,
transformer_uri: _types.OptionalUri = None,
vae_uri: _types.OptionalUri = None,
lora_uris: _types.OptionalUris = None,
image_encoder_uri: _types.OptionalUri = None,
ip_adapter_uris: _types.OptionalUris = None,
textual_inversion_uris: _types.OptionalUris = None,
controlnet_uris: _types.OptionalUris = None,
t2i_adapter_uris: _types.OptionalUris = None,
pag: bool = False,
help_mode: bool = False
) -> typing.Type[diffusers.DiffusionPipeline]:
"""
Get an appropriate ``diffusers`` pipeline class for the provided arguments.
:param model_type: :py:class:`dgenerate.pipelinewrapper.ModelType` enum value
:param pipeline_type: :py:class:`dgenerate.pipelinewrapper.PipelineType` enum value
:param unet_uri: Optional ``--unet`` URI string for specifying a specific UNet
:param transformer_uri: Optional ``--transformer`` URI string for specifying a specific Transformer,
currently this is only supported for Stable Diffusion 3 and Flux models.
:param vae_uri: Optional ``--vae`` URI string for specifying a specific VAE
:param lora_uris: Optional ``--loras`` URI strings for specifying LoRA weights
:param image_encoder_uri: Optional ``--image-encoder`` URI for use with IP Adapter weights or Stable Cascade
:param ip_adapter_uris: Optional ``--ip-adapters`` URI strings for specifying IP Adapter weights
:param textual_inversion_uris: Optional ``--textual-inversions`` URI strings for specifying Textual Inversion weights
:param controlnet_uris: Optional ``--control-nets`` URI strings for specifying ControlNet models
:param t2i_adapter_uris: Optional ``--t2i-adapters`` URI strings for specifying T2IAdapter models
:param pag: Use perturbed attention guidance?
:param help_mode: Return the class even if it does not support the selected ``pipeline_type``
:raises UnsupportedPipelineConfigError:
"""
# PAG check
if pag:
if not (model_type == _enums.ModelType.SD or
model_type == _enums.ModelType.SDXL or
model_type == _enums.ModelType.SD3 or
model_type == _enums.ModelType.KOLORS):
raise UnsupportedPipelineConfigError(
'Perturbed attention guidance (--pag*) is only supported with '
'--model-type sd, sdxl, kolors (txt2img), and sd3.')
if t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'Perturbed attention guidance (--pag*) is is not supported '
'with --t2i-adapters.')
# Flux model restrictions
if _enums.model_type_is_flux(model_type):
if t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'Flux --model-type values are not compatible with --t2i-adapters.')
if ip_adapter_uris and not image_encoder_uri:
raise UnsupportedPipelineConfigError(
'Must specify --image-encoder when using --ip-adapters with Flux.')
if ip_adapter_uris and len(ip_adapter_uris) > 1:
raise UnsupportedPipelineConfigError(
'Flux --model-type values do not support multiple --ip-adapters.')
if model_type == _enums.ModelType.FLUX_FILL:
if pipeline_type != _enums.PipelineType.INPAINT:
raise UnsupportedPipelineConfigError(
'Flux fill --model-type value does not support anything but inpaint mode.'
)
# Deep Floyd model restrictions
if _enums.model_type_is_floyd(model_type):
if controlnet_uris:
raise UnsupportedPipelineConfigError(
'Deep Floyd --model-type values are not compatible with --control-nets.')
if t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'Deep Floyd --model-type values are not compatible with --t2i-adapters.')
if vae_uri:
raise UnsupportedPipelineConfigError(
'Deep Floyd --model-type values are not compatible with --vae.')
if image_encoder_uri:
raise UnsupportedPipelineConfigError(
'Deep Floyd --model-type values are not compatible with --image-encoder.')
# Stable Cascade model restrictions
if _enums.model_type_is_s_cascade(model_type):
if controlnet_uris:
raise UnsupportedPipelineConfigError(
'Stable Cascade --model-type values are not compatible with --control-nets.')
if t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'Stable Cascade --model-type values are not compatible with --t2i-adapters.')
if vae_uri:
raise UnsupportedPipelineConfigError(
'Stable Cascade --model-type values are not compatible with --vae.')
# Torch SD3 restrictions
if _enums.model_type_is_sd3(model_type):
if t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'--model-type sd3 is not compatible with --t2i-adapters.')
if unet_uri:
raise UnsupportedPipelineConfigError(
'--model-type sd3 is not compatible with --unet.')
# Torch Kolors restrictions
if _enums.model_type_is_sd3(model_type):
if t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'--model-type kolors is not compatible with --t2i-adapters.')
if transformer_uri:
if not _enums.model_type_is_sd3(model_type) and not _enums.model_type_is_flux(model_type):
raise UnsupportedPipelineConfigError(
'--transformer is only supported for --model-type sd3 and flux.')
# Incompatible combinations
if controlnet_uris and t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'--control-nets and --t2i-adapters cannot be used together.')
if image_encoder_uri and not ip_adapter_uris and model_type != _enums.ModelType.S_CASCADE:
raise UnsupportedPipelineConfigError(
'--image-encoder cannot be specified without --ip-adapters if --model-type is not s-cascade.')
# Pix2Pix model restrictions
is_pix2pix = _enums.model_type_is_pix2pix(model_type)
if is_pix2pix:
if controlnet_uris:
raise UnsupportedPipelineConfigError(
'Pix2Pix --model-type values are not compatible with --control-nets.')
if t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'Pix2Pix --model-type values are not compatible with --t2i-adapters.')
if image_encoder_uri and model_type != _enums.ModelType.PIX2PIX:
raise UnsupportedPipelineConfigError(
'Only Pix2Pix --model-type pix2pix is compatible '
'with --image-encoder. Pix2Pix SDXL is not supported.')
is_sdxl = _enums.model_type_is_sdxl(model_type)
is_sd3 = _enums.model_type_is_sd3(model_type)
sdxl_controlnet_union = False
parsed_control_net_uris = None
try:
if controlnet_uris and is_sdxl:
parsed_control_net_uris = [_uris.ControlNetUri.parse(s, model_type) for s in controlnet_uris]
sdxl_controlnet_union = controlnet_uris and is_sdxl and any(
s.mode is not None for s in parsed_control_net_uris)
except _uris.InvalidControlNetUriError as e:
raise UnsupportedPipelineConfigError(str(e)) from e
def eq_cn_uri(
uri1: _uris.ControlNetUri,
uri2: _uris.ControlNetUri):
equals = True
for name, val in _types.get_public_attributes(uri1).items():
if name not in {'scale', 'mode', 'start', 'end'}:
equals = (equals and val == getattr(uri2, name))
return equals
if sdxl_controlnet_union and \
any(not eq_cn_uri(parsed_control_net_uris[0], u)
for u in parsed_control_net_uris):
raise UnsupportedPipelineConfigError(
'SDXL ControlNet Union mode requires all ControlNet '
'model URIs to be identical with the exception of the '
'"scale", "mode", "start", and "end" arguments.'
)
# Pipeline class selection
if _enums.model_type_is_upscaler(model_type):
if controlnet_uris:
raise UnsupportedPipelineConfigError(
'Upscaler models are not compatible with --control-nets.')
if t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'Upscaler models are not compatible with --t2i-adapters.')
if image_encoder_uri:
raise UnsupportedPipelineConfigError(
'Upscaler models are not compatible with --image-encoder.')
if pipeline_type != _enums.PipelineType.IMG2IMG and not help_mode:
raise UnsupportedPipelineConfigError(
'Upscaler models only work with img2img generation, IE: --image-seeds (with no image masks).')
pipeline_class = (
diffusers.StableDiffusionUpscalePipeline
if model_type == _enums.ModelType.UPSCALER_X4
else diffusers.StableDiffusionLatentUpscalePipeline
)
else:
if pipeline_type == _enums.PipelineType.TXT2IMG:
if is_pix2pix:
if not help_mode:
raise UnsupportedPipelineConfigError(
'Pix2Pix models only work in img2img / inpaint mode and cannot work without --image-seeds.')
else:
if is_sdxl:
# noinspection PyUnusedLocal
pipeline_class = diffusers.StableDiffusionXLInstructPix2PixPipeline
elif is_sd3:
# noinspection PyUnusedLocal
pipeline_class = dgenerate.extras.ultraedit.StableDiffusion3InstructPix2PixPipeline
else:
# noinspection PyUnusedLocal
pipeline_class = diffusers.StableDiffusionInstructPix2PixPipeline
if model_type == _enums.ModelType.FLUX_KONTEXT:
raise UnsupportedPipelineConfigError(
'Flux Kontext models only work in img2img / inpaint mode and cannot work without --image-seeds.'
)
if model_type == _enums.ModelType.FLUX_FILL:
raise UnsupportedPipelineConfigError(
'Flux Fill models only work in inpaint mode and cannot work without --image-seeds.'
)
if model_type == _enums.ModelType.IF:
pipeline_class = diffusers.IFPipeline
elif model_type == _enums.ModelType.IFS:
if not help_mode:
raise UnsupportedPipelineConfigError(
'Deep Floyd IF super-resolution (IFS) only works in '
'img2img mode and cannot work without --image-seeds.')
else:
pipeline_class = diffusers.IFSuperResolutionPipeline
elif model_type == _enums.ModelType.S_CASCADE:
pipeline_class = diffusers.StableCascadePriorPipeline
elif model_type == _enums.ModelType.S_CASCADE_DECODER:
pipeline_class = diffusers.StableCascadeDecoderPipeline
elif model_type == _enums.ModelType.FLUX:
if controlnet_uris:
pipeline_class = diffusers.FluxControlNetPipeline
else:
pipeline_class = diffusers.FluxPipeline
elif model_type == _enums.ModelType.SD3:
if pag:
pipeline_class = diffusers.StableDiffusion3PAGPipeline
elif controlnet_uris:
if pag:
raise UnsupportedPipelineConfigError(
'Stable Diffusion 3 does not support --pag with controlnets.')
pipeline_class = diffusers.StableDiffusion3ControlNetPipeline
else:
pipeline_class = diffusers.StableDiffusion3Pipeline
elif model_type == _enums.ModelType.KOLORS:
if controlnet_uris:
if pag:
raise UnsupportedPipelineConfigError(
'Kolors ControlNet mode does not support PAG')
else:
pipeline_class = _kolors.KolorsControlNetPipeline
else:
if pag:
pipeline_class = diffusers.KolorsPAGPipeline
else:
pipeline_class = diffusers.KolorsPipeline
elif t2i_adapter_uris:
# The custom type is a hack to support from_single_file for SD1.5 - 2
# models with the associated pipeline class which does not inherit
# the correct mixin to do so but can use the mixin just fine
pipeline_class = (
diffusers.StableDiffusionXLAdapterPipeline
if is_sdxl
else diffusers.StableDiffusionAdapterPipeline
)
elif controlnet_uris:
if is_sdxl:
if pag:
if sdxl_controlnet_union:
raise UnsupportedPipelineConfigError(
'SDXL ControlNet Union mode does not support PAG')
pipeline_class = diffusers.StableDiffusionXLControlNetPAGPipeline
else:
if sdxl_controlnet_union:
pipeline_class = \
diffusers.StableDiffusionXLControlNetUnionPipeline
else:
pipeline_class = diffusers.StableDiffusionXLControlNetPipeline
else:
if pag:
pipeline_class = diffusers.StableDiffusionControlNetPAGPipeline
else:
pipeline_class = diffusers.StableDiffusionControlNetPipeline
else:
if is_sdxl:
if pag:
pipeline_class = diffusers.StableDiffusionXLPAGPipeline
else:
pipeline_class = diffusers.StableDiffusionXLPipeline
else:
if pag:
pipeline_class = diffusers.StableDiffusionPAGPipeline
else:
pipeline_class = diffusers.StableDiffusionPipeline
elif pipeline_type == _enums.PipelineType.IMG2IMG:
if controlnet_uris:
if is_pix2pix:
raise UnsupportedPipelineConfigError(
'Pix2Pix models are not compatible with --control-nets.')
if model_type == _enums.ModelType.FLUX_FILL:
raise UnsupportedPipelineConfigError(
'Flux Fill models only work in inpaint mode.'
)
if is_pix2pix:
if is_sdxl:
# noinspection PyUnusedLocal
pipeline_class = diffusers.StableDiffusionXLInstructPix2PixPipeline
elif is_sd3:
# noinspection PyUnusedLocal
pipeline_class = dgenerate.extras.ultraedit.StableDiffusion3InstructPix2PixPipeline
else:
# noinspection PyUnusedLocal
pipeline_class = diffusers.StableDiffusionInstructPix2PixPipeline
elif model_type == _enums.ModelType.IF:
pipeline_class = diffusers.IFImg2ImgPipeline
elif model_type == _enums.ModelType.IFS:
pipeline_class = diffusers.IFSuperResolutionPipeline
elif model_type == _enums.ModelType.IFS_IMG2IMG:
pipeline_class = diffusers.IFImg2ImgSuperResolutionPipeline
elif model_type == _enums.ModelType.S_CASCADE:
pipeline_class = diffusers.StableCascadePriorPipeline
elif model_type == _enums.ModelType.S_CASCADE_DECODER:
raise UnsupportedPipelineConfigError(
'Stable Cascade decoder models do not support img2img.')
elif model_type == _enums.ModelType.FLUX:
if controlnet_uris:
pipeline_class = diffusers.FluxControlNetImg2ImgPipeline
else:
pipeline_class = diffusers.FluxImg2ImgPipeline
elif model_type == _enums.ModelType.FLUX_KONTEXT:
if controlnet_uris:
raise UnsupportedPipelineConfigError(
'--model-type flux-kontext does not support ControlNet models.'
)
pipeline_class = diffusers.FluxKontextPipeline
elif model_type == _enums.ModelType.SD3:
if controlnet_uris:
raise UnsupportedPipelineConfigError(
'--model-type sd3 does not support img2img mode with ControlNet models.')
if pag:
pipeline_class = diffusers.StableDiffusion3PAGImg2ImgPipeline
else:
pipeline_class = diffusers.StableDiffusion3Img2ImgPipeline
elif model_type == _enums.ModelType.KOLORS:
if controlnet_uris:
if pag:
raise UnsupportedPipelineConfigError(
'Kolors ControlNet does not support PAG in img2img mode'
)
pipeline_class = _kolors.KolorsControlNetImg2ImgPipeline
else:
if pag:
raise UnsupportedPipelineConfigError(
'Kolors does not support PAG in img2img mode'
)
pipeline_class = diffusers.KolorsImg2ImgPipeline
elif t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'img2img mode is not supported with --t2i-adapters.')
elif controlnet_uris:
if is_sdxl:
if pag:
if sdxl_controlnet_union:
raise UnsupportedPipelineConfigError(
'SDXL ControlNet Union mode does not support PAG')
pipeline_class = diffusers.StableDiffusionXLControlNetPAGImg2ImgPipeline
else:
if sdxl_controlnet_union:
pipeline_class = \
diffusers.StableDiffusionXLControlNetUnionImg2ImgPipeline
else:
pipeline_class = diffusers.StableDiffusionXLControlNetImg2ImgPipeline
else:
if pag:
raise UnsupportedPipelineConfigError(
'--model-type sd (Stable Diffusion 1.5 - 2.*) '
'does not support --pag in img2img mode with ControlNet models.')
else:
pipeline_class = diffusers.StableDiffusionControlNetImg2ImgPipeline
else:
if is_sdxl:
if pag:
pipeline_class = diffusers.StableDiffusionXLPAGImg2ImgPipeline
else:
pipeline_class = diffusers.StableDiffusionXLImg2ImgPipeline
else:
if pag:
raise UnsupportedPipelineConfigError(
'--model-type sd (Stable Diffusion 1.5 - 2.*) '
'does not support --pag in img2img mode.')
else:
pipeline_class = diffusers.StableDiffusionImg2ImgPipeline
elif pipeline_type == _enums.PipelineType.INPAINT:
if _enums.model_type_is_s_cascade(model_type):
raise UnsupportedPipelineConfigError(
'Stable Cascade model types do not support inpainting.')
if _enums.model_type_is_upscaler(model_type):
raise UnsupportedPipelineConfigError(
'Stable Diffusion upscaler model types do not support inpainting.')
if is_pix2pix:
if is_sdxl:
# noinspection PyUnusedLocal
pipeline_class = dgenerate.extras.ultraedit.StableDiffusionXLInstructPix2PixPipeline
elif is_sd3:
# noinspection PyUnusedLocal
pipeline_class = dgenerate.extras.ultraedit.StableDiffusion3InstructPix2PixPipeline
else:
# noinspection PyUnusedLocal
pipeline_class = dgenerate.extras.ultraedit.StableDiffusionInstructPix2PixPipeline
elif model_type == _enums.ModelType.FLUX:
if controlnet_uris:
pipeline_class = diffusers.FluxControlNetInpaintPipeline
else:
pipeline_class = diffusers.FluxInpaintPipeline
elif model_type == _enums.ModelType.FLUX_FILL:
if controlnet_uris:
raise UnsupportedPipelineConfigError(
'--model-type flux-fill does not support ControlNet models.')
pipeline_class = diffusers.FluxFillPipeline
elif model_type == _enums.ModelType.FLUX_KONTEXT:
if controlnet_uris:
raise UnsupportedPipelineConfigError(
'--model-type flux-kontext does not support ControlNet models.')
pipeline_class = diffusers.FluxKontextInpaintPipeline
elif model_type == _enums.ModelType.IF:
pipeline_class = diffusers.IFInpaintingPipeline
elif model_type == _enums.ModelType.IFS:
pipeline_class = diffusers.IFInpaintingSuperResolutionPipeline
elif model_type == _enums.ModelType.SD3:
if controlnet_uris:
return diffusers.StableDiffusion3ControlNetInpaintingPipeline
if pag:
raise UnsupportedPipelineConfigError(
'--model-type sd3 does not support --pag in inpaint mode.'
)
pipeline_class = diffusers.StableDiffusion3InpaintPipeline
elif model_type == _enums.ModelType.KOLORS:
if controlnet_uris:
if pag:
raise UnsupportedPipelineConfigError(
'Kolors ControlNet does not support PAG in inpaint mode'
)
pipeline_class = _kolors.KolorsControlNetInpaintPipeline
else:
if pag:
raise UnsupportedPipelineConfigError(
'Kolors does not support PAG in inpaint mode'
)
pipeline_class = _kolors.KolorsInpaintPipeline
elif t2i_adapter_uris:
raise UnsupportedPipelineConfigError(
'inpaint mode is not supported with --t2i-adapters.')
elif controlnet_uris:
if is_sdxl:
if pag:
raise UnsupportedPipelineConfigError(
'--model-type sdxl does not support --pag '
'in inpaint mode with ControlNet models.'
)
else:
if sdxl_controlnet_union:
pipeline_class = \
diffusers.StableDiffusionXLControlNetUnionInpaintPipeline
else:
pipeline_class = diffusers.StableDiffusionXLControlNetInpaintPipeline
else:
if pag:
pipeline_class = diffusers.StableDiffusionControlNetPAGInpaintPipeline
else:
pipeline_class = diffusers.StableDiffusionControlNetInpaintPipeline
else:
if is_sdxl:
if pag:
pipeline_class = diffusers.StableDiffusionXLPAGInpaintPipeline
else:
pipeline_class = diffusers.StableDiffusionXLInpaintPipeline
else:
if pag:
raise UnsupportedPipelineConfigError(
'--model-type sd (Stable Diffusion 1.5 - 2.*) '
'does not support --pag in inpaint mode.')
else:
pipeline_class = diffusers.StableDiffusionInpaintPipeline
else:
# Should be impossible
raise UnsupportedPipelineConfigError('Pipeline type not implemented.')
if lora_uris and not pipeline_class_supports_lora(pipeline_class):
raise UnsupportedPipelineConfigError(
f'Given current arguments, '
f'--model-type {_enums.get_model_type_string(model_type)} '
f'(pipeline: {pipeline_class.__name__}) does not support LoRAs.')
if textual_inversion_uris and not pipeline_class_supports_textual_inversion(pipeline_class):
raise UnsupportedPipelineConfigError(
f'Given current arguments, '
f'--model-type {_enums.get_model_type_string(model_type)} '
f'(pipeline: {pipeline_class.__name__}) does not support Textual Inversions.')
if ip_adapter_uris and not pipeline_class_supports_ip_adapter(pipeline_class):
raise UnsupportedPipelineConfigError(
f'Given current arguments, '
f'--model-type {_enums.get_model_type_string(model_type)} '
f'(pipeline: {pipeline_class.__name__}) does not support IP Adapters.')
return pipeline_class
def _enforce_pipeline_cache_size(new_pipeline_size):
_pipeline_cache.enforce_cpu_mem_constraints(
_constants.PIPELINE_CACHE_MEMORY_CONSTRAINTS,
size_var='pipeline_size',
new_object_size=new_pipeline_size)
@_memoize(_pipeline_cache,
exceptions={'local_files_only', 'missing_submodules_ok'},
hasher=_pipeline_args_hasher,
extra_identities=[lambda m: m.pipeline],
on_hit=_pipeline_on_hit,
on_create=_pipeline_on_create)
def _create_diffusion_pipeline(
model_path: str,
model_type: _enums.ModelType = _enums.ModelType.SD,
pipeline_type: _enums.PipelineType = _enums.PipelineType.TXT2IMG,
revision: _types.OptionalString = None,
variant: _types.OptionalString = None,
subfolder: _types.OptionalString = None,
dtype: _enums.DataType = _enums.DataType.AUTO,
unet_uri: _types.OptionalUri = None,
transformer_uri: _types.OptionalUri = None,
vae_uri: _types.OptionalUri = None,
lora_uris: _types.OptionalUris = None,
lora_fuse_scale: _types.OptionalFloat = None,
image_encoder_uri: _types.OptionalUri = None,
ip_adapter_uris: _types.OptionalUris = None,
textual_inversion_uris: _types.OptionalUris = None,
text_encoder_uris: _types.OptionalUris = None,
controlnet_uris: _types.OptionalUris = None,
t2i_adapter_uris: _types.OptionalUris = None,
quantizer_uri: _types.OptionalUri = None,
quantizer_map: _types.OptionalStrings = None,
pag: bool = False,
safety_checker: bool = False,
original_config: _types.OptionalString = None,
auth_token: _types.OptionalString = None,
device: str = _torchutil.default_device(),
extra_modules: dict[str, typing.Any] | None = None,
model_cpu_offload: bool = False,
sequential_cpu_offload: bool = False,
local_files_only: bool = False,
missing_submodules_ok: bool = False
) -> PipelineCreationResult:
# Ensure model path is specified
if not model_path:
raise ValueError('model_path must be specified.')
# Offload checks
if model_cpu_offload and sequential_cpu_offload:
raise UnsupportedPipelineConfigError(
'model_cpu_offload and sequential_cpu_offload may not be enabled simultaneously.')
# Device check
if not _torchutil.is_valid_device_string(device):
raise UnsupportedPipelineConfigError(
f'Invalid device argument, {_torchutil.invalid_device_message(device, cap=False)}')
# Quantizer map check
quantizer_map_vals = [
'unet',
'transformer',
'text_encoder',
'text_encoder_2',
'text_encoder_3',
'controlnet'
]
if quantizer_map is not None:
for map_value in quantizer_map:
if map_value not in quantizer_map_vals:
raise UnsupportedPipelineConfigError(
f'Unknown quantizer_map value: {map_value}, '
f'must be one of: {_textprocessing.oxford_comma(quantizer_map_vals, "or")}'
)
if not _hfhub.is_single_file_model_load(model_path) and original_config:
raise UnsupportedPipelineConfigError(
'Loading original config .yaml file is not supported '
'when loading from a Hugging Face repo.'
)
if quantizer_uri and quantizer_uri.split(';')[0].strip() in {'bnb', 'bitsandbytes'}:
if dtype is _enums.DataType.AUTO:
# Default to all modules to float32 if no dtype is specified when using bitsandbytes
dtype = _enums.DataType.FLOAT32
if original_config:
# only instance where it really makes sense to raise config not found
# everything else is indicative of a missing model
with _hfhub.with_hf_errors_as_config_not_found():
original_config = _hfhub.download_non_hf_slug_config(original_config)
model_path = _hfhub.download_non_hf_slug_model(model_path)
if (_hfhub.is_single_file_model_load(model_path)
and _enums.model_type_is_kolors(model_type)
):
raise UnsupportedPipelineConfigError(
'Kolors models cannot be loaded from a single file.'
)
model_index = _util.fetch_model_index_dict(
model_path,
subfolder=subfolder,
revision=revision,
use_auth_token=auth_token,
local_files_only=local_files_only
)
if '_class_name' in model_index:
model_class_name = model_index['_class_name']
model_checks = [
(_enums.model_type_is_flux, ('^Flux.*', 'Flux')),
(_enums.model_type_is_sd3, ('^StableDiffusion3.*', 'Stable Diffusion 3')),
(_enums.model_type_is_sdxl, ('^StableDiffusionXL.*', 'Stable Diffusion XL')),
(_enums.model_type_is_sd15, ('^StableDiffusion[^X3].*', 'Stable Diffusion')),
(_enums.model_type_is_sd2, ('^StableDiffusion[^X3].*', 'Stable Diffusion')),
(_enums.model_type_is_s_cascade, ('^StableCascade.*', 'Stable Cascade')),
(_enums.model_type_is_kolors, ('^Kolors.*', 'Kolors')),
(_enums.model_type_is_floyd, ('^IF.*', 'Deep Floyd')),
]
# exceptions to the rules above
# where left and right evaluate True
model_fallback_checks = [
(lambda x: x == _enums.ModelType.SD, '^LatentConsistency.*')
]
for check_func, (pattern, title) in model_checks:
if check_func(model_type) and re.match(pattern, model_class_name) is None:
# are there any exceptions to this such as legacy configs?
if not any(
check(model_type) and
re.match(pattern, model_class_name) is not None
for check, pattern in model_fallback_checks
):
raise UnsupportedPipelineConfigError(
f'{model_path} is not a {title} model, '
f'incorrect --model-type value: {_enums.get_model_type_string(model_type)}'
)
pipeline_class = get_pipeline_class(
model_type=model_type,
pipeline_type=pipeline_type,
unet_uri=unet_uri,
transformer_uri=transformer_uri,
vae_uri=vae_uri,
lora_uris=lora_uris,
image_encoder_uri=image_encoder_uri,
ip_adapter_uris=ip_adapter_uris,
textual_inversion_uris=textual_inversion_uris,
controlnet_uris=controlnet_uris,
t2i_adapter_uris=t2i_adapter_uris,
pag=pag
)
text_encoder_count = len(
[a for a in inspect.getfullargspec(pipeline_class.__init__).args if a.startswith('text_encoder')])
if not text_encoder_uris:
text_encoder_uris = []
if len(text_encoder_uris) > text_encoder_count:
raise UnsupportedPipelineConfigError('To many text encoder URIs specified.')
if extra_modules is not None:
_messages.debug_log('Checking extra_modules for meta tensors...')
for module in extra_modules.items():
if module[1] is None:
continue
_messages.debug_log(f'Checking extra module {module[0]} = {module[1].__class__}...')
try:
if get_torch_device(module[1]).type == 'meta':
_messages.debug_log(f'"{module[0]}" has meta tensors.')
_disable_to(
module[1],
module[0] == 'vae'
)
except ValueError:
_messages.debug_log(
f'Unable to get device of {module[0]} = {module[1].__class__}')
extra_modules = extra_modules.copy()
else:
extra_modules = dict()
unet_override = 'unet' in extra_modules
vae_override = 'vae' in extra_modules
controlnet_override = 'controlnet' in extra_modules
adapter_override = 'adapter' in extra_modules
image_encoder_override = 'image_encoder' in extra_modules
safety_checker_override = 'safety_checker' in extra_modules
transformer_override = 'transformer' in extra_modules
if 'text_encoder' in extra_modules and text_encoder_count == 0:
raise UnsupportedPipelineConfigError('To many text encoders specified.')
if 'text_encoder_2' in extra_modules and text_encoder_count < 2:
raise UnsupportedPipelineConfigError('To many text encoders specified.')
if 'text_encoder_3' in extra_modules and text_encoder_count < 3:
raise UnsupportedPipelineConfigError('To many text encoders specified.')
# noinspection PyTypeChecker
text_encoders: list[str] = list(text_encoder_uris)
if len(text_encoders) > 0 and _text_encoder_null(text_encoders[0]):
extra_modules['text_encoder'] = None
if len(text_encoders) > 1 and _text_encoder_null(text_encoders[1]):
extra_modules['text_encoder_2'] = None
if len(text_encoders) > 2 and _text_encoder_null(text_encoders[2]):
extra_modules['text_encoder_3'] = None
text_encoder_override = 'text_encoder' in extra_modules
text_encoder_2_override = 'text_encoder_2' in extra_modules
text_encoder_3_override = 'text_encoder_3' in extra_modules
if len(text_encoders) > 0 and text_encoder_override:
text_encoders[0] = None
if len(text_encoders) > 1 and text_encoder_2_override:
text_encoders[1] = None
if len(text_encoders) > 2 and text_encoder_3_override:
text_encoders[2] = None
# If we are not auto caching UNet/Transformer, VAE, and text encoders
# into other caches, we should estimate a cache size for the pipeline
# that includes them
pipeline_cached_with_submodules = \
model_cpu_offload or sequential_cpu_offload or quantizer_uri
estimated_memory_usage = estimate_pipeline_cache_footprint(
model_type=model_type,
model_path=model_path,
revision=revision,
variant=variant,
subfolder=subfolder,
include_unet_or_transformer=pipeline_cached_with_submodules or bool(lora_uris),
include_vae=pipeline_cached_with_submodules,
include_text_encoders=pipeline_cached_with_submodules or bool(lora_uris),
lora_uris=lora_uris,
image_encoder_uri=image_encoder_uri,
ip_adapter_uris=ip_adapter_uris,
textual_inversion_uris=textual_inversion_uris,
safety_checker=safety_checker and not safety_checker_override,
auth_token=auth_token,
extra_args=extra_modules,
local_files_only=local_files_only
)
_messages.debug_log(
f'Creating Torch Pipeline: "{pipeline_class.__name__}", '
f'Estimated CPU Side Memory Use: {_memory.bytes_best_human_unit(estimated_memory_usage)}')
# Helper function to determine if quantization should be applied to a module
def should_apply_quantizer(module_name):
if not quantizer_uri:
return False
if quantizer_map is None:
return True
return module_name in quantizer_map
# Helper function to determine device_map - always use the selected device currently
def get_device_map_for_quantizer(quantizer_uri):
return device if quantizer_uri else None
# we need to manually emulate bitsandbytes 'compute_dtype'
# casting for sdnq as it has trouble dequanting to anything
# but float16 on forward which messes with diffusers
# in various ways when a model is loaded in float32
sdnq_cast_hack = False
if quantizer_uri and (
_uris.get_quantizer_uri_class(quantizer_uri) is
_uris.SDNQQuantizerUri
):
sdnq_cast_hack = True
uri_quant_check = []
manual_quantizer_components = set()
# Check text encoder URIs
for idx, encoder_uri in enumerate(text_encoder_uris):
if encoder_uri and encoder_uri.lower() not in {'+', 'help', 'null'}:
parsed_uri = _uris.TextEncoderUri.parse(encoder_uri)
uri_quant_check.append(parsed_uri)
if parsed_uri.quantizer:
encoder_name = f'text_encoder{"_2" if idx == 1 else "_3" if idx == 2 else ""}'
manual_quantizer_components.add(encoder_name)
quantizer_class = _uris.get_quantizer_uri_class(parsed_uri.quantizer)
if quantizer_class is _uris.SDNQQuantizerUri:
sdnq_cast_hack = True
# Check transformer URI
if transformer_uri:
parsed_uri = _uris.TransformerUri.parse(transformer_uri)
uri_quant_check.append(parsed_uri)
if parsed_uri.quantizer:
manual_quantizer_components.add('transformer')
quantizer_class = _uris.get_quantizer_uri_class(parsed_uri.quantizer)
if quantizer_class is _uris.SDNQQuantizerUri:
sdnq_cast_hack = True
# Check unet URI
if unet_uri:
parsed_uri = _uris.UNetUri.parse(unet_uri)
uri_quant_check.append(parsed_uri)
if parsed_uri.quantizer:
manual_quantizer_components.add('unet')
quantizer_class = _uris.get_quantizer_uri_class(parsed_uri.quantizer)
if quantizer_class is _uris.SDNQQuantizerUri:
sdnq_cast_hack = True
# Check controlnet URIs
if controlnet_uris:
for controlnet_uri in controlnet_uris:
parsed_uri = _uris.ControlNetUri.parse(controlnet_uri, model_type=model_type)
uri_quant_check.append(parsed_uri)
if parsed_uri.quantizer:
manual_quantizer_components.add('controlnet')
quantizer_class = _uris.get_quantizer_uri_class(parsed_uri.quantizer)
if quantizer_class is _uris.SDNQQuantizerUri:
sdnq_cast_hack = True
if quantizer_uri or any(p.quantizer for p in uri_quant_check):
# for now, just knock out anything cached on the gpu, such as the last pipeline
# the quantized pipeline modules are likely going to go straight onto the GPU
# immediately, and they are guaranteed to be of non-trivial size
_devicecache.clear_device_cache(device)
# Granular component caching for quantization scenarios
cached_component_paths = {}
manual_quantizers = {}
components_to_cache = []
# Determine if we need granular caching
needs_granular_caching = (quantizer_uri and (_hfhub.is_single_file_model_load(model_path) or lora_uris)) or \
(manual_quantizer_components and lora_uris)
if needs_granular_caching:
# Determine which components to cache
if lora_uris:
# For LoRA + quantizer: cache only components affected by LoRA
components_to_cache = _get_affected_components_for_lora(pipeline_class)
# Also include any manual components with quantizers that are LoRA-affected
lora_affected_components = set(components_to_cache)
for component in manual_quantizer_components:
if component in lora_affected_components:
if component not in components_to_cache:
components_to_cache.append(component)
else:
# For single file + quantizer: cache all quantizable components
components_to_cache = _get_quantizable_components(model_type)
# Filter based on quantizer_map ONLY for components that would get auto-quantized
# Components with manual quantizers should always be cached when LoRAs are involved
if quantizer_map is not None:
# Keep components that either:
# 1. Have manual quantizers (always cache these when LoRAs are involved)
# 2. Are in the quantizer_map (for auto-quantization)
if lora_uris:
# When LoRAs are involved, always cache components with manual quantizers
components_to_cache = [c for c in components_to_cache
if c in manual_quantizer_components or c in quantizer_map]
else:
# For single file without LoRAs, only respect quantizer_map
components_to_cache = [c for c in components_to_cache if c in quantizer_map]
# Collect manual URIs for components that will be cached
manual_component_uris = {}
if unet_uri and 'unet' in components_to_cache:
manual_component_uris['unet'] = unet_uri
if transformer_uri and 'transformer' in components_to_cache:
manual_component_uris['transformer'] = transformer_uri
# Handle text encoder URIs
if text_encoder_uris:
for idx, encoder_uri in enumerate(text_encoder_uris):
if not _text_encoder_default(encoder_uri):
encoder_name = f'text_encoder{"_2" if idx == 1 else "_3" if idx == 2 else ""}'
if encoder_name in components_to_cache:
manual_component_uris[encoder_name] = encoder_uri
# Perform granular caching
if components_to_cache:
cached_component_paths, manual_quantizers = _cache_components_granular(
model_path=model_path,
model_type=model_type,
pipeline_type=pipeline_type,
components_to_cache=components_to_cache,
lora_uris=lora_uris,
lora_fuse_scale=lora_fuse_scale,
revision=revision,
variant=variant,
subfolder=subfolder,
dtype=dtype,
original_config=original_config,
auth_token=auth_token,
manual_component_uris=manual_component_uris,
vae_uri=vae_uri
)
_messages.debug_log(
f"Granular caching complete. Cached components: {list(cached_component_paths.keys())}")
if manual_component_uris:
_messages.debug_log(
f"Manual component URIs processed: {list(manual_component_uris.keys())}")
if manual_quantizers:
_messages.debug_log(
f"Manual quantizers preserved: {list(manual_quantizers.keys())}")
else:
manual_quantizers = {}
# ControlNet and VAE loading
# Used during pipeline load
creation_kwargs = {}
torch_dtype = _enums.get_torch_dtype(dtype)
parsed_controlnet_uris = []
parsed_t2i_adapter_uris = []
parsed_image_encoder_uri = None
parsed_unet_uri = None
parsed_vae_uri = None
parsed_transformer_uri = None
pipe_params = inspect.signature(pipeline_class.__init__).parameters
def load_text_encoder(uri: _uris.TextEncoderUri):
return uri.load(
variant_fallback=variant,
dtype_fallback=dtype,
original_config=original_config,
use_auth_token=auth_token,
local_files_only=local_files_only,
no_cache=bool(lora_uris) or model_cpu_offload or sequential_cpu_offload,
missing_ok=missing_submodules_ok,
device_map=get_device_map_for_quantizer(uri.quantizer)
)
def load_vae(uri: _uris.VAEUri):
vae_model = uri.load(
dtype_fallback=dtype,
original_config=original_config,
use_auth_token=auth_token,
local_files_only=local_files_only,
no_cache=model_cpu_offload or sequential_cpu_offload,
missing_ok=missing_submodules_ok
)
if sdnq_cast_hack:
og_decode = vae_model.decode
def sdnq_decode(latents, *args, **kwargs):
if getattr(vae_model.config, 'use_post_quant_conv', False):
cur_dtype = vae_model.post_quant_conv.weight.dtype
else:
cur_dtype = _enums.get_torch_dtype(dtype)
return og_decode(latents.to(
dtype=vae_model.dtype if cur_dtype is None else cur_dtype), *args, **kwargs)
vae_model.decode = sdnq_decode
return vae_model
def sdnq_forward(og_forward, model, *args, **kwargs):
args = list(args)
for i, arg in enumerate(args):
if isinstance(arg, torch.Tensor):
args[i] = arg.to(dtype=model.dtype)
for k,v in kwargs.items():
if isinstance(v, torch.Tensor):
kwargs[k] = v.to(dtype=model.dtype)
return og_forward(*args, **kwargs)
def controlnet_quant_forward(og_forward, model, *args, **kwargs):
"""
Forward function for quantized controlnets that casts inputs to the model's dtype.
This is needed because diffusers doesn't handle controlnet quantization state internally.
Note: Preserves specific parameters that are used as indices in embedding layers.
Specifically, 'controlnet_mode' parameter in FLUX ControlNets is used as indices
in the controlnet_mode_embedder embedding layer and must remain as integer tensors.
"""
# Known parameters that should remain as integer types for embedding layer indices
# controlnet_mode: Used in FLUX ControlNet Union models with nn.Embedding layer
# Note: SDXL ControlNet Union uses control_type with Timesteps projection (no integer requirement)
embedding_index_params = {'controlnet_mode'}
# Cast positional arguments to model dtype (no special handling needed)
args = list(args)
for i, arg in enumerate(args):
if isinstance(arg, torch.Tensor):
args[i] = arg.to(dtype=model.dtype)
# Handle keyword arguments with special cases for embedding indices
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
# Preserve specific embedding index parameters or any integer tensor
if (k in embedding_index_params or
v.dtype in (torch.int64, torch.long, torch.int32, torch.int)):
continue # Don't cast embedding index tensors
kwargs[k] = v.to(dtype=model.dtype)
return og_forward(*args, **kwargs)
def load_unet(uri: _uris.UNetUri, unet_class):
unet_model = uri.load(
variant_fallback=variant,
dtype_fallback=dtype,
original_config=original_config,
use_auth_token=auth_token,
local_files_only=local_files_only,
no_cache=bool(lora_uris) or
bool(ip_adapter_uris) or
model_cpu_offload or
sequential_cpu_offload,
device_map=get_device_map_for_quantizer(uri.quantizer),
unet_class=unet_class
)
if sdnq_cast_hack:
unet_model.forward = functools.partial(
sdnq_forward,
unet_model.forward,
unet_model
)
return unet_model
def load_transformer(uri: _uris.TransformerUri, transformer_class):
transformer_model = uri.load(
variant_fallback=variant,
dtype_fallback=dtype,
original_config=original_config,
use_auth_token=auth_token,
local_files_only=local_files_only,
no_cache=bool(lora_uris) or model_cpu_offload or sequential_cpu_offload,
device_map=get_device_map_for_quantizer(uri.quantizer),
transformer_class=transformer_class
)
if sdnq_cast_hack:
transformer_model.forward = functools.partial(
sdnq_forward,
transformer_model.forward,
transformer_model
)
return transformer_model
def load_default_text_encoder(encoder, encoder_name):
should_quantize = should_apply_quantizer(encoder_name)
# Check if we have a cached version from granular caching
if encoder_name in cached_component_paths:
# Use manual quantizer if available, otherwise use global quantizer
text_encoder_quantizer = manual_quantizers.get(encoder_name, quantizer_uri if should_quantize else None)
return load_text_encoder(
_uris.TextEncoderUri(
encoder=encoder,
model=cached_component_paths[encoder_name],
variant=variant,
dtype=dtype,
quantizer=text_encoder_quantizer
)
)
else:
if _hfhub.is_single_file_model_load(model_path) and should_quantize:
raise UnsupportedPipelineConfigError(
'Cannot use global --quantizer URI when attempting to '
f'load default text encoder for "{encoder_name}" from a single file checkpoint. '
f'You must specify this text encoder manually with --text-encoders, it is likely missing '
f'from the model checkpoint.'
)
return load_text_encoder(
_uris.TextEncoderUri(
encoder=encoder,
model=model_path,
variant=variant,
revision=revision,
subfolder=encoder_subfolder,
dtype=dtype,
quantizer=quantizer_uri if should_quantize else None
)
)
text_encoder_override_states = [
text_encoder_override,
text_encoder_2_override,
text_encoder_3_override
]
# Load Text Encoders
for idx, (name, param) in enumerate(
[n for n in sorted(model_index.items(), key=lambda x: x[0])
if n[0].startswith('text_encoder') and n[1][0] is not None]):
if text_encoder_override_states[idx]:
continue
if _hfhub.is_single_file_model_load(model_path):
encoder_subfolder = name
else:
encoder_subfolder = os.path.join(subfolder, name) if subfolder else name
encoder_class = param[1]
if text_encoder_uris and len(text_encoder_uris) > idx:
encoder_uri = text_encoder_uris[idx]
if _text_encoder_default(encoder_uri):
creation_kwargs[name] = load_default_text_encoder(encoder_class, name)
elif not (needs_granular_caching and name in components_to_cache):
creation_kwargs[name] = load_text_encoder(
_uris.TextEncoderUri.parse(encoder_uri)
)
# If granular caching is needed and this encoder is in components_to_cache,
# skip direct loading - it will be handled by granular caching
else:
creation_kwargs[name] = load_default_text_encoder(encoder_class, name)
# Load VAE
if not vae_override:
if vae_uri:
parsed_vae_uri = _uris.VAEUri.parse(vae_uri)
creation_kwargs['vae'] = load_vae(parsed_vae_uri)
_messages.debug_log(lambda:
f'Added Torch VAE: "{vae_uri}" to pipeline: "{pipeline_class.__name__}"')
elif 'vae' in pipe_params:
if _hfhub.is_single_file_model_load(model_path):
vae_subfolder = 'vae'
else:
vae_subfolder = os.path.join(subfolder, 'vae') if subfolder else 'vae'
vae_param = pipe_params['vae'].annotation
vae_encoder_name = vae_param.__name__
if _types.is_union(vae_param):
try:
vae_encoder_name = model_index['vae'][1]
except (KeyError, IndexError):
_messages.debug_log(
'Skipping auto VAE caching due to model '
'configuration not having a VAE key.')
vae_encoder_name = None
except FileNotFoundError as e:
raise UnsupportedPipelineConfigError(
'Could not find VAE configuration data.') from e
if vae_encoder_name not in _uris.VAEUri.supported_encoder_names():
raise UnsupportedPipelineConfigError(
f'Unsupported VAE encoder type: {vae_encoder_name}'
)
if vae_encoder_name is not None:
vae_extract_from_checkpoint = _hfhub.is_single_file_model_load(model_path)
try:
creation_kwargs['vae'] = \
load_vae(_uris.VAEUri(
encoder=vae_encoder_name,
model=model_path,
variant=variant,
revision=revision,
subfolder=vae_subfolder,
extract=vae_extract_from_checkpoint,
dtype=dtype
))
except _d_exceptions.ModelNotFoundError:
if vae_extract_from_checkpoint:
raise
creation_kwargs['vae'] = \
load_vae(_uris.VAEUri(
encoder=vae_encoder_name,
model=model_path,
revision=revision,
subfolder=vae_subfolder,
dtype=dtype
))
# Load UNet
if not unet_override:
unet_parameter = 'unet'
if model_type == _enums.ModelType.S_CASCADE:
unet_parameter = 'prior'
elif model_type == _enums.ModelType.S_CASCADE_DECODER:
unet_parameter = 'decoder'
unet_class = diffusers.UNet2DConditionModel if unet_parameter == 'unet' \
else diffusers.models.unets.StableCascadeUNet
if unet_uri is not None and not (needs_granular_caching and 'unet' in components_to_cache):
parsed_unet_uri = _uris.UNetUri.parse(unet_uri)
if _enums.model_type_is_kolors(model_type) and parsed_unet_uri.quantizer:
raise UnsupportedPipelineConfigError(
f'--model-type {_enums.get_model_type_string(model_type)} does not support '
f'loading a --unet with quantization applied.'
)
creation_kwargs[unet_parameter] = load_unet(
parsed_unet_uri, unet_class=unet_class
)
_messages.debug_log(lambda:
f'Added Torch UNet: "{unet_uri}" to pipeline: "{pipeline_class.__name__}"')
elif 'unet' in pipe_params:
if _hfhub.is_single_file_model_load(model_path):
unet_subfolder = unet_parameter
else:
unet_subfolder = os.path.join(subfolder, unet_parameter) if subfolder else unet_parameter
# Check if we have a cached version from granular caching
if 'unet' in cached_component_paths:
unet_model_path = cached_component_paths['unet']
unet_subfolder = None # No subfolder for cached component
unet_revision = None # No revision for cached component
# Use manual quantizer if available, otherwise use global quantizer
unet_quantizer = manual_quantizers.get('unet',
quantizer_uri if should_apply_quantizer("unet") else None)
else:
unet_model_path = model_path
unet_revision = revision
unet_quantizer = quantizer_uri if should_apply_quantizer("unet") else None
creation_kwargs['unet'] = \
load_unet(
_uris.UNetUri(
model=unet_model_path,
variant=variant,
revision=unet_revision,
subfolder=unet_subfolder,
dtype=dtype,
quantizer=unet_quantizer
), unet_class=unet_class)
# Load Transformer
if _enums.model_type_is_sd3(model_type):
transformer_class = diffusers.SD3Transformer2DModel
elif _enums.model_type_is_flux(model_type):
transformer_class = diffusers.FluxTransformer2DModel
else:
transformer_class = None
if not transformer_override:
if transformer_uri is not None and not (
needs_granular_caching and transformer_class is not None and 'transformer' in components_to_cache):
assert transformer_class is not None
parsed_transformer_uri = _uris.TransformerUri.parse(transformer_uri)
creation_kwargs['transformer'] = load_transformer(
parsed_transformer_uri,
transformer_class=transformer_class
)
_messages.debug_log(lambda:
f'Added Torch Transformer: "{transformer_uri}" to '
f'pipeline: "{pipeline_class.__name__}"')
elif 'transformer' in pipe_params:
assert transformer_class is not None
if _hfhub.is_single_file_model_load(model_path):
transformer_subfolder = 'transformer'
else:
transformer_subfolder = os.path.join(subfolder, 'transformer') if subfolder else 'transformer'
# Check if we have a cached version from granular caching
if 'transformer' in cached_component_paths:
transformer_model_path = cached_component_paths['transformer']
transformer_subfolder = None # No subfolder for cached component
transformer_revision = None # No revision for cached component
# Use manual quantizer if available, otherwise use global quantizer
transformer_quantizer = manual_quantizers.get('transformer', quantizer_uri if should_apply_quantizer(
"transformer") else None)
else:
transformer_model_path = model_path
transformer_revision = revision
transformer_quantizer = quantizer_uri if should_apply_quantizer("transformer") else None
creation_kwargs['transformer'] = load_transformer(
_uris.TransformerUri(
model=transformer_model_path,
variant=variant,
revision=transformer_revision,
subfolder=transformer_subfolder,
dtype=dtype,
quantizer=transformer_quantizer
), transformer_class=transformer_class)
# load image encoder
if image_encoder_uri is not None and not image_encoder_override:
parsed_image_encoder_uri = _uris.ImageEncoderUri.parse(image_encoder_uri)
if _enums.model_type_is_sd3(model_type):
# image encoder does not participate in offloading for SD3
no_cache_image_encoder = model_cpu_offload
else:
no_cache_image_encoder = model_cpu_offload or sequential_cpu_offload
loaded_image_encoder = parsed_image_encoder_uri.load(
dtype_fallback=dtype,
use_auth_token=auth_token,
local_files_only=local_files_only,
no_cache=no_cache_image_encoder,
image_encoder_class=
_models.SiglipImageEncoder
if _enums.model_type_is_sd3(model_type) else
transformers.CLIPVisionModelWithProjection
)
if isinstance(loaded_image_encoder, _models.SiglipImageEncoder):
creation_kwargs['image_encoder'] = loaded_image_encoder.image_encoder
creation_kwargs['feature_extractor'] = loaded_image_encoder.feature_extractor
else:
creation_kwargs['image_encoder'] = loaded_image_encoder
_messages.debug_log(lambda:
f'Added Torch Image Encoder: "{image_encoder_uri}" to '
f'pipeline: "{pipeline_class.__name__}"')
# Load T2I Adapters
if t2i_adapter_uris and not adapter_override:
t2i_adapters = None
for t2i_adapter_uri in t2i_adapter_uris:
parsed_t2i_adapter_uri = _uris.T2IAdapterUri.parse(t2i_adapter_uri)
parsed_t2i_adapter_uris.append(parsed_t2i_adapter_uri)
new_adapter = parsed_t2i_adapter_uri.load(
use_auth_token=auth_token,
dtype_fallback=dtype,
local_files_only=local_files_only,
no_cache=model_cpu_offload or sequential_cpu_offload
)
_messages.debug_log(lambda:
f'Added Torch T2IAdapter: "{t2i_adapter_uri}" '
f'to pipeline: "{pipeline_class.__name__}"')
if t2i_adapters is not None:
if not isinstance(t2i_adapters, list):
t2i_adapters = [t2i_adapters, new_adapter]
else:
t2i_adapters.append(new_adapter)
else:
t2i_adapters = new_adapter
if isinstance(t2i_adapters, list):
creation_kwargs['adapter'] = diffusers.MultiAdapter(t2i_adapters)
else:
creation_kwargs['adapter'] = t2i_adapters
# Load ControlNets
if controlnet_uris and not controlnet_override:
controlnets = None
sdxl_cn_union = None
for controlnet_uri in controlnet_uris:
parsed_controlnet_uri = _uris.ControlNetUri.parse(
uri=controlnet_uri,
model_type=model_type
)
parsed_controlnet_uris.append(parsed_controlnet_uri)
# Apply global quantizer if controlnet doesn't have
# its own quantizer and should be quantized
controlnet_uri_to_load = parsed_controlnet_uri
if not parsed_controlnet_uri.quantizer and should_apply_quantizer('controlnet'):
# Create a new URI with the global quantizer
controlnet_uri_to_load = _uris.ControlNetUri(
model=parsed_controlnet_uri.model,
revision=parsed_controlnet_uri.revision,
variant=parsed_controlnet_uri.variant,
subfolder=parsed_controlnet_uri.subfolder,
dtype=parsed_controlnet_uri.dtype,
scale=parsed_controlnet_uri.scale,
start=parsed_controlnet_uri.start,
end=parsed_controlnet_uri.end,
mode=parsed_controlnet_uri.mode,
quantizer=quantizer_uri,
model_type=parsed_controlnet_uri.model_type
)
new_net = controlnet_uri_to_load.load(
use_auth_token=auth_token,
dtype_fallback=dtype,
local_files_only=local_files_only,
no_cache=model_cpu_offload or sequential_cpu_offload,
device_map=get_device_map_for_quantizer(controlnet_uri_to_load.quantizer)
)
# Apply casting hack for quantized controlnets
if controlnet_uri_to_load.quantizer:
new_net.forward = functools.partial(
controlnet_quant_forward,
new_net.forward,
new_net
)
_messages.debug_log(lambda:
f'Added Torch ControlNet: "{controlnet_uri}" '
f'to pipeline: "{pipeline_class.__name__}"')
if sdxl_cn_union is not None:
continue
if isinstance(new_net, diffusers.ControlNetUnionModel):
# first model determines controlnet model,
# the rest of the specifications just provide the mode
sdxl_cn_union = new_net
continue
if controlnets is not None:
if not isinstance(controlnets, list):
controlnets = [controlnets, new_net]
else:
controlnets.append(new_net)
else:
controlnets = new_net
if sdxl_cn_union is not None:
controlnets = sdxl_cn_union
if isinstance(controlnets, list):
# not handled internally for whatever reason like the other pipelines
if _enums.model_type_is_sd3(model_type):
creation_kwargs['controlnet'] = diffusers.SD3MultiControlNetModel(controlnets)
elif _enums.model_type_is_flux(model_type):
creation_kwargs['controlnet'] = diffusers.FluxMultiControlNetModel(controlnets)
else:
creation_kwargs['controlnet'] = controlnets
else:
creation_kwargs['controlnet'] = controlnets
if _enums.model_type_is_floyd(model_type):
creation_kwargs['watermarker'] = None
if not safety_checker and \
(_enums.model_type_is_sd15(model_type) or
_enums.model_type_is_floyd(model_type)) and not safety_checker_override:
creation_kwargs['safety_checker'] = None
creation_kwargs.update(extra_modules)
def _handle_generic_pipeline_load_failure(e):
exc_msg = str(e)
_messages.debug_log(
f'Failed to load primary pipeline model: "{model_path}", reason: {exc_msg}')
if model_path in exc_msg:
if 'restricted' in exc_msg:
# the gated repo message is far more useful to the user
raise InvalidModelFileError(exc_msg) from e
else:
raise InvalidModelFileError(f'invalid model file or repo slug: {model_path}') from e
raise InvalidModelFileError(e) from e
if _hfhub.is_single_file_model_load(model_path):
if subfolder is not None:
raise UnsupportedPipelineConfigError(
'Single file model loads do not support the subfolder option.')
try:
_enforce_pipeline_cache_size(estimated_memory_usage)
pipeline = _pipeline_creation_args_debug(
backend='Torch',
cls=pipeline_class,
method=pipeline_class.from_single_file,
original_config=original_config,
model=model_path,
token=auth_token,
revision=revision,
variant=variant,
torch_dtype=torch_dtype,
use_safe_tensors=model_path.endswith('.safetensors'),
local_files_only=local_files_only,
**creation_kwargs)
except diffusers.loaders.single_file.SingleFileComponentError as e:
_handle_single_file_component_error(e)
except (ValueError, TypeError, NameError, OSError) as e:
_handle_generic_pipeline_load_failure(e)
else:
try:
pipeline = _pipeline_creation_args_debug(
backend='Torch',
cls=pipeline_class,
method=pipeline_class.from_pretrained,
model=model_path,
token=auth_token,
revision=revision,
variant=variant,
torch_dtype=torch_dtype,
subfolder=subfolder,
local_files_only=local_files_only,
**creation_kwargs)
except (ValueError, TypeError, NameError, OSError) as e:
_handle_generic_pipeline_load_failure(e)
if hasattr(pipeline, 'vae') and \
_enums.model_type_is_sd3(model_type):
# patch to enable tiling at all resolutions
if pipeline.vae.quant_conv is None:
pipeline.vae.quant_conv = lambda x: x
if pipeline.vae.post_quant_conv is None:
pipeline.vae.post_quant_conv = lambda x: x
# Textual Inversions, LoRAs, IP Adapters
parsed_textual_inversion_uris = []
parsed_lora_uris = []
parsed_ip_adapter_uris = []
if textual_inversion_uris:
for inversion_uri in textual_inversion_uris:
parsed = _uris.TextualInversionUri.parse(inversion_uri)
parsed_textual_inversion_uris.append(parsed)
_uris.TextualInversionUri.load_on_pipeline(
pipeline=pipeline,
uris=parsed_textual_inversion_uris,
use_auth_token=auth_token,
local_files_only=local_files_only)
if lora_uris and not quantizer_uri:
# LoRAs are fused into the model and cached
# to disk in the case that quantization is requested
for lora_uri in lora_uris:
parsed = _uris.LoRAUri.parse(lora_uri)
parsed_lora_uris.append(parsed)
_uris.LoRAUri.load_on_pipeline(
pipeline=pipeline,
uris=parsed_lora_uris,
fuse_scale=lora_fuse_scale if lora_fuse_scale is not None else 1.0,
use_auth_token=auth_token,
local_files_only=local_files_only)
if ip_adapter_uris:
for ip_adapter_uri in ip_adapter_uris:
parsed = _uris.IPAdapterUri.parse(ip_adapter_uri)
parsed_ip_adapter_uris.append(parsed)
_uris.IPAdapterUri.load_on_pipeline(
pipeline=pipeline,
uris=parsed_ip_adapter_uris,
use_auth_token=auth_token,
local_files_only=local_files_only)
if ip_adapter_uris and (not hasattr(pipeline, 'image_encoder') or pipeline.image_encoder is None):
raise UnsupportedPipelineConfigError(
'Using --ip-adapters but missing required --image-encoder specification, '
'your --ip-adapters specification did not include an image encoder model and '
'you must specify one manually.')
# Safety Checker
if not safety_checker_override:
if _enums.model_type_is_floyd(model_type):
_set_floyd_safety_checker(pipeline, safety_checker)
else:
_set_sd_safety_checker(pipeline, safety_checker)
# Model Offloading
# SD3 image_encoder needs to be excluded to avoid meta tensor errors.
if _enums.model_type_is_sd3(model_type) and sequential_cpu_offload and 'image_encoder' in creation_kwargs:
pipeline._exclude_from_cpu_offload.append("image_encoder")
if not device.startswith('cpu'):
if sequential_cpu_offload:
enable_sequential_cpu_offload(pipeline, device)
elif model_cpu_offload:
enable_model_cpu_offload(pipeline, device)
_messages.debug_log(f'Finished Creating Torch Pipeline: "{pipeline_class.__name__}"')
# modules quantized in 8 bit by bitsandbytes cannot be moved off the GPU
# there is a sequence in call_pipeline that garbage collects them out of the
# cache before calling other pipelines, there is also a dgenerate.devicecache
# hook to clear them out when requested
for module in get_pipeline_modules(pipeline).values():
if _util.is_loaded_in_8bit_bnb(module):
_disable_to(module)
# noinspection PyTypeChecker
return PipelineCreationResult(
model_path=model_path,
pipeline=pipeline,
parsed_unet_uri=parsed_unet_uri,
parsed_transformer_uri=parsed_transformer_uri,
parsed_vae_uri=parsed_vae_uri,
parsed_lora_uris=parsed_lora_uris,
parsed_image_encoder_uri=parsed_image_encoder_uri,
parsed_ip_adapter_uris=parsed_ip_adapter_uris,
parsed_textual_inversion_uris=parsed_textual_inversion_uris,
parsed_controlnet_uris=parsed_controlnet_uris,
parsed_t2i_adapter_uris=parsed_t2i_adapter_uris
), _d_memoize.CachedObjectMetadata(size=estimated_memory_usage)
def _handle_single_file_component_error(e):
msg = str(e)
if 'text_encoder' in msg:
raise UnsupportedPipelineConfigError(
f'Single file load error, missing --text-encoders / --second-model-text-encoders:\n{e}') from e
elif 'vae =' in msg:
raise UnsupportedPipelineConfigError(
f'Single file load error, missing --vae:\n{e}') from e
else:
raise UnsupportedPipelineConfigError(
f'Single file load error, missing component:\n{e}') from e
def get_converted_checkpoint_cache_dir():
"""
Get cache directory where dgenerate stores any checkpoints that needed to be converted into
diffusers directory format on disk to function, this process is used to
support quantization on single file checkpoints.
Or the value of the environmental variable ``DGENERATE_CACHE`` joined with ``diffusers_converted``.
:return: string (directory path)
"""
user_cache_path = os.environ.get('DGENERATE_CACHE')
if user_cache_path is not None:
path = os.path.join(user_cache_path, 'web')
else:
path = os.path.expanduser(os.path.join('~', '.cache', 'dgenerate', 'diffusers_converted'))
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
return path
_to_diffusers_cache_dir = get_converted_checkpoint_cache_dir()
# Component-specific caches for granular caching
def _get_component_cache_store(component_name: str):
component_dir = os.path.join(_to_diffusers_cache_dir, component_name)
pathlib.Path(component_dir).mkdir(parents=True, exist_ok=True)
return _filecache.KeyValueStore(os.path.join(component_dir, 'cache.db'))
_component_cache_stores = {
'text_encoder': _get_component_cache_store('text_encoder'),
'text_encoder_2': _get_component_cache_store('text_encoder_2'),
'text_encoder_3': _get_component_cache_store('text_encoder_3'),
'transformer': _get_component_cache_store('transformer'),
'unet': _get_component_cache_store('unet'),
}
def _get_component_cache_key(
component_name: str,
model_path: str,
model_type: _enums.ModelType,
lora_uris: _types.OptionalUris,
lora_fuse_scale: float | None,
revision: str | None,
variant: str | None,
subfolder: str | None,
dtype: _enums.DataType,
original_config: str | None,
manual_component_uris: dict[str, str] | None = None
) -> str:
manual_uri_part = ""
if manual_component_uris and component_name in manual_component_uris:
# Strip quantizer from manual URI for cache key since components are saved in original precision
try:
if component_name == 'unet':
parsed_uri = _uris.UNetUri.parse(manual_component_uris[component_name])
uri_without_quantizer = _uris.UNetUri(
model=parsed_uri.model,
revision=parsed_uri.revision,
variant=parsed_uri.variant,
subfolder=parsed_uri.subfolder,
dtype=parsed_uri.dtype,
quantizer=None
)
manual_uri_part = f"|manual:{str(uri_without_quantizer)}"
elif component_name == 'transformer':
parsed_uri = _uris.TransformerUri.parse(manual_component_uris[component_name])
uri_without_quantizer = _uris.TransformerUri(
model=parsed_uri.model,
revision=parsed_uri.revision,
variant=parsed_uri.variant,
subfolder=parsed_uri.subfolder,
dtype=parsed_uri.dtype,
quantizer=None
)
manual_uri_part = f"|manual:{str(uri_without_quantizer)}"
elif component_name.startswith('text_encoder'):
parsed_uri = _uris.TextEncoderUri.parse(manual_component_uris[component_name])
uri_without_quantizer = _uris.TextEncoderUri(
encoder=parsed_uri.encoder,
model=parsed_uri.model,
revision=parsed_uri.revision,
variant=parsed_uri.variant,
subfolder=parsed_uri.subfolder,
dtype=parsed_uri.dtype,
quantizer=None
)
manual_uri_part = f"|manual:{str(uri_without_quantizer)}"
else:
# Fallback for unknown component types
manual_uri_part = f"|manual:{manual_component_uris[component_name]}"
except Exception:
# If parsing fails, use the original URI
manual_uri_part = f"|manual:{manual_component_uris[component_name]}"
# LoRA order consideration: LoRA fusion order can affect the final result since
# LoRA operations are not necessarily commutative. We preserve the original order
# to ensure mathematical correctness and respect user intent.
#
# If you want to optimize cache efficiency by normalizing LoRA order (at the cost
# of potentially different results), you could sort the URIs like this:
# lora_uris_for_cache = sorted(lora_uris) if lora_uris else lora_uris
lora_uris_for_cache = lora_uris
# Note: quantizer_uri is NOT included in the cache key because components are saved
# in original precision. The quantizer is only applied when loading from cache.
return f"{component_name}|{model_path}|{str(lora_uris_for_cache)}|{str(lora_fuse_scale)}|" \
f"{str(model_type)}|{str(revision)}|{str(variant)}|{str(subfolder)}|" \
f"{str(dtype)}|{str(original_config)}{manual_uri_part}"
def _get_affected_components_for_lora(pipeline_class) -> list[str]:
if hasattr(pipeline_class, '_lora_loadable_modules'):
return list(pipeline_class._lora_loadable_modules)
return []
def _get_quantizable_components(model_type: _enums.ModelType) -> list[str]:
components = []
# Add UNet or Transformer based on model type
if _enums.model_type_is_sd3(model_type) or _enums.model_type_is_flux(model_type):
# For SD3 and Flux, single file checkpoints typically only contain the transformer
# Text encoders are usually separate standard models that don't need conversion
components.append('transformer')
else:
# For older models (SD1.5, SD2, SDXL), single files often contain UNet and text encoders
components.append('unet')
if _enums.model_type_is_sdxl(model_type):
components.extend(['text_encoder', 'text_encoder_2'])
else:
components.append('text_encoder')
return components
def _create_minimal_pipeline_for_component_extraction(
model_path: str,
model_type: _enums.ModelType,
pipeline_type: _enums.PipelineType,
components_needed: list[str],
revision: str | None,
variant: str | None,
subfolder: str | None,
dtype: _enums.DataType,
original_config: str | None,
auth_token: str | None,
lora_uris: _types.OptionalUris,
lora_fuse_scale: float | None,
manual_component_uris: dict[str, str] | None = None,
vae_uri: _types.OptionalUri = None
):
pipeline_class = get_pipeline_class(
model_type=model_type,
pipeline_type=pipeline_type,
help_mode=True # Allow creation even if pipeline_type doesn't match
)
torch_dtype = _enums.get_torch_dtype(dtype)
creation_kwargs = {}
manual_quantizers = {} # Store quantizer info from manual URIs
# Get all text encoder parameters that the pipeline expects
pipe_params = inspect.signature(pipeline_class.__init__).parameters
text_encoder_params = [name for name in pipe_params.keys() if name.startswith('text_encoder')]
# Load ALL text encoders that the pipeline expects, even if not being cached
# This prevents SingleFileComponentError during pipeline creation
for encoder_param in text_encoder_params:
# Check if we have a manual URI for this encoder
if manual_component_uris and encoder_param in manual_component_uris:
parsed_uri = _uris.TextEncoderUri.parse(manual_component_uris[encoder_param])
# Save quantizer info and reconstruct URI without quantizer
if parsed_uri.quantizer:
manual_quantizers[encoder_param] = parsed_uri.quantizer
# Reconstruct URI without quantizer
uri_without_quantizer = _uris.TextEncoderUri(
encoder=parsed_uri.encoder,
model=parsed_uri.model,
revision=parsed_uri.revision,
variant=parsed_uri.variant,
subfolder=parsed_uri.subfolder,
dtype=parsed_uri.dtype,
quantizer=None # Explicitly set to None
)
creation_kwargs[encoder_param] = uri_without_quantizer.load(
variant_fallback=variant,
dtype_fallback=dtype,
original_config=original_config,
use_auth_token=auth_token,
local_files_only=False,
no_cache=True,
missing_ok=False
)
else:
# Load default text encoder for this parameter
# We need to load it even if we're not caching it to prevent errors
try:
if _hfhub.is_single_file_model_load(model_path):
encoder_subfolder = encoder_param
else:
encoder_subfolder = os.path.join(subfolder, encoder_param) if subfolder else encoder_param
# Try to determine the encoder class - this is a simplified approach
# In practice, we might need to look at model_index or infer from model_type
if encoder_param == 'text_encoder':
encoder_name = 'CLIPTextModel' # Default for most SD models
elif encoder_param == 'text_encoder_2':
if _enums.model_type_is_sdxl(model_type):
encoder_name = 'CLIPTextModelWithProjection'
elif _enums.model_type_is_sd3(model_type):
encoder_name = 'CLIPTextModelWithProjection'
elif _enums.model_type_is_flux(model_type):
encoder_name = 'CLIPTextModel'
else:
encoder_name = 'CLIPTextModel'
elif encoder_param == 'text_encoder_3':
if _enums.model_type_is_sd3(model_type) or _enums.model_type_is_flux(model_type):
encoder_name = 'T5EncoderModel'
else:
encoder_name = 'CLIPTextModel'
else:
encoder_name = 'CLIPTextModel'
default_encoder_uri = _uris.TextEncoderUri(
encoder=encoder_name,
model=model_path,
variant=variant,
revision=revision,
subfolder=encoder_subfolder,
dtype=dtype,
quantizer=None
)
creation_kwargs[encoder_param] = default_encoder_uri.load(
variant_fallback=variant,
dtype_fallback=dtype,
original_config=original_config,
use_auth_token=auth_token,
local_files_only=False,
no_cache=True,
missing_ok=True # Allow missing for optional encoders
)
except Exception as e:
_messages.debug_log(f"Failed to load default {encoder_param}: {e}")
# Continue without this encoder - some may be optional
# Load manual components that are not text encoders
if manual_component_uris:
for component in components_needed:
if component in manual_component_uris and not component.startswith('text_encoder'):
if component == 'unet':
unet_class = diffusers.UNet2DConditionModel
parsed_uri = _uris.UNetUri.parse(manual_component_uris[component])
# Save quantizer info and reconstruct URI without quantizer
if parsed_uri.quantizer:
manual_quantizers[component] = parsed_uri.quantizer
# Reconstruct URI without quantizer
uri_without_quantizer = _uris.UNetUri(
model=parsed_uri.model,
revision=parsed_uri.revision,
variant=parsed_uri.variant,
subfolder=parsed_uri.subfolder,
dtype=parsed_uri.dtype,
quantizer=None # Explicitly set to None
)
creation_kwargs['unet'] = uri_without_quantizer.load(
variant_fallback=variant,
dtype_fallback=dtype,
original_config=original_config,
use_auth_token=auth_token,
local_files_only=False,
no_cache=True,
unet_class=unet_class
)
elif component == 'transformer':
if _enums.model_type_is_sd3(model_type):
transformer_class = diffusers.SD3Transformer2DModel
elif _enums.model_type_is_flux(model_type):
transformer_class = diffusers.FluxTransformer2DModel
else:
continue # Skip if no appropriate transformer class
parsed_uri = _uris.TransformerUri.parse(manual_component_uris[component])
# Save quantizer info and reconstruct URI without quantizer
if parsed_uri.quantizer:
manual_quantizers[component] = parsed_uri.quantizer
# Reconstruct URI without quantizer
uri_without_quantizer = _uris.TransformerUri(
model=parsed_uri.model,
revision=parsed_uri.revision,
variant=parsed_uri.variant,
subfolder=parsed_uri.subfolder,
dtype=parsed_uri.dtype,
quantizer=None # Explicitly set to None
)
creation_kwargs['transformer'] = uri_without_quantizer.load(
variant_fallback=variant,
dtype_fallback=dtype,
original_config=original_config,
use_auth_token=auth_token,
local_files_only=False,
no_cache=True,
transformer_class=transformer_class
)
# Load VAE if specified - this is needed for single file checkpoints that don't contain a VAE
if vae_uri:
parsed_vae_uri = _uris.VAEUri.parse(vae_uri)
creation_kwargs['vae'] = parsed_vae_uri.load(
dtype_fallback=dtype,
original_config=original_config,
use_auth_token=auth_token,
local_files_only=False,
no_cache=True,
missing_ok=False
)
# Load the pipeline with all required components
if _hfhub.is_single_file_model_load(model_path):
try:
pipeline = pipeline_class.from_single_file(
model_path,
original_config=original_config,
token=auth_token,
revision=revision,
variant=variant,
torch_dtype=torch_dtype,
use_safe_tensors=model_path.endswith('.safetensors'),
**creation_kwargs
)
except diffusers.loaders.single_file.SingleFileComponentError as e:
_handle_single_file_component_error(e)
except Exception as e:
_messages.debug_log(f"Failed to create minimal pipeline from single file: {e}")
raise
else:
try:
pipeline = pipeline_class.from_pretrained(
model_path,
token=auth_token,
revision=revision,
variant=variant,
torch_dtype=torch_dtype,
subfolder=subfolder,
**creation_kwargs
)
except Exception as e:
_messages.debug_log(f"Failed to create minimal pipeline from pretrained: {e}")
raise
# Apply LoRAs if specified
if lora_uris:
parsed_lora_uris = [_uris.LoRAUri.parse(uri) for uri in lora_uris]
_uris.LoRAUri.load_on_pipeline(
pipeline=pipeline,
uris=parsed_lora_uris,
fuse_scale=lora_fuse_scale if lora_fuse_scale is not None else 1.0,
use_auth_token=auth_token
)
return pipeline, manual_quantizers
def _cache_components_granular(
model_path: str,
model_type: _enums.ModelType,
pipeline_type: _enums.PipelineType,
components_to_cache: list[str],
lora_uris: _types.OptionalUris,
lora_fuse_scale: float | None,
revision: str | None,
variant: str | None,
subfolder: str | None,
dtype: _enums.DataType,
original_config: str | None,
auth_token: str | None,
manual_component_uris: dict[str, str] | None = None,
vae_uri: _types.OptionalUri = None
) -> tuple[dict[str, str], dict[str, str]]:
cached_component_paths = {}
components_to_load = []
# Check which components are already cached
for component in components_to_cache:
cache_key = _get_component_cache_key(
component, model_path, model_type, lora_uris, lora_fuse_scale,
revision, variant, subfolder, dtype, original_config,
manual_component_uris
)
cache_store = _component_cache_stores[component]
with cache_store:
cached_path = cache_store.get(cache_key)
if cached_path and os.path.exists(cached_path):
cached_component_paths[component] = cached_path
_messages.debug_log(f"Using cached {component} from: {cached_path}")
else:
components_to_load.append(component)
# If all components are cached, we still need to extract manual quantizer info
manual_quantizers = {}
if manual_component_uris:
for component, uri in manual_component_uris.items():
if component in components_to_cache:
try:
if component == 'unet':
parsed_uri = _uris.UNetUri.parse(uri)
if parsed_uri.quantizer:
manual_quantizers[component] = parsed_uri.quantizer
elif component == 'transformer':
parsed_uri = _uris.TransformerUri.parse(uri)
if parsed_uri.quantizer:
manual_quantizers[component] = parsed_uri.quantizer
elif component.startswith('text_encoder'):
parsed_uri = _uris.TextEncoderUri.parse(uri)
if parsed_uri.quantizer:
manual_quantizers[component] = parsed_uri.quantizer
except Exception as e:
_messages.debug_log(f"Error extracting quantizer from {component} URI: {e}")
# If all components are cached, return early
if not components_to_load:
return cached_component_paths, manual_quantizers
# Create pipeline and extract needed components
if lora_uris:
_messages.warning(
f'Model "{model_path}" is having LoRAs '
f'fused into specific components, that will then be cached on disk '
f'prior to quantization. This is a one time task per LoRA scale value, '
f'please be patient...'
)
else:
_messages.warning(
f'Model "{model_path}" components are being converted to '
f'diffusers format and cached on disk prior to quantization. '
f'This is a one time task, please be patient...'
)
with _d_memoize.disable_memoization_context():
pipeline, manual_quantizers = _create_minimal_pipeline_for_component_extraction(
model_path=model_path,
model_type=model_type,
pipeline_type=pipeline_type,
components_needed=components_to_load,
revision=revision,
variant=variant,
subfolder=subfolder,
dtype=dtype,
original_config=original_config,
auth_token=auth_token,
lora_uris=lora_uris,
lora_fuse_scale=lora_fuse_scale,
manual_component_uris=manual_component_uris,
vae_uri=vae_uri
)
if lora_uris:
pipeline.unload_lora_weights()
# Cache each component individually
for component in components_to_load:
if not hasattr(pipeline, component):
_messages.debug_log(
f"Component {component} not found in pipeline, skipping cache")
continue
component_obj = getattr(pipeline, component)
if component_obj is None:
_messages.debug_log(
f"Component {component} is None, skipping cache")
continue
# Get the cache store for the component
cache_store = _component_cache_stores[component]
# Generate cache key and path
cache_key = _get_component_cache_key(
component, model_path, model_type, lora_uris, lora_fuse_scale,
revision, variant, subfolder, dtype, original_config,
manual_component_uris
)
# Generate cache path with component/uuid structure
component_dir = os.path.join(_to_diffusers_cache_dir, component)
pathlib.Path(component_dir).mkdir(parents=True, exist_ok=True)
with cache_store:
cache_uuid = hashlib.sha256(cache_key.encode('utf-8')).hexdigest()
cache_path = os.path.join(component_dir, cache_uuid)
# Ensure unique path
while os.path.exists(cache_path):
cache_uuid = hashlib.sha256((cache_key + str(random.random())).encode('utf-8')).hexdigest()
cache_path = os.path.join(component_dir, cache_uuid)
# Save component
_messages.debug_log(f"Saving cached {component} to: {cache_path}")
component_obj.save_pretrained(cache_path, variant=variant)
# Store in cache
with cache_store:
cache_store[cache_key] = cache_path
cached_component_paths[component] = cache_path
return cached_component_paths, manual_quantizers
__all__ = _types.module_all()