Source code for dgenerate.pipelinewrapper.uris.textualinversionuri

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os.path
import typing

import diffusers

import dgenerate.hfhub as _hfhub
import dgenerate.messages as _messages
import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types
from dgenerate.pipelinewrapper.uris import exceptions as _exceptions

_textual_inversion_uri_parser = _textprocessing.ConceptUriParser(
    'Textual Inversion', ['token', 'revision', 'subfolder', 'weight-name'])


def _load_textual_inversion_state_dict(pretrained_model_name_or_path, **kwargs):
    from diffusers.utils.hub_utils import _get_model_file
    from diffusers.models.modeling_utils import load_state_dict

    text_inversion_name = "learned_embeds.bin"
    text_inversion_name_safe = "learned_embeds.safetensors"

    cache_dir = kwargs.pop("cache_dir", None)
    force_download = kwargs.pop("force_download", False)
    proxies = kwargs.pop("proxies", None)
    local_files_only = kwargs.pop("local_files_only", False)
    token = kwargs.pop("token", None)
    revision = kwargs.pop("revision", None)
    subfolder = kwargs.pop("subfolder", None)
    weight_name = kwargs.pop("weight_name", None)
    use_safetensors = kwargs.pop("use_safetensors", None)

    allow_pickle = False
    if use_safetensors is None:
        use_safetensors = True
        allow_pickle = True

    user_agent = {
        "file_type": "text_inversion",
        "framework": "pytorch",
    }

    # 3.1. Load textual inversion file
    state_dict = None
    model_file = None

    # Let's first try to load .safetensors weights
    if (use_safetensors and weight_name is None) or (
            weight_name is not None and weight_name.endswith(".safetensors")
    ):
        try:
            model_file = _get_model_file(
                pretrained_model_name_or_path,
                weights_name=weight_name or text_inversion_name_safe,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
                subfolder=subfolder,
                user_agent=user_agent,
            )
            state_dict = load_state_dict(model_file)
        except Exception as e:
            if not allow_pickle:
                raise e

            model_file = None

    if model_file is None:
        model_file = _get_model_file(
            pretrained_model_name_or_path,
            weights_name=weight_name or text_inversion_name,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
            subfolder=subfolder,
            user_agent=user_agent,
        )
        state_dict = load_state_dict(model_file)
    return model_file, state_dict


