Source code for dgenerate.pipelinewrapper.schedulers

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import ast
import inspect
import itertools

import diffusers
import diffusers.schedulers
import typing
import dgenerate.types as _types
import numpy

import dgenerate.textprocessing as _textprocessing
import diffusers.loaders
import dgenerate.messages as _messages
import collections.abc


[docs] class SchedulerLoadError(Exception): """ Base class for scheduler loading exceptions. """
[docs] class SchedulerArgumentError(SchedulerLoadError): """ Scheduler URI argument error. """ pass
[docs] class InvalidSchedulerNameError(SchedulerLoadError): """ Unknown scheduler name used. """ pass
def _resolve_karras_schedulers(): """Resolves `KarrasDiffusionSchedulers` enum values to actual scheduler class names in `diffusers.schedulers`.""" scheduler_names = {} for scheduler_name in diffusers.schedulers.KarrasDiffusionSchedulers.__members__.keys(): if hasattr(diffusers.schedulers, scheduler_name): scheduler_names[diffusers.schedulers.KarrasDiffusionSchedulers[scheduler_name]] = getattr( diffusers.schedulers, scheduler_name) return scheduler_names _KARRAS_SCHEDULERS_MAP = _resolve_karras_schedulers() def _expand_with_compatibles(scheduler_names): expanded_schedulers = set(scheduler_names) for scheduler_cls in list(expanded_schedulers): compatibles = getattr(scheduler_cls, "_compatibles", None) if not compatibles: continue for compatible in compatibles: if isinstance(compatible, diffusers.schedulers.KarrasDiffusionSchedulers): expanded_schedulers.add(_KARRAS_SCHEDULERS_MAP.get(compatible, None)) elif inspect.isclass(compatible) and issubclass(compatible, diffusers.schedulers.SchedulerMixin): expanded_schedulers.add(compatible) return list(expanded_schedulers)
[docs] def get_compatible_schedulers(pipeline_cls: type[diffusers.DiffusionPipeline]) -> list[type[diffusers.SchedulerMixin]]: """ Finds all compatible scheduler classes for a given diffusers pipeline class without instantiating it. :param pipeline_cls: The pipeline class, for example :py:class:`diffusers.StableDiffusionPipeline` :return A list of compatible scheduler class types """ if pipeline_cls is diffusers.StableDiffusionLatentUpscalePipeline: # Seems to only work with this scheduler return [diffusers.EulerDiscreteScheduler] if any(pipeline_cls is x for x in (diffusers.IFPipeline, diffusers.IFInpaintingPipeline, diffusers.IFImg2ImgPipeline, diffusers.IFSuperResolutionPipeline, diffusers.IFInpaintingSuperResolutionPipeline, diffusers.IFImg2ImgSuperResolutionPipeline)): # same here return [diffusers.DDPMScheduler] compatible_schedulers = set() # Get constructor signature init_sig = inspect.signature(pipeline_cls.__init__) for param in init_sig.parameters.values(): param_type = param.annotation # Case 1: Direct scheduler class type hint (e.g., `param: DDIMScheduler`) if inspect.isclass(param_type) and issubclass(param_type, diffusers.schedulers.SchedulerMixin): compatible_schedulers.add(param_type) # Case 2: Union type hint (e.g., `Union[DDIMScheduler, EulerScheduler]`) elif _types.is_union(param_type): for sub_type in typing.get_args(param_type): if inspect.isclass(sub_type) and issubclass(sub_type, diffusers.schedulers.SchedulerMixin): compatible_schedulers.add(sub_type) # Case 3: Enum-based schedulers (KarrasDiffusionSchedulers) elif param_type is diffusers.schedulers.KarrasDiffusionSchedulers: compatible_schedulers.update(_KARRAS_SCHEDULERS_MAP.values()) # Expand using _compatibles compatibles = _expand_with_compatibles(compatible_schedulers) if issubclass(pipeline_cls, (diffusers.loaders.StableDiffusionLoraLoaderMixin, diffusers.loaders.StableDiffusionXLLoraLoaderMixin)): compatibles.append(diffusers.LCMScheduler) return compatibles
_scheduler_option_args = { diffusers.DDIMScheduler: { "beta_schedule": ["linear", "scaled_linear", "squaredcos_cap_v2"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"] }, diffusers.DDPMScheduler: { "beta_schedule": ["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"], 'variance_type': ["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"] }, diffusers.DDPMWuerstchenScheduler: { # Note: This scheduler doesn't use the standard beta_schedule, prediction_type, or timestep_spacing # It uses custom parameters: scaler and s }, diffusers.DEISMultistepScheduler: { "beta_schedule": ["linear", "scaled_linear", "squaredcos_cap_v2"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"] }, diffusers.DPMSolverMultistepScheduler: { "beta_schedule": ["linear", "scaled_linear", "squaredcos_cap_v2"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"], "algorithm_type": ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"], "solver_type": ["midpoint", "heun"], "final_sigmas_type": ["zero", "sigma_min"], "variance_type": ["learned", "learned_range"] }, diffusers.DPMSolverSDEScheduler: { "beta_schedule": ["linear", "scaled_linear"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"] }, diffusers.DPMSolverSinglestepScheduler: { "beta_schedule": ["linear", "scaled_linear", "squaredcos_cap_v2"], "prediction_type": ["epsilon", "sample", "v_prediction"], "algorithm_type": ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"], "solver_type": ["midpoint", "heun"], "final_sigmas_type": ["zero", "sigma_min"], "variance_type": ["learned", "learned_range"] }, diffusers.EDMEulerScheduler: { "sigma_schedule": ["karras", "exponential"], "prediction_type": ["epsilon", "sample", "v_prediction"], "final_sigmas_type": ["zero", "sigma_min"] }, diffusers.EulerAncestralDiscreteScheduler: { "beta_schedule": ["linear", "scaled_linear"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"] }, diffusers.EulerDiscreteScheduler: { "beta_schedule": ["linear", "scaled_linear", "squaredcos_cap_v2"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"], "timestep_type": ["discrete", "continuous"], "interpolation_type": ["linear", "log_linear"], "final_sigmas_type": ["zero", "sigma_min"] }, diffusers.FlowMatchEulerDiscreteScheduler: { "time_shift_type": ["exponential", "linear"] }, diffusers.HeunDiscreteScheduler: { "beta_schedule": ["linear", "scaled_linear"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"] }, diffusers.KDPM2AncestralDiscreteScheduler: { "beta_schedule": ["linear", "scaled_linear"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"] }, diffusers.KDPM2DiscreteScheduler: { "beta_schedule": ["linear", "scaled_linear"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"] }, diffusers.LCMScheduler: { "beta_schedule": ["linear", "scaled_linear", "squaredcos_cap_v2"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"] }, diffusers.LMSDiscreteScheduler: { "beta_schedule": ["linear", "scaled_linear"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"] }, diffusers.PNDMScheduler: { "beta_schedule": ["linear", "scaled_linear", "squaredcos_cap_v2"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"] }, diffusers.UniPCMultistepScheduler: { "beta_schedule": ["linear", "scaled_linear", "squaredcos_cap_v2"], "prediction_type": ["epsilon", "sample", "v_prediction"], "timestep_spacing": ["leading", "trailing", "linspace"], "solver_type": ["bh1", "bh2"], "final_sigmas_type": ["zero", "sigma_min"] } }
[docs] def get_scheduler_uri_schema(scheduler: type[diffusers.SchedulerMixin] | list[type[diffusers.SchedulerMixin]]): """ Return a schema describing initialization arguments from a ``diffusers`` scheduler type, or list of scheduler types. This returns a set of schemas keyed by scheduler name, which are identical to the schema format returned by :py:meth:`dgenerate.plugin.Plugin.get_accepted_args_schema`. Arguments which cannot be passed through a URI such as class references are omitted. :param scheduler: ``diffusers`` scheduler type, or list of them. :return: ``dict`` schema. """ if not isinstance(scheduler, list): scheduler = [scheduler] schema = dict() for class_type in scheduler: # first argument is the config cloning behavior parameter_schema = { 'clone-config': { 'optional': False, 'default': True, 'types': ['bool'] } } schema[class_type.__name__] = parameter_schema def _type_name(t): return (str(t) if t.__module__ != 'builtins' else t.__name__).strip() def _resolve_union(t): t_name = _type_name(t) if _types.is_union(t): return set(itertools.chain.from_iterable( [_resolve_union(t) for t in parameter.annotation.__args__])) return [t_name] def _filter_types(typs): o = set() for t in typs: if t.startswith('typing.List'): o.add('list') elif t == "<class 'numpy.ndarray'>": o.add('list') elif t.startswith("<class"): pass else: o.add(t) return list(o) option_args = _scheduler_option_args.get(class_type, dict()) for parameter_name, parameter in inspect.signature(class_type.__init__).parameters.items(): if parameter_name == 'self': continue parameter_details = dict() type_name = _type_name(parameter.annotation) if _types.is_union(parameter.annotation): union_args = _resolve_union(parameter.annotation) if 'NoneType' in union_args: parameter_details['optional'] = True union_args.remove('NoneType') filtered_types = _filter_types(list(sorted(union_args))) if not filtered_types: continue parameter_details['types'] = filtered_types else: filtered_types = _filter_types([type_name]) if not filtered_types: continue parameter_details['optional'] = False parameter_details['types'] = filtered_types if isinstance(parameter.default, list): if not all(isinstance(i, typing.SupportsIndex) or not isinstance(i, collections.abc.Iterable) for i in parameter.default): # cannot support multiple dimensions continue if isinstance(parameter.default, numpy.ndarray): if parameter.default.ndim != 1: # cannot support multiple dimensions continue if parameter.default is not inspect.Parameter.empty: parameter_details['default'] = parameter.default if parameter.name in option_args: parameter_details['options'] = option_args[parameter.name] parameter_schema[_textprocessing.dashup(parameter_name)] = parameter_details return schema
[docs] def load_scheduler(pipeline: diffusers.DiffusionPipeline, scheduler_uri: _types.Uri | None): """ Load a specific compatible scheduler class name onto a huggingface diffusers pipeline object. Passing ``None`` to the URI reloads the original scheduler that the pipeline was loaded with, if no new scheduler has been set since then, this is a no-op. :raises InvalidSchedulerNameError: If an invalid scheduler name is specified specifically. :raises SchedulerArgumentError: If invalid arguments are supplied to the scheduler via the URI. :param pipeline: pipeline object :param scheduler_uri: Compatible scheduler URI. """ if scheduler_uri is None: if hasattr(pipeline, '_DGENERATE_ORIGINAL_SCHEDULER'): pipeline.scheduler = pipeline._DGENERATE_ORIGINAL_SCHEDULER return compatibles = get_compatible_schedulers(pipeline.__class__) def _get_uri_arg_value( scheduler: str, value: typing.Any, arg_name: str, optional: bool, types: list): if isinstance(value, list): return value elif optional and value.lower() == 'none': return None elif any(t == 'list' for t in types): try: val = ast.literal_eval(value) if not isinstance(val, (list, tuple, set)): return [val] else: return val except (ValueError, SyntaxError) as e: raise SchedulerArgumentError( f'{scheduler} argument "{arg_name}" ' f'must be a singular literal, list, ' f'tuple, or set value in python syntax.' ) from e elif any(t == 'float' for t in types): try: return float(value) except ValueError as e: raise SchedulerArgumentError( f'{scheduler} argument "{arg_name}" ' f'must be a floating point value.' ) from e elif any(t == 'int' for t in types): try: return int(value) except ValueError as e: raise SchedulerArgumentError( f'{scheduler} argument "{arg_name}" ' f'must be an integer value.' ) from e elif any(t == 'bool' for t in types): try: return _types.parse_bool(value) except ValueError as e: raise SchedulerArgumentError( f'{scheduler} argument "{arg_name}" ' f'must be a boolean value.' ) from e try: # string literal? return ast.literal_eval(value) except (ValueError, SyntaxError): # token (string) return value for scheduler_type in compatibles: if scheduler_type.__name__.startswith(scheduler_uri.split(';')[0].strip()): schema = get_scheduler_uri_schema(scheduler_type)[scheduler_type.__name__] parser = _textprocessing.ConceptUriParser( 'Scheduler', known_args=list(schema.keys()), args_raw=[k for k, v in schema.items() if any(t == 'list' for t in v['types'])]) try: result = parser.parse(scheduler_uri) except _textprocessing.ConceptUriParseError as e: raise SchedulerArgumentError(e) from e args = {_textprocessing.dashdown(k): _get_uri_arg_value( scheduler_type.__name__, v, k, schema[k]['optional'], schema[k]['types']) for k, v in result.args.items()} option_args = _scheduler_option_args.get(scheduler_type, dict()) for arg, value in args.items(): if arg in option_args and value not in option_args[arg]: raise SchedulerArgumentError( f'Invalid value "{value}" for argument "{_textprocessing.dashup(arg)}" ' f'of scheduler "{scheduler_type.__name__}", ' f'valid options are: ' f'{_textprocessing.oxford_comma(option_args[arg], "or")}') _messages.debug_log( f'Constructing Scheduler: "{scheduler_type.__name__}", URI Args: {args}') try: if not hasattr(pipeline, '_DGENERATE_ORIGINAL_SCHEDULER'): # first time pipeline._DGENERATE_ORIGINAL_SCHEDULER = pipeline.scheduler clone_config = args.pop('clone_config', True) if clone_config: # init from original scheduler config # apply any user overrides over top pipeline.scheduler = scheduler_type.from_config( pipeline._DGENERATE_ORIGINAL_SCHEDULER.config, **args ) else: # raw init with possible overrides to defaults pipeline.scheduler = scheduler_type(**args) except Exception as e: raise SchedulerArgumentError( f'Error constructing scheduler "{scheduler_type.__name__}" ' f'with given URI argument values, encountered error: {e}') from e _messages.debug_log( f'Scheduler: "{scheduler_type.__name__}", ' f'Successfully added to pipeline: {pipeline.__class__.__name__}') # found a matching scheduler, return return raise InvalidSchedulerNameError( f'Scheduler named "{scheduler_uri}" is not a valid compatible scheduler, ' 'options are:\n\n' + '\n'.join( sorted(' ' * 4 + _textprocessing.quote(i.__name__.split('.')[-1]) for i in compatibles)))