Source code for dgenerate.promptupscalers.magicpromptupscaler

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import contextlib
import gc
import typing

import diffusers
import torch
import transformers

import dgenerate.memory as _memory
import dgenerate.messages as _messages
import dgenerate.pipelinewrapper.util as _pipelinewrapper_util
import dgenerate.prompt as _prompt
import dgenerate.promptupscalers.exceptions as _exceptions
import dgenerate.promptupscalers.llmupscalermixin as _llmupscalermixin
import dgenerate.promptupscalers.promptupscaler as _promptupscaler
from dgenerate.pipelinewrapper.uris import get_quantizer_uri_class as _get_quantizer_uri_class
from dgenerate.pipelinewrapper.uris import BNBQuantizerUri as _BNBQuantizerUri


@contextlib.contextmanager
def _with_seed(seed: int | None):
    if seed is None:
        yield
    else:
        orig_state = torch.random.get_rng_state()
        torch.manual_seed(seed)
        try:
            yield
        finally:
            torch.random.set_rng_state(orig_state)


class _TextGenerationPipeline:
    def __init__(self, model, tokenizer, quantized: bool):
        self.model = model
        self.tokenizer = tokenizer
        self.quantized = quantized

    def to(self, device):
        if not self.quantized:
            self.model.to(device)

    def __call__(self,
                 prompts: list[str],
                 batch_size: int = 1,
                 max_length: int = 100,
                 **kwargs
                 ):
        results = []
        for i in range(0, len(prompts), batch_size):
            batch = prompts[i: i + batch_size]

            inputs = self.tokenizer(
                batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length)

            model_device = self.model.device

            inputs = inputs.to(model_device)

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=max_length, **kwargs)

            if model_device.type != 'cpu':
                inputs.to('cpu')

            del inputs

            results.extend(
                [self.tokenizer.decode(output) for output in outputs])

        return results


[docs] class MagicPromptUpscaler(_llmupscalermixin.LLMPromptUpscalerMixin, _promptupscaler.PromptUpscaler): """ Upscale prompts using magicprompt or other LLMs via transformers. The "part" argument indicates which parts of the prompt to act on, possible values are: "both", "positive", and "negative" The "model" specifies the model path for magicprompt, the default value is: "Gustavosta/MagicPrompt-Stable-Diffusion". This can be a folder on disk or a Hugging Face repository slug. The "dtype" argument specifies the torch dtype (compute dtype) to load the model with, this defaults to: float32, and may be one of: float32, float16, or bfloat16. The "seed" argument can be used to specify a seed for prompt generation. The "variations" argument specifies how many variations should be produced. The "max-length" argument is the max prompt length for a generated prompt, this value defaults to 100. The "temperature" argument sets the sampling temperature to use when generating prompts. Larger values increase creativity but decrease factuality. The "top_k" argument sets the "top_k" generation value, i.e. randomly sample from the "top_k" most likely tokens at each generation step. Set this to 1 for greedy decoding. The "top_p" argument sets the "top_p" generation value, i.e. randomly sample at each generation step from the top most likely tokens whose probabilities add up to "top_p". The "system" argument sets the system instruction for the LLM. The "preamble" argument sets a text input preamble for the LLM, this preamble will be removed from the output generated by the LLM. The "remove-prompt" argument specifies whether to remove the original prompt from the generated text. The "prepend-prompt" argument specifies whether to forcefully prepend the original prompt to the generated prompt, this might be necessary if you want a continuation with some models, the original prompt will be prepended with a space at the end. The "batch" argument enables and disables batching prompt text into the LLM, setting this to False tells the plugin that you only want the LLM to ever process one prompt at a time, this might be useful if you are memory constrained, but processing is much slower. The "max-batch" argument allows you to adjust how many prompts can be processed by the LLM simultaneously, processing too many prompts at once will run your system out of memory, processing too little prompts at once will be slow. Specifying "None" indicates unlimited batch size. The "quantizer" argument allows you to specify a quantization backend for loading the LLM, this is the same syntax and supported backends as with the dgenerate --quantizer argument. The "block-regex" argument is a python syntax regex that will block prompts that match the regex, the prompt will be regenerated until the regex does not match, up to "max-attempts". This regex is case-insensitive. The "max-attempts" argument specifies how many times to reattempt to generate a prompt if it is blocked by "block-regex" The "smart-truncate" argument enables intelligent truncation of the prompt generated by the LLM, i.e. it will remove incomplete sentences from the end of the prompt utilizing spaCy NLP. The "cleanup-config" argument allows you to specify a custom LLM output cleanup configuration file in .json, .toml, or .yaml format. This file can be used to run custom pattern substitutions or python functions over the LLMs raw output, and overrides the built-in cleanup excluding "smart-truncate" which occurs before your configuration. """ NAMES = ['magicprompt'] OPTION_ARGS = { 'part': ['both', 'positive', 'negative'], 'dtype': ['float32', 'float16', 'bfloat16'] } FILE_ARGS = { 'model': {'mode': 'dir'}, 'cleanup-config': {'mode': 'in', 'filetypes': [('Cleanup Config', ('*.json', '*.toml', '*.yaml', '*.yml'))]} }
[docs] def __init__(self, part: str = 'both', model: str = "Gustavosta/MagicPrompt-Stable-Diffusion", dtype: str = 'float32', seed: int | None = None, variations: int = 1, max_length: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 1.0, system: str | None = None, preamble: str | None = None, remove_prompt: bool = False, prepend_prompt: bool = False, batch: bool = True, max_batch: int | None = 50, quantizer: str | None = None, block_regex: str | None = None, max_attempts: int = 10, smart_truncate: bool = False, cleanup_config: str | None = None, **kwargs ): """ :param kwargs: child class forwarded arguments """ super().__init__(**kwargs, part=part, block_regex=block_regex, max_attempts=max_attempts, cleanup_mode='magic' if 'magicprompt' in model.lower() else 'other', smart_truncate=smart_truncate, cleanup_config=cleanup_config) dtype = dtype.lower() if dtype not in {'float32', 'float16', 'bfloat16'}: raise self.argument_error('Argument "dtype" must be either float32, float16, or bfloat16.') if quantizer: try: quantizer_class = _get_quantizer_uri_class(quantizer) quantization_config = quantizer_class.parse(quantizer).to_config(dtype) except Exception as e: raise self.argument_error(f'Error loading "quantizer" argument "{quantizer}": {e}') from e else: quantization_config = None part = part.lower() if part not in {'both', 'positive', 'negative'}: raise self.argument_error( 'Argument "part" must be one of: "both", "positive", or "negative"' ) if max_length < 1: raise self.argument_error( 'Cannot specify "max-length" less than 1.' ) if variations < 1: raise self.argument_error( 'Argument "variations" may not be less than 1.') if temperature < 0.0: raise self.argument_error( 'Argument "temperature" may not be less than 1.') if top_k < 1: raise self.argument_error( 'Argument "top-k" may not be less than 1.') if top_p < 0.0: raise self.argument_error( 'Argument "top-p" may not be less than 0.') if top_p > 1.0: raise self.argument_error( 'Argument "top-p" may not be greater than 1.') if max_batch is not None and max_batch < 1: raise self.argument_error( 'Argument "max-batch" may not be less than 1.') model_files = list( _pipelinewrapper_util.fetch_model_files_with_size( model, local_files_only=self.local_files_only, extensions={'.safetensors', '.bin'}) ) if len(model_files) > 1: model_files = [m for m in model_files if m[0].endswith('.safetensors')] estimated_size = 0 for model_entry in model_files: estimated_size += model_entry[1] _messages.debug_log( f'Estimated the size of LLM model: ' f'{model}, as: {estimated_size} Bytes ({_memory.bytes_best_human_unit(estimated_size)})') def load_method(): if quantization_config is not None: self.memory_guard_device(self.device, self.size_estimate) torch_dtype = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16 }[dtype] if isinstance(quantization_config, diffusers.BitsAndBytesConfig): if quantization_config.load_in_4bit and quantization_config.bnb_4bit_compute_dtype is None: quantization_config.bnb_4bit_compute_dtype = torch_dtype return self._load_pipeline(model, dtype=torch_dtype, quantization_config=quantization_config) self.set_size_estimate(estimated_size) self._pipeline = self.load_object_cached( tag=model + (quantizer if quantizer else '') + dtype, estimated_size=estimated_size, method=load_method ) self._system = system self._preamble = preamble self._remove_prompt = remove_prompt self._seed = seed self._max_length = max_length self._temperature = temperature self._top_k = top_k self._top_p = top_p self._variations = variations self._accepts_batch = batch self._max_batch = max_batch self._part = part self._quantizer = quantizer self._max_attempts = max_attempts self._prepend_prompt = prepend_prompt
def _load_pipeline(self, model_name: str, dtype: torch.dtype, quantization_config: typing.Optional[typing.Any] = None) -> _TextGenerationPipeline: try: model = transformers.AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=dtype, quantization_config=quantization_config, device_map=self.device if quantization_config else None, local_files_only=self.local_files_only ) except Exception as e: raise self.argument_error(f'Could not load model "{model_name}": {e}') tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token_id = model.config.eos_token_id return _TextGenerationPipeline( model, tokenizer, quantized=quantization_config is not None ) def _to(self, device: str | torch.device): self._pipeline.to(device) def _generate_prompts(self, original_prompts: list[str]) -> list[str]: def build_query(text): if self._preamble: return self._preamble + (' ' if not self._preamble.endswith(' ') else '') + text return text if self._system: formatted_prompts = [ f"<|system|> {self._system} <|user|> {build_query(query)} <|assistant|>" for query in original_prompts ] else: formatted_prompts = [build_query(query) for query in original_prompts] if not self._accepts_batch: generated_prompts = [] for ptext in formatted_prompts: generated_prompts.extend(self._call_pipeline([ptext])) elif self._max_batch is not None: generated_prompts = [] for batch_segment in range(0, len(formatted_prompts), self._max_batch): segment = formatted_prompts[batch_segment:batch_segment + self._max_batch] generated_prompts.extend(self._call_pipeline(segment)) else: generated_prompts = self._call_pipeline(formatted_prompts) generated_prompts = [ self._clean_prompt( formatted_prompt, generated_prompt, remove_prefixes=[self._system, self._preamble], remove_prompt=self._remove_prompt, prepend=original_prompt if self._prepend_prompt else None, ) for original_prompt, formatted_prompt, generated_prompt in zip( original_prompts, formatted_prompts, generated_prompts ) ] return generated_prompts def _call_pipeline(self, prompts: list[str]): return self._pipeline( prompts, max_length=self._max_length, temperature=self._temperature, top_k=self._top_k, top_p=self._top_p, do_sample=True, batch_size=len(prompts) ) @contextlib.contextmanager def _with_device(self): if self._quantizer: yield gc.collect() _memory.torch_gc() else: try: self.memory_guard_device(self.device, self.size_estimate) self._to(self.device) yield finally: self._to('cpu') gc.collect() _memory.torch_gc() @property def accepts_batch(self) -> bool: """ This prompt upscaler can accept a batch of prompts for efficient execution. :return: ``True``, unless the constructor argument ``batch`` was passed ``False`` """ return self._accepts_batch
[docs] def upscale(self, prompts: _prompt.PromptOrPrompts) -> _prompt.PromptOrPrompts: if isinstance(prompts, _prompt.Prompt): prompts = [prompts] if len(prompts) > 1 and not self.accepts_batch: raise _exceptions.PromptUpscalerProcessingError( f'magicprompt prompt upscaler cannot accept batch input when ' f'the argument "batch" is set to False.' ) prompts = list(prompts) * self._variations try: with _with_seed(self._seed), self._with_device(): return self._process_prompts(prompts) except torch.cuda.OutOfMemoryError as e: prompt_count = len(prompts) if prompt_count > 1: raise _exceptions.PromptUpscalerProcessingError( f'magicprompt prompt upscaler could not ' f'process {len(prompts)} incoming prompt(s) due to CUDA ' f'out of memory error, try using the argument "batch=False" ' f'to process only one prompt at a time (this is slow).') from e raise _exceptions.PromptUpscalerProcessingError( f'magicprompt prompt upscaler could not ' f'process prompt due to CUDA out of memory error: {prompts[0]}' ) from e except transformers.pipelines.PipelineException as e: raise _exceptions.PromptUpscalerProcessingError( f'magicprompt prompt upscaler could not process prompt(s) due ' f'to transformers pipeline exception: {e}' ) from e