[docs] class TextualInversionUri: """ Representation of ``--textual-inversions`` uri """ # pipelinewrapper.uris.util.get_uri_accepted_args_schema metadata NAMES = ['Textual Inversion']
[docs] @staticmethod def help(): import dgenerate.arguments as _a return _a.get_raw_help_text('--textual-inversions')
FILE_ARGS = { 'model': {'mode': ['in', 'dir'], 'filetypes': [('Models', ['*.safetensors', '*.pt', '*.pth', '*.cpkt', '*.bin'])]} } # === @property def model(self) -> str: """ Model path, huggingface slug, file path """ return self._model @property def revision(self) -> _types.OptionalString: """ Model repo revision """ return self._revision @property def subfolder(self) -> _types.OptionalPath: """ Model repo subfolder """ return self._subfolder @property def weight_name(self) -> _types.OptionalName: """ Model weight-name """ return self._weight_name @property def token(self) -> _types.OptionalString: """ Prompt keyword """ return self._token
[docs] def __init__(self, model: str, token: str | None = None, revision: _types.OptionalString = None, subfolder: _types.OptionalPath = None, weight_name: _types.OptionalName = None): self._token = token self._model = model self._revision = revision self._subfolder = subfolder self._weight_name = weight_name
def __str__(self): return f'{self.__class__.__name__}({str(_types.get_public_attributes(self))})' def __repr__(self): return str(self)
[docs] @staticmethod def load_on_pipeline(pipeline: diffusers.DiffusionPipeline, uris: typing.Iterable[typing.Union["TextualInversionUri", str]], use_auth_token: _types.OptionalString = None, local_files_only: bool = False): """ Load Textual Inversion weights on to a pipeline using on or more URIs :param pipeline: :py:class:`diffusers.DiffusionPipeline` :param uris: Iterable of :py:class:`TextualInversionUri` or ``str`` Textual Inversion URIs to load :param use_auth_token: optional huggingface auth token. :param local_files_only: avoid downloading files and only look for cached files when the model path is a huggingface slug :raises ModelNotFoundError: If the model could not be found. :raises dgenerate.pipelinewrapper.uris.exceptions.InvalidTextualInversionUriError: On URI parsing errors. :raises dgenerate.pipelinewrapper.uris.exceptions.TextualInversionUriLoadError: On loading errors. """ def cache_all(e): if isinstance(e, _exceptions.InvalidTextualInversionUriError): raise e else: raise _exceptions.TextualInversionUriLoadError( f'error loading Textual Inversions: {e}') from e with _hfhub.with_hf_errors_as_model_not_found(cache_all): TextualInversionUri._load_on_pipeline( uris=uris, pipeline=pipeline, use_auth_token=use_auth_token, local_files_only=local_files_only)
@staticmethod def _load_on_pipeline(pipeline: diffusers.DiffusionPipeline, uris: typing.Iterable[typing.Union["TextualInversionUri", str]], use_auth_token: _types.OptionalString = None, local_files_only: bool = False): if hasattr(pipeline, 'load_textual_inversion'): for textual_inversion_uri in uris: if not isinstance(textual_inversion_uri, TextualInversionUri): textual_inversion_uri = TextualInversionUri.parse(textual_inversion_uri) model_path = _hfhub.download_non_hf_slug_model(textual_inversion_uri.model) is_sdxl = pipeline.__class__.__name__.startswith('StableDiffusionXL') is_flux = pipeline.__class__.__name__.startswith('Flux') if is_sdxl or is_flux: filename, dicts = _load_textual_inversion_state_dict( model_path, revision=textual_inversion_uri.revision, subfolder=textual_inversion_uri.subfolder, weight_name=textual_inversion_uri.weight_name, local_files_only=local_files_only, token=use_auth_token ) if is_sdxl: if 'clip_l' not in dicts or 'clip_g' not in dicts: raise RuntimeError( 'clip_l or clip_g not found in SDXL textual ' f'inversion model "{textual_inversion_uri.model}" state dict, ' 'unsupported model format.') else: if 'clip_l' not in dicts: raise RuntimeError( 'clip_l not found in Flux textual ' f'inversion model "{textual_inversion_uri.model}" state dict, ' 'unsupported model format.') # token is the file name (no extension) with spaces # replaced by underscores when the user does not provide # a prompt token token = os.path.splitext( os.path.basename(filename))[0].replace(' ', '_') \ if textual_inversion_uri.token is None else textual_inversion_uri.token pipeline.load_textual_inversion(dicts['clip_l'], token=token, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, hf_token=use_auth_token) if is_sdxl: pipeline.load_textual_inversion(dicts['clip_g'], token=token, text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2, hf_token=use_auth_token) if is_flux and 't5' in dicts: pipeline.load_textual_inversion(dicts['t5'], token=token, text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2, hf_token=use_auth_token) else: pipeline.load_textual_inversion(model_path, token=textual_inversion_uri.token, revision=textual_inversion_uri.revision, subfolder=textual_inversion_uri.subfolder, weight_name=textual_inversion_uri.weight_name, local_files_only=local_files_only, hf_token=use_auth_token) _messages.debug_log(f'Added Textual Inversion: "{textual_inversion_uri}" ' f'to pipeline: "{pipeline.__class__.__name__}"') else: raise RuntimeError(f'Pipeline: {pipeline.__class__.__name__} ' f'does not support loading textual inversions.')
[docs] @staticmethod def parse(uri: _types.Uri) -> 'TextualInversionUri': """ Parse a ``--textual-inversions`` uri and return an object representing its constituents :param uri: string with ``--textual-inversions`` uri syntax :raise InvalidTextualInversionUriError: :return: :py:class:`.TextualInversionPath` """ try: r = _textual_inversion_uri_parser.parse(uri) return TextualInversionUri(model=r.concept, token=r.args.get('token', None), weight_name=r.args.get('weight-name', None), revision=r.args.get('revision', None), subfolder=r.args.get('subfolder', None)) except _textprocessing.ConceptUriParseError as e: raise _exceptions.InvalidTextualInversionUriError(e) from e