Source code for dgenerate.pipelinewrapper.enums

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import enum

import torch

import dgenerate.types as _types


[docs] class PipelineType(enum.Enum): """ Represents possible diffusers pipeline types. """ TXT2IMG = 1 """ Text to image mode. Prompt only generation. """ IMG2IMG = 2 """ Image to image mode. Generation seeded / controlled with an image in some fashion. """ INPAINT = 3 """ Inpainting mode. Generation seeded / controlled with an image and a mask in some fashion. """
[docs] def get_pipeline_type_enum(id_str: PipelineType | str | None) -> PipelineType: """ Get a :py:class:`.PipelineType` enum value from a string. :param id_str: one of: "txt2img", "img2img", or "inpaint" :raises ValueError: if an invalid string value (name) is passed :return: :py:class:`.PipelineType` """ if isinstance(id_str, PipelineType): return id_str try: return {'txt2img': PipelineType.TXT2IMG, 'img2img': PipelineType.IMG2IMG, 'inpaint': PipelineType.INPAINT}[id_str.strip().lower()] except KeyError: raise ValueError('invalid PipelineType string')
[docs] def get_pipeline_type_string(pipeline_type_enum: PipelineType): """ Convert a :py:class:`.PipelineType` enum value to a string. :param pipeline_type_enum: :py:class:`.PipelineType` value :return: one of: "txt2img", "img2img", or "inpaint" """ pipeline_type = get_pipeline_type_enum(pipeline_type_enum) return {PipelineType.TXT2IMG: 'txt2img', PipelineType.IMG2IMG: 'img2img', PipelineType.INPAINT: 'inpaint'}[pipeline_type]
[docs] class DataType(enum.Enum): """ Represents model precision """ AUTO = 0 """Auto selection.""" FLOAT16 = 1 """16 bit floating point.""" FLOAT32 = 2 """32 bit floating point.""" BFLOAT16 = 3 """16 bit brain floating point."""
[docs] def supported_data_type_strings(): """ Return a list of supported ``--dtype`` strings """ return ['auto', 'bfloat16', 'float16', 'float32']
[docs] def supported_data_type_enums() -> list[DataType]: """ Return a list of supported :py:class:`.DataType` enum values """ return [get_data_type_enum(i) for i in supported_data_type_strings()]
[docs] def get_data_type_enum(id_str: DataType | str | None) -> DataType: """ Convert a ``--dtype`` string to its :py:class:`.DataType` enum value :param id_str: ``--dtype`` string :raises ValueError: if an invalid string value (name) is passed :return: :py:class:`.DataType` """ if isinstance(id_str, DataType): return id_str try: return {'auto': DataType.AUTO, 'float16': DataType.FLOAT16, 'float32': DataType.FLOAT32, 'bfloat16': DataType.BFLOAT16}[id_str.strip().lower()] except KeyError: raise ValueError('invalid DataType string')
[docs] def get_data_type_string(data_type_enum: DataType) -> str: """ Convert a :py:class:`.DataType` enum value to its ``--dtype`` string :param data_type_enum: :py:class:`.DataType` value :return: ``--dtype`` string """ model_type = get_data_type_enum(data_type_enum) return {DataType.AUTO: 'auto', DataType.FLOAT16: 'float16', DataType.FLOAT32: 'float32', DataType.BFLOAT16: 'bfloat16'}[model_type]
[docs] class ModelType(enum.Enum): """ Enum representation of ``--model-type`` """ SD = 0 """Stable Diffusion, such as SD 1.0 - 2.x""" PIX2PIX = 1 """Stable Diffusion pix2pix prompt guided editing.""" SDXL = 2 """Stable Diffusion XL""" IF = 3 """Deep Floyd IF stage 1""" IFS = 4 """Deep Floyd IF superscaler (stage 2)""" IFS_IMG2IMG = 5 """Deep Floyd IF superscaler (stage 2) image to image / variation mode.""" SDXL_PIX2PIX = 6 """Stable Diffusion XL pix2pix prompt guided editing.""" UPSCALER_X2 = 7 """Stable Diffusion X2 upscaler""" UPSCALER_X4 = 8 """Stable Diffusion X4 upscaler""" S_CASCADE = 9 """ Stable Cascade prior """ S_CASCADE_DECODER = 10 """ Stable Cascade decoder """ SD3 = 11 """ Stable Diffusion 3 """ SD3_PIX2PIX = 12 """ Stable Diffusion 3 pix2pix prompt guided editing. """ FLUX = 13 """ Flux pipeline """ FLUX_FILL = 14 """ Flux infill / outfill pipeline """ FLUX_KONTEXT = 15 """ Flux Kontext pipeline """ KOLORS = 16 """Kolors (SDXL + ChatGLM)"""
[docs] def supported_model_type_strings(): """ Return a list of supported ``--model-type`` strings """ return ['sd', 'pix2pix', 'sdxl', 'sdxl-pix2pix', 'kolors', 'upscaler-x2', 'upscaler-x4', 'if', 'ifs', 'ifs-img2img', 's-cascade', 'sd3', 'sd3-pix2pix', 'flux', 'flux-fill', 'flux-kontext']
[docs] def supported_model_type_enums() -> list[ModelType]: """ Return a list of supported :py:class:`.ModelType` enum values """ return [get_model_type_enum(i) for i in supported_model_type_strings()]
[docs] def get_model_type_enum(id_str: ModelType | str) -> ModelType: """ Convert a ``--model-type`` string to its :py:class:`.ModelType` enum value :param id_str: ``--model-type`` string :raises ValueError: if an invalid string value (name) is passed :return: :py:class:`.ModelType` """ if isinstance(id_str, ModelType): return id_str try: return {'sd': ModelType.SD, 'pix2pix': ModelType.PIX2PIX, 'sdxl': ModelType.SDXL, 'kolors': ModelType.KOLORS, 'if': ModelType.IF, 'ifs': ModelType.IFS, 'ifs-img2img': ModelType.IFS_IMG2IMG, 'sdxl-pix2pix': ModelType.SDXL_PIX2PIX, 'upscaler-x2': ModelType.UPSCALER_X2, 'upscaler-x4': ModelType.UPSCALER_X4, 's-cascade': ModelType.S_CASCADE, 'sd3': ModelType.SD3, 'sd3-pix2pix': ModelType.SD3_PIX2PIX, 'flux': ModelType.FLUX, 'flux-fill': ModelType.FLUX_FILL, 'flux-kontext': ModelType.FLUX_KONTEXT}[id_str.strip().lower()] except KeyError: raise ValueError('invalid ModelType string')
[docs] def get_model_type_string(model_type_enum: ModelType) -> str: """ Convert a :py:class:`.ModelType` enum value to its ``--model-type`` string :param model_type_enum: :py:class:`.ModelType` value :return: ``--model-type`` string """ model_type = get_model_type_enum(model_type_enum) return {ModelType.SD: 'sd', ModelType.PIX2PIX: 'pix2pix', ModelType.SDXL: 'sdxl', ModelType.KOLORS: 'kolors', ModelType.IF: 'if', ModelType.IFS: 'ifs', ModelType.IFS_IMG2IMG: 'ifs-img2img', ModelType.SDXL_PIX2PIX: 'sdxl-pix2pix', ModelType.UPSCALER_X2: 'upscaler-x2', ModelType.UPSCALER_X4: 'upscaler-x4', ModelType.S_CASCADE: 's-cascade', ModelType.S_CASCADE_DECODER: 's-cascade-decoder', ModelType.SD3: 'sd3', ModelType.SD3_PIX2PIX: 'sd3-pix2pix', ModelType.FLUX: 'flux', ModelType.FLUX_FILL: 'flux-fill', ModelType.FLUX_KONTEXT: 'flux-kontext'}[model_type]
[docs] def model_type_is_sd15(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an SD1.5 model? These model types may also be able to load SD2 checkpoints, specifically: :py:attr:`.ModelType.SD` can. :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_enum(model_type) return model_type in { ModelType.SD, ModelType.PIX2PIX, ModelType.UPSCALER_X2 }
[docs] def model_type_is_sd2(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an SD 2.X compatible model? These model types may also be able to load SD1.5 checkpoints, specifically: :py:attr:`.ModelType.SD` can. :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_enum(model_type) return model_type in { ModelType.SD, ModelType.UPSCALER_X4 }
[docs] def model_type_is_upscaler(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an upscaler model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'upscaler' in model_type
[docs] def model_type_is_sdxl(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an SDXL model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'sdxl' in model_type
[docs] def model_type_is_kolors(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent a Kolors model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'kolors' in model_type
[docs] def model_type_is_sd3(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an SD3 model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'sd3' in model_type
[docs] def model_type_is_flux(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent a Flux model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'flux' in model_type
[docs] def model_type_is_s_cascade(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent a Stable Cascade related model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 's-cascade' in model_type
[docs] def model_type_is_pix2pix(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an pix2pix type model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'pix2pix' in model_type
[docs] def model_type_is_floyd(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an floyd "if" of "ifs" type model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_enum(model_type) return model_type == ModelType.IF or \ model_type == ModelType.IFS or \ model_type == ModelType.IFS_IMG2IMG
[docs] def model_type_is_floyd_if(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an floyd "if" type model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_enum(model_type) return model_type == ModelType.IF
[docs] def model_type_is_floyd_ifs(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an floyd "ifs" type model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_enum(model_type) return model_type == ModelType.IFS or model_type == ModelType.IFS_IMG2IMG
[docs] def get_torch_dtype(dtype: DataType | torch.dtype | str | None) -> torch.dtype | None: """ Return a :py:class:`torch.dtype` datatype from a :py:class:`.DataType` value, or a string, or a :py:class:`torch.dtype` datatype itself. Passing ``None`` results in ``None`` being returned. Passing 'auto' or :py:attr:`DataType.AUTO` results in ``None`` being returned. :param dtype: :py:class:`.DataType`, string, :py:class:`torch.dtype`, None :raises ValueError: if an invalid string value (name) is passed :return: :py:class:`torch.dtype` """ if dtype is None: return None if isinstance(dtype, torch.dtype): return dtype if isinstance(dtype, DataType): dtype = get_data_type_string(dtype) try: return {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'auto': None}[dtype.lower()] except KeyError: raise ValueError('invalid DataType string')
__all__ = _types.module_all()