Source code for dgenerate.pipelinewrapper.uris.sdxlrefineruri

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import dgenerate.pipelinewrapper.enums as _enums
import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types
from dgenerate.pipelinewrapper.uris import exceptions as _exceptions

_sdxl_refiner_uri_parser = _textprocessing.ConceptUriParser(
    'SDXL Refiner', ['revision', 'variant', 'subfolder', 'dtype'])


[docs] class SDXLRefinerUri: """ Representation of ``--sdxl-refiner`` uri """ # pipelinewrapper.uris.util.get_uri_accepted_args_schema metadata NAMES = ['SDXL Refiner']
[docs] @staticmethod def help(): import dgenerate.arguments as _a return _a.get_raw_help_text('--sdxl-refiner')
OPTION_ARGS = { 'dtype': ['float16', 'bfloat16', 'float32'] } FILE_ARGS = { 'model': {'mode': ['in', 'dir'], 'filetypes': [('Models', ['*.safetensors', '*.pt', '*.pth', '*.cpkt', '*.bin'])]} } # === @property def model(self) -> str: """ Model path, huggingface slug """ return self._model @property def revision(self) -> _types.OptionalString: """ Model repo revision """ return self._revision @property def variant(self) -> _types.OptionalString: """ Model repo revision """ return self._variant @property def subfolder(self) -> _types.OptionalPath: """ Model repo subfolder """ return self._subfolder @property def dtype(self) -> _enums.DataType | None: """ Model dtype (precision) """ return self._dtype
[docs] def __init__(self, model: str, revision: _types.OptionalString = None, variant: _types.OptionalString = None, subfolder: _types.OptionalPath = None, dtype: _enums.DataType | str | None = None): """ :param model: model path :param revision: model revision (branch name) :param variant: model variant, for example ``fp16`` :param subfolder: model subfolder :param dtype: model data type (precision) :raises InvalidSDXLRefinerUriError: If ``dtype`` is passed an invalid data type string. """ self._model = model self._revision = revision self._variant = variant try: self._dtype = _enums.get_data_type_enum(dtype) if dtype else None except ValueError: raise _exceptions.InvalidSDXLRefinerUriError( f'invalid dtype string, must be one of: {_textprocessing.oxford_comma(_enums.supported_data_type_strings(), "or")}') self._subfolder = subfolder
[docs] @staticmethod def parse(uri: _types.Uri) -> 'SDXLRefinerUri': """ Parse an ``--sdxl-refiner`` uri and return an object representing its constituents :param uri: string with ``--sdxl-refiner`` uri syntax :raise InvalidSDXLRefinerUriError: :return: :py:class:`.SDXLRefinerUri` """ try: r = _sdxl_refiner_uri_parser.parse(uri) supported_dtypes = _enums.supported_data_type_strings() dtype = r.args.get('dtype', None) if dtype is not None and dtype not in supported_dtypes: raise _exceptions.InvalidSDXLRefinerUriError( f'Torch SDXL refiner "dtype" must be {", ".join(supported_dtypes)}, ' f'or left undefined, received: {dtype}') return SDXLRefinerUri( model=r.concept, revision=r.args.get('revision', None), variant=r.args.get('variant', None), dtype=dtype, subfolder=r.args.get('subfolder', None)) except _textprocessing.ConceptUriParseError as e: raise _exceptions.InvalidSDXLRefinerUriError(e) from e