Source code for dgenerate.promptweighters.sdembedpromptweighter

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import gc
import inspect
import re

import torch

import dgenerate.extras.sd_embed as _sd_embed
import dgenerate.messages as _messages
import dgenerate.pipelinewrapper.enums as _enums
import dgenerate.promptweighters.exceptions as _exceptions
import dgenerate.promptweighters.promptweighter as _promptweighter


[docs] class SdEmbedPromptWeighter(_promptweighter.PromptWeighter): r""" Implements prompt weighting syntax for Stable Diffusion 1/2, Stable Diffusion XL, and Stable Diffusion 3, and Flux using sd_embed. sd_embed uses a Stable Diffusion Web UI compatible prompt syntax. See: https://github.com/xhinker/sd_embed NOWRAP! @misc{sd_embed_2024, author = {Shudong Zhu(Andrew Zhu)}, title = {Long Prompt Weighted Stable Diffusion Embedding}, howpublished = {\url{https://github.com/xhinker/sd_embed}}, year = {2024}, } NOWRAP! --model-type torch --model-type torch-pix2pix --model-type torch-upscaler-x4 --model-type torch-sdxl --model-type torch-sdxl-pix2pix --model-type torch-s-cascade --model-type torch-sd3 --model-type torch-flux The secondary prompt option for SDXL --sdxl-second-prompts is supported by this prompt weighter implementation. However, --sdxl-refiner-second-prompts is not supported and will be ignored with a warning message. The secondary prompt option for SD3 --sd3-second-prompts is not supported by this prompt weighter implementation. Neither is --sd3-third-prompts. The prompts from these arguments will be ignored. The secondary prompt option for Flux --flux-second-prompts is supported by this prompt weighter. Flux does not support negative prompting in either prompt. """ NAMES = ['sd-embed']
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) supported = { _enums.ModelType.TORCH, _enums.ModelType.TORCH_PIX2PIX, _enums.ModelType.TORCH_UPSCALER_X4, _enums.ModelType.TORCH_SDXL, _enums.ModelType.TORCH_SDXL_PIX2PIX, _enums.ModelType.TORCH_S_CASCADE, _enums.ModelType.TORCH_SD3, _enums.ModelType.TORCH_FLUX } if self.model_type not in supported: raise _exceptions.PromptWeightingUnsupported( f'Prompt weighting not supported for --model-type: {_enums.get_model_type_string(self.model_type)}') self._tensors = list()
[docs] @torch.inference_mode() def translate_to_embeds(self, pipeline, device: str, args: dict[str, any]): # we are responsible for generating these arguments # if they exist already then we cannot do our job forbidden_call_args = { 'prompt_embeds', 'pooled_prompt_embeds', 'negative_prompt_embeds', 'negative_pooled_prompt_embeds' } if any(a in forbidden_call_args for a in args.keys()): raise _exceptions.PromptWeightingUnsupported( f'Prompt weighting not supported for --model-type: {_enums.get_model_type_string(self.model_type)}, ' f'in mode: {_enums.get_pipeline_type_string(self.pipeline_type)}') pipeline_sig = set(inspect.signature(pipeline.__call__).parameters.keys()) if 'prompt_embeds' not in pipeline_sig: # pipeline does not support passing prompt embeddings directly raise _exceptions.PromptWeightingUnsupported( f'Prompt weighting not supported for --model-type: {_enums.get_model_type_string(self.model_type)}, ' f'in mode: {_enums.get_pipeline_type_string(self.pipeline_type)}') if not (pipeline.__class__.__name__.startswith('StableDiffusionXL') or pipeline.__class__.__name__.startswith('StableDiffusion') or pipeline.__class__.__name__.startswith('StableDiffusion3') or pipeline.__class__.__name__.startswith('Flux') or pipeline.__class__.__name__.startswith('StableCascade')): raise _exceptions.PromptWeightingUnsupported( f'Prompt weighting not supported for --model-type: {_enums.get_model_type_string(self.model_type)}') output = dict(args) clip_skip = args.get('clip_skip', 0) positive = args.get('prompt') negative = args.get('negative_prompt') positive_2 = args.get('prompt_2') negative_2 = args.get('negative_prompt_2') if args.get('prompt_3') or args.get('negative_prompt_3'): _messages.log( f'Prompt weighting is not supported by --prompt-weighter ' f'"sd-embed" for --sd3-third-prompts, that prompt is being ignored.', level=_messages.WARNING) prompt_args = re.compile(r'^(prompt|negative_prompt)(_\d+)?$') for name in args.keys(): if prompt_args.match(name): output.pop(name) positive = positive if positive else "" negative = negative if negative else "" positive_2 = positive_2 if positive_2 else "" negative_2 = negative_2 if negative_2 else "" if hasattr(pipeline, 'maybe_convert_prompt'): # support refiner, which only has tokenizer_2 tk = pipeline.tokenizer if pipeline.tokenizer is not None else pipeline.tokenizer_2 if positive: positive = pipeline.maybe_convert_prompt(positive, tokenizer=tk) if negative: negative = pipeline.maybe_convert_prompt(negative, tokenizer=tk) if pipeline.tokenizer is not None: # refiner not supported for secondary prompt if positive_2: positive_2 = pipeline.maybe_convert_prompt(positive_2, tokenizer=pipeline.tokenizer_2) if negative_2: negative_2 = pipeline.maybe_convert_prompt(negative_2, tokenizer=pipeline.tokenizer_2) pos_conditioning = None neg_conditioning = None pos_pooled = None neg_pooled = None if pipeline.__class__.__name__.startswith('StableDiffusion3'): if positive_2 or negative_2: _messages.log( f'Prompt weighting is not supported by --prompt-weighter ' f'"sd-embed" for --sd3-second-prompts, that prompt is being ignored.', level=_messages.WARNING) original_clip_layers = pipeline.text_encoder.text_model.encoder.layers original_clip_layers_2 = pipeline.text_encoder_2.text_model.encoder.layers try: if clip_skip > 0: pipeline.text_encoder.text_model.encoder.layers = original_clip_layers[:-clip_skip] pipeline.text_encoder_2.text_model.encoder.layers = original_clip_layers_2[:-clip_skip] pos_conditioning, \ neg_conditioning, \ pos_pooled, \ neg_pooled = _sd_embed.get_weighted_text_embeddings_sd3( pipe=pipeline, prompt=positive, neg_prompt=negative, pad_last_block=True, use_t5_encoder=pipeline.tokenizer_3 is not None, device=device) finally: # leaving this modified would really # screw up other stuff in dgenerate :) if clip_skip > 0: pipeline.text_encoder.text_model.encoder.layers = original_clip_layers pipeline.text_encoder_2.text_model.encoder.layers = original_clip_layers_2 elif pipeline.__class__.__name__.startswith('StableCascade'): original_clip_layers = pipeline.text_encoder.text_model.encoder.layers try: if clip_skip > 0: pipeline.text_encoder.text_model.encoder.layers = original_clip_layers[:-clip_skip] pos_conditioning, \ neg_conditioning, \ pos_pooled, \ neg_pooled = _sd_embed.get_weighted_text_embeddings_s_cascade( pipe=pipeline, prompt=positive, neg_prompt=negative, device=device) finally: # leaving this modified would really # screw up other stuff in dgenerate :) if clip_skip > 0: pipeline.text_encoder.text_model.encoder.layers = original_clip_layers elif pipeline.__class__.__name__.startswith('StableDiffusionXL'): if pipeline.tokenizer is not None: original_clip_layers = pipeline.text_encoder.text_model.encoder.layers original_clip_layers_2 = pipeline.text_encoder_2.text_model.encoder.layers try: if clip_skip > 0: pipeline.text_encoder.text_model.encoder.layers = original_clip_layers[:-clip_skip] pipeline.text_encoder_2.text_model.encoder.layers = original_clip_layers_2[:-clip_skip] if positive_2 or negative_2: pos_conditioning, \ neg_conditioning, \ pos_pooled, \ neg_pooled = _sd_embed.get_weighted_text_embeddings_sdxl_2p( pipe=pipeline, prompt=positive, prompt_2=positive_2 if positive_2 else None, neg_prompt=negative, neg_prompt_2=negative_2 if negative_2 else None, device=device) else: pos_conditioning, \ neg_conditioning, \ pos_pooled, \ neg_pooled = _sd_embed.get_weighted_text_embeddings_sdxl( pipe=pipeline, prompt=positive, neg_prompt=negative, device=device) finally: # leaving this modified would really # screw up other stuff in dgenerate :) if clip_skip > 0: pipeline.text_encoder.text_model.encoder.layers = original_clip_layers pipeline.text_encoder_2.text_model.encoder.layers = original_clip_layers_2 else: if positive_2 or negative_2: _messages.log( f'Prompt weighting is not supported by --prompt-weighter ' f'"sd-embed" for --sdxl-refiner-second-prompts, that prompt is being ignored.', level=_messages.WARNING) original_clip_layers_2 = pipeline.text_encoder_2.text_model.encoder.layers try: if clip_skip > 0: pipeline.text_encoder_2.text_model.encoder.layers = original_clip_layers_2[:-clip_skip] pos_conditioning, \ neg_conditioning, \ pos_pooled, \ neg_pooled = _sd_embed.get_weighted_text_embeddings_sdxl_refiner( pipe=pipeline, prompt=positive, neg_prompt=negative, device=device) finally: # leaving this modified would really # screw up other stuff in dgenerate :) if clip_skip > 0: pipeline.text_encoder_2.text_model.encoder.layers = original_clip_layers_2 elif pipeline.__class__.__name__.startswith('StableDiffusion'): pos_conditioning, \ neg_conditioning = _sd_embed.get_weighted_text_embeddings_sd15( pipe=pipeline, prompt=positive, neg_prompt=negative, pad_last_block=False, clip_skip=clip_skip, device=device) elif pipeline.__class__.__name__.startswith('Flux'): pos_conditioning, \ pos_pooled = _sd_embed.get_weighted_text_embeddings_flux1( pipe=pipeline, prompt=positive, prompt2=positive_2 if positive_2 else None, device=device) if pos_conditioning is not None: self._tensors.append(pos_conditioning) output.update({ 'prompt_embeds': pos_conditioning, }) if neg_conditioning is not None: self._tensors.append(neg_conditioning) output.update({ 'negative_prompt_embeds': neg_conditioning, }) if pos_pooled is not None: self._tensors.append(pos_pooled) if self.model_type == _enums.ModelType.TORCH_S_CASCADE: output.update({ 'prompt_embeds_pooled': pos_pooled, }) else: output.update({ 'pooled_prompt_embeds': pos_pooled, }) if neg_pooled is not None: self._tensors.append(neg_pooled) if self.model_type == _enums.ModelType.TORCH_S_CASCADE: output.update({ 'negative_prompt_embeds_pooled': neg_pooled, }) else: output.update({ 'negative_pooled_prompt_embeds': neg_pooled, }) return output
[docs] def cleanup(self): for tensor in self._tensors: tensor.to('cpu') del tensor self._tensors.clear() gc.collect() torch.cuda.empty_cache()