Source code for dgenerate.pipelinewrapper.enums

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import enum
import typing

import torch

import dgenerate.types as _types


[docs] class PipelineType(enum.Enum): """ Represents possible diffusers pipeline types. """ TXT2IMG = 1 """ Text to image mode. Prompt only generation. """ IMG2IMG = 2 """ Image to image mode. Generation seeded / controlled with an image in some fashion. """ INPAINT = 3 """ Inpainting mode. Generation seeded / controlled with an image and a mask in some fashion. """
[docs] def get_pipeline_type_enum(id_str: PipelineType | str | None) -> PipelineType: """ Get a :py:class:`.PipelineType` enum value from a string. :param id_str: one of: "txt2img", "img2img", or "inpaint" :raises ValueError: if an invalid string value (name) is passed :return: :py:class:`.PipelineType` """ if isinstance(id_str, PipelineType): return id_str try: return {'txt2img': PipelineType.TXT2IMG, 'img2img': PipelineType.IMG2IMG, 'inpaint': PipelineType.INPAINT}[id_str.strip().lower()] except KeyError: raise ValueError('invalid PipelineType string')
[docs] def get_pipeline_type_string(pipeline_type_enum: PipelineType): """ Convert a :py:class:`.PipelineType` enum value to a string. :param pipeline_type_enum: :py:class:`.PipelineType` value :return: one of: "txt2img", "img2img", or "inpaint" """ pipeline_type = get_pipeline_type_enum(pipeline_type_enum) return {PipelineType.TXT2IMG: 'txt2img', PipelineType.IMG2IMG: 'img2img', PipelineType.INPAINT: 'inpaint'}[pipeline_type]
[docs] class DataType(enum.Enum): """ Represents model precision """ AUTO = 0 """Auto selection.""" FLOAT16 = 1 """16 bit floating point.""" FLOAT32 = 2 """32 bit floating point.""" BFLOAT16 = 3 """16 bit brain floating point."""
[docs] def supported_data_type_strings(): """ Return a list of supported ``--dtype`` strings """ return ['auto', 'bfloat16', 'float16', 'float32']
[docs] def supported_data_type_enums() -> list[DataType]: """ Return a list of supported :py:class:`.DataType` enum values """ return [get_data_type_enum(i) for i in supported_data_type_strings()]
[docs] def get_data_type_enum(id_str: DataType | str | None) -> DataType: """ Convert a ``--dtype`` string to its :py:class:`.DataType` enum value :param id_str: ``--dtype`` string :raises ValueError: if an invalid string value (name) is passed :return: :py:class:`.DataType` """ if isinstance(id_str, DataType): return id_str try: return {'auto': DataType.AUTO, 'float16': DataType.FLOAT16, 'float32': DataType.FLOAT32, 'bfloat16': DataType.BFLOAT16}[id_str.strip().lower()] except KeyError: raise ValueError('invalid DataType string')
[docs] def get_data_type_string(data_type_enum: DataType) -> str: """ Convert a :py:class:`.DataType` enum value to its ``--dtype`` string :param data_type_enum: :py:class:`.DataType` value :return: ``--dtype`` string """ model_type = get_data_type_enum(data_type_enum) return {DataType.AUTO: 'auto', DataType.FLOAT16: 'float16', DataType.FLOAT32: 'float32', DataType.BFLOAT16: 'bfloat16'}[model_type]
[docs] class ModelType(enum.Enum): """ Enum representation of ``--model-type`` """ TORCH = 0 """Stable Diffusion, such as SD 1.0 - 2.x""" TORCH_PIX2PIX = 1 """Stable Diffusion pix2pix prompt guided editing.""" TORCH_SDXL = 2 """Stable Diffusion XL""" TORCH_IF = 3 """Deep Floyd IF stage 1""" TORCH_IFS = 4 """Deep Floyd IF superscaler (stage 2)""" TORCH_IFS_IMG2IMG = 5 """Deep Floyd IF superscaler (stage 2) image to image / variation mode.""" TORCH_SDXL_PIX2PIX = 6 """Stable Diffusion XL pix2pix prompt guided editing.""" TORCH_UPSCALER_X2 = 7 """Stable Diffusion X2 upscaler""" TORCH_UPSCALER_X4 = 8 """Stable Diffusion X4 upscaler""" TORCH_S_CASCADE = 9 """ Stable Cascade prior """ TORCH_S_CASCADE_DECODER = 10 """ Stable Cascade decoder """ TORCH_SD3 = 11 """ Stable Diffusion 3 """ TORCH_FLUX = 12 """ Flux pipeline """ TORCH_FLUX_FILL = 13 """ Flux infill / outfill pipeline """
[docs] def supported_model_type_strings(): """ Return a list of supported ``--model-type`` strings """ return ['torch', 'torch-pix2pix', 'torch-sdxl', 'torch-sdxl-pix2pix', 'torch-upscaler-x2', 'torch-upscaler-x4', 'torch-if', 'torch-ifs', 'torch-ifs-img2img', 'torch-s-cascade', 'torch-sd3', 'torch-flux', 'torch-flux-fill']
[docs] def supported_model_type_enums() -> list[ModelType]: """ Return a list of supported :py:class:`.ModelType` enum values """ return [get_model_type_enum(i) for i in supported_model_type_strings()]
[docs] def get_model_type_enum(id_str: ModelType | str) -> ModelType: """ Convert a ``--model-type`` string to its :py:class:`.ModelType` enum value :param id_str: ``--model-type`` string :raises ValueError: if an invalid string value (name) is passed :return: :py:class:`.ModelType` """ if isinstance(id_str, ModelType): return id_str try: return {'torch': ModelType.TORCH, 'torch-pix2pix': ModelType.TORCH_PIX2PIX, 'torch-sdxl': ModelType.TORCH_SDXL, 'torch-if': ModelType.TORCH_IF, 'torch-ifs': ModelType.TORCH_IFS, 'torch-ifs-img2img': ModelType.TORCH_IFS_IMG2IMG, 'torch-sdxl-pix2pix': ModelType.TORCH_SDXL_PIX2PIX, 'torch-upscaler-x2': ModelType.TORCH_UPSCALER_X2, 'torch-upscaler-x4': ModelType.TORCH_UPSCALER_X4, 'torch-s-cascade': ModelType.TORCH_S_CASCADE, 'torch-sd3': ModelType.TORCH_SD3, 'torch-flux': ModelType.TORCH_FLUX, 'torch-flux-fill': ModelType.TORCH_FLUX_FILL}[id_str.strip().lower()] except KeyError: raise ValueError('invalid ModelType string')
[docs] def get_model_type_string(model_type_enum: ModelType) -> str: """ Convert a :py:class:`.ModelType` enum value to its ``--model-type`` string :param model_type_enum: :py:class:`.ModelType` value :return: ``--model-type`` string """ model_type = get_model_type_enum(model_type_enum) return {ModelType.TORCH: 'torch', ModelType.TORCH_PIX2PIX: 'torch-pix2pix', ModelType.TORCH_SDXL: 'torch-sdxl', ModelType.TORCH_IF: 'torch-if', ModelType.TORCH_IFS: 'torch-ifs', ModelType.TORCH_IFS_IMG2IMG: 'torch-ifs-img2img', ModelType.TORCH_SDXL_PIX2PIX: 'torch-sdxl-pix2pix', ModelType.TORCH_UPSCALER_X2: 'torch-upscaler-x2', ModelType.TORCH_UPSCALER_X4: 'torch-upscaler-x4', ModelType.TORCH_S_CASCADE: 'torch-s-cascade', ModelType.TORCH_S_CASCADE_DECODER: 'torch-s-cascade-decoder', ModelType.TORCH_SD3: 'torch-sd3', ModelType.TORCH_FLUX: 'torch-flux', ModelType.TORCH_FLUX_FILL: 'torch-flux-fill'}[model_type]
[docs] def model_type_is_upscaler(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an upscaler model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'upscaler' in model_type
[docs] def model_type_is_sdxl(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an SDXL model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'sdxl' in model_type
[docs] def model_type_is_sd3(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an SD3 model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'sd3' in model_type
[docs] def model_type_is_flux(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent a Flux model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'flux' in model_type
[docs] def model_type_is_s_cascade(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent a Stable Cascade related model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 's-cascade' in model_type
[docs] def model_type_is_torch(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an Torch model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'torch' in model_type
[docs] def model_type_is_pix2pix(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an pix2pix type model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_string(model_type) return 'pix2pix' in model_type
[docs] def model_type_is_floyd(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an floyd "if" of "ifs" type model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_enum(model_type) return model_type == ModelType.TORCH_IF or \ model_type == ModelType.TORCH_IFS or \ model_type == ModelType.TORCH_IFS_IMG2IMG
[docs] def model_type_is_floyd_if(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an floyd "if" type model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_enum(model_type) return model_type == ModelType.TORCH_IF
[docs] def model_type_is_floyd_ifs(model_type: ModelType | str) -> bool: """ Does a ``--model-type`` string or :py:class:`.ModelType` enum value represent an floyd "ifs" type model? :param model_type: ``--model-type`` string or :py:class:`.ModelType` enum value :return: bool """ model_type = get_model_type_enum(model_type) return model_type == ModelType.TORCH_IFS or model_type == ModelType.TORCH_IFS_IMG2IMG
[docs] def get_torch_dtype(dtype: DataType | torch.dtype | str | None) -> torch.dtype | None: """ Return a :py:class:`torch.dtype` datatype from a :py:class:`.DataType` value, or a string, or a :py:class:`torch.dtype` datatype itself. Passing ``None`` results in ``None`` being returned. Passing 'auto' or :py:attr:`DataType.AUTO` results in ``None`` being returned. :param dtype: :py:class:`.DataType`, string, :py:class:`torch.dtype`, None :raises ValueError: if an invalid string value (name) is passed :return: :py:class:`torch.dtype` """ if dtype is None: return None if isinstance(dtype, torch.dtype): return dtype if isinstance(dtype, DataType): dtype = get_data_type_string(dtype) try: return {'bfloat16': torch.bfloat16, 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'auto': None}[dtype.lower()] except KeyError: raise ValueError('invalid DataType string')
__all__ = _types.module_all()