Source code for dgenerate.prompt

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

__doc__ = """
Prompt representation object / prompt parsing.
"""

import ast
import collections.abc
import functools
import re
import typing

import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types


[docs] class PromptEmbeddedArgumentError(Exception): """ Error involving a prompt embedded argument other than ``weighter`` """ pass
[docs] class Prompt: """ Represents a combined positive and optional negative prompt split by a delimiter character. """
[docs] def __init__(self, positive: str | None = None, negative: str | None = None, delimiter: str = ';', weighter: _types.OptionalUri = None, upscaler: _types.OptionalUriOrUris = None, embedded_args: dict[str, str] | None = None): """ :param positive: positive prompt component. :param negative: negative prompt component. :param delimiter: delimiter for stringification. :param weighter: ``--prompt-weighter`` plugin URI. :param upscaler: ``--prompt-upscaler`` plugin URI. :param embedded_args: embedded prompt arguments parsed from ``<argument: value_text>``. """ import dgenerate.promptupscalers as _promptupscalers import dgenerate.promptweighters as _promptweighters if weighter is not None and not _promptweighters.prompt_weighter_exists(weighter): raise PromptEmbeddedArgumentError( f'Unknown prompt "weighter" implementation: {_promptweighters.prompt_weighter_name_from_uri(weighter)}, ' f'must be one of: {_textprocessing.oxford_comma(_promptweighters.prompt_weighter_names(), "or")}') if upscaler is not None: if isinstance(upscaler, str): check_upscalers = [upscaler] else: check_upscalers = upscaler for u in check_upscalers: if not _promptupscalers.prompt_upscaler_exists(u): raise PromptEmbeddedArgumentError( f'Unknown prompt "upscaler" implementation: {_promptupscalers.prompt_upscaler_name_from_uri(u)}, ' f'must be one of: {_textprocessing.oxford_comma(_promptupscalers.prompt_upscaler_names(), "or")}' ) self._positive = positive self._negative = negative self._delimiter = delimiter self._weighter = weighter self._upscaler = upscaler if embedded_args is not None: self._embedded_args = embedded_args else: self._embedded_args = dict()
[docs] def __str__(self): if self._positive and self._negative: return f'{self._positive}{self._delimiter} {self._negative}' elif self._positive: return self._positive elif self._negative: return self._delimiter + self._negative else: return ''
@property def delimiter(self) -> str: """ Positive / Negative delimiter for this prompt, for example ";" """ return self._delimiter @property def positive(self) -> str | None: """ Positive prompt value. """ return self._positive @positive.setter def positive(self, value): self._positive = value @property def negative(self) -> str | None: """ Negative prompt value. """ return self._negative @negative.setter def negative(self, value): self._negative = value @property def weighter(self) -> _types.OptionalUri: """ Embedded prompt weighter URI argument for this prompt if any. """ return self._weighter @property def upscaler(self) -> _types.OptionalUriOrUris: """ Embedded prompt upscaler URI argument for this prompt if any. """ return self._upscaler @property def embedded_args(self) -> list[tuple[str, str]]: """ Other embedded arguments parsed out of the prompt. """ return list(self._embedded_args.items()) def __repr__(self): return f"'{str(self)}'"
[docs] def copy_embedded_args_from(self, prompt: 'Prompt'): self._weighter = prompt._weighter self._upscaler = prompt._upscaler self._embedded_args = prompt._embedded_args.copy()
[docs] def set_embedded_args_on( self, on_object: typing.Any, forbidden_checker: typing.Optional[typing.Callable[[str, typing.Any], bool]] = None, validate_only: bool = False ): """ Set the other embedded arguments parsed out of a prompt on to an object. The object should be type hinted using types from :py:mod:`dgenerate.types` Specifically, any of: * :py:class:`dgenerate.types.Size` * :py:class:`dgenerate.types.Sizes` * :py:class:`dgenerate.types.OptionalSize` * :py:class:`dgenerate.types.OptionalSizes` * :py:class:`dgenerate.types.Padding` * :py:class:`dgenerate.types.Paddings` * :py:class:`dgenerate.types.OptionalPadding` * :py:class:`dgenerate.types.OptionalPaddings` * :py:class:`dgenerate.types.Boolean` * :py:class:`dgenerate.types.OptionalBoolean` * :py:class:`dgenerate.types.Float` * :py:class:`dgenerate.types.Floats` * :py:class:`dgenerate.types.OptionalFloat` * :py:class:`dgenerate.types.OptionalFloats` * :py:class:`dgenerate.types.Integer` * :py:class:`dgenerate.types.Integers` * :py:class:`dgenerate.types.OptionalInteger` * :py:class:`dgenerate.types.OptionalIntegers` * :py:class:`dgenerate.types.String` * :py:class:`dgenerate.types.Strings` * :py:class:`dgenerate.types.OptionalString` * :py:class:`dgenerate.types.OptionalStrings` * :py:class:`dgenerate.types.Name` * :py:class:`dgenerate.types.Names` * :py:class:`dgenerate.types.OptionalName` * :py:class:`dgenerate.types.OptionalNames` * :py:class:`dgenerate.types.Uri` * :py:class:`dgenerate.types.Uris` * :py:class:`dgenerate.types.OptionalUri` * :py:class:`dgenerate.types.OptionalUris` :raise PromptEmbeddedArgumentError: If there was a problem applying the embedded arguments to the object. :param on_object: The object to set values on. :param forbidden_checker: This is a function that should return ``True`` if an argument name / value is forbidden to use. :param validate_only: Only run validation and do not set any values? """ if forbidden_checker and not callable(forbidden_checker): raise ValueError('forbidden_checker must be a callable function') hints = typing.get_type_hints(on_object) def is_forbidden(name, value): return forbidden_checker(name, value) if forbidden_checker else False def parse_padding(value): parsed_value = _textprocessing.parse_dimensions(value) length = len(parsed_value) if length > 4: raise ValueError('too many padding values.') if length == 3: raise ValueError('3 values is invalid for padding specification.') return parsed_value if length > 1 else parsed_value[0] def list_of(the_type, v): try: values = ast.literal_eval(v) if not isinstance(values, collections.abc.Iterable): raise ValueError() values = list(values) for idx, v in enumerate(values): values[idx] = the_type(v) return values except (ValueError, SyntaxError) as e: raise ValueError(f"Invalid value, expected iterable and got: {v}") from e def optional(the_type, v): return None if v == 'None' else the_type(v) type_parsers = { # dimensions _types.Size: _textprocessing.parse_image_size, _types.Sizes: functools.partial( list_of, _textprocessing.parse_image_size), _types.OptionalSize: functools.partial( optional, _textprocessing.parse_image_size), _types.OptionalSizes: functools.partial( optional, functools.partial(list_of, _textprocessing.parse_image_size)), # paddings _types.Padding: parse_padding, _types.Paddings: functools.partial(list_of, parse_padding), _types.OptionalPadding: functools.partial(optional, parse_padding), _types.OptionalPaddings: functools.partial(optional, functools.partial(list_of, parse_padding)), # bool _types.Boolean: _types.parse_bool, _types.OptionalBoolean: functools.partial(optional, _types.parse_bool), # float _types.Float: float, _types.Floats: functools.partial(list_of, float), _types.OptionalFloat: functools.partial(optional, float), _types.OptionalFloats: functools.partial(optional, functools.partial(list_of, float)), # int _types.Integer: int, _types.Integers: functools.partial(list_of, int), _types.OptionalInteger: functools.partial(optional, int), _types.OptionalIntegers: functools.partial(optional, functools.partial(list_of, int)), # string _types.String: str, _types.Strings: functools.partial(list_of, str), _types.OptionalString: functools.partial(optional, str), _types.OptionalStrings: functools.partial(optional, functools.partial(list_of, str)), # name _types.Name: str, _types.Names: functools.partial(list_of, str), _types.OptionalName: functools.partial(optional, str), _types.OptionalNames: functools.partial(optional, functools.partial(list_of, str)), # uri _types.Uri: str, _types.Uris: functools.partial(optional, functools.partial(list_of, str)), _types.OptionalUri: functools.partial(optional, str), _types.OptionalUris: functools.partial(optional, functools.partial(list_of, str)), } for name, value in self._embedded_args.items(): name = _textprocessing.dashdown(name) if is_forbidden(name, value): raise PromptEmbeddedArgumentError( f'Setting diffusion argument "{name}" from the prompt is forbidden.') if not hasattr(on_object, name): raise PromptEmbeddedArgumentError( f'Unknown embedded prompt argument: {name}') hint = hints.get(name, None) try: if hint in type_parsers: value = type_parsers[hint](value) else: raise PromptEmbeddedArgumentError( f'Setting diffusion argument "{name}" from the prompt is forbidden.') except (ValueError, SyntaxError): raise PromptEmbeddedArgumentError( f'Could not parse embedded prompt argument: {name}, value: {value}, into type: {hint}') if not validate_only: setattr(on_object, name, value)
@staticmethod def _get_embeded_args(value: str, embedded_arg_names: list[str] | None): embedded_args = {} def find_arg(match): nonlocal embedded_args name = match.group(1) arg_value = match.group(2) if name == 'upscaler': # If upscaler is already set, convert it to a list if name in embedded_args: if isinstance(embedded_args[name], list): embedded_args[name].append(arg_value) else: embedded_args[name] = [embedded_args[name], arg_value] else: embedded_args[name] = arg_value else: embedded_args[name] = arg_value return ' ' if embedded_arg_names: cleaned_value = value for arg_name in embedded_arg_names: cleaned_value = re.sub(rf"\s*<\s*{arg_name}\s*:\s*(.*?)\s*>\s*", find_arg, cleaned_value) else: cleaned_value = re.sub(r"\s*<\s*([a-zA-Z_-][a-zA-Z0-9_-]+)\s*:\s*(.*?)\s*>\s*", find_arg, value) return cleaned_value.strip(), embedded_args
[docs] @staticmethod def copy(prompt: 'Prompt'): """ Return a copy of another prompt. :param prompt: The prompt to copy. :return: A copy of the provided prompt. """ new_prompt = Prompt( positive=prompt.positive, negative=prompt.negative, delimiter=prompt.delimiter ) new_prompt.copy_embedded_args_from(prompt) return new_prompt
[docs] @staticmethod def parse( value: str, delimiter=';', parse_embedded_args: bool = True, embedded_arg_names: list[str] | None = None, ) -> 'Prompt': """ Parse the positive and negative prompt from a string and return a prompt object. :param value: the string :param delimiter: The prompt delimiter character :param parse_embedded_args: parse embedded args? ``< arg: value >`` :param embedded_arg_names: list of embedded argument names to parse, if ``None``, all are parsed. :raise ValueError: if value is ``None`` :return: :py:class:`.Prompt` (returns self) """ if value is None: raise ValueError('Input string may not be None.') weighter = None upscaler = None embedded_args = None if parse_embedded_args: value, embedded_args = Prompt._get_embeded_args( value, embedded_arg_names ) weighter = embedded_args.get('weighter', None) if weighter is not None: embedded_args.pop('weighter') upscaler = embedded_args.get('upscaler', None) if upscaler is not None: embedded_args.pop('upscaler') parse = value.split(delimiter, 1) if len(parse) == 1: positive = parse[0].strip() negative = None elif len(parse) == 2: positive = parse[0].strip() negative = parse[1].strip() else: positive = None negative = None return Prompt(positive=positive, negative=negative, delimiter=delimiter, weighter=weighter, upscaler=upscaler, embedded_args=embedded_args)
OptionalPrompt = typing.Optional[Prompt] Prompts = collections.abc.Sequence[Prompt] PromptOrPrompts = typing.Union[Prompt, collections.abc.Sequence[Prompt]] OptionalPrompts = typing.Optional[Prompts] OptionalPromptOrPrompts = typing.Optional[PromptOrPrompts] __all__ = _types.module_all()