# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import collections.abc
import glob
import shlex
import types
import typing
import dgenerate
import dgenerate.arguments as _arguments
import dgenerate.batchprocess.batchprocessor as _batchprocessor
import dgenerate.batchprocess.configrunnerpluginloader as _configrunnerpluginloader
import dgenerate.invoker as _invoker
import dgenerate.messages as _messages
import dgenerate.pipelinewrapper as _pipelinewrapper
import dgenerate.prompt as _prompt
import dgenerate.renderloop as _renderloop
import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types
[docs]
class ConfigRunner(_batchprocessor.BatchProcessor):
"""
A :py:class:`.BatchProcessor` that can run dgenerate batch processing configs from a string or file.
"""
@property
def plugin_module_paths(self) -> frozenset[str]:
"""
Set of plugin module paths if they were injected into the config runner by ``--plugin-modules``
or used in a ``\\import_plugins`` statement in a config.
:return: a set of paths, may be empty but not ``None``
"""
return frozenset(self._plugin_module_paths)
[docs]
def __init__(self,
injected_args: typing.Optional[collections.abc.Sequence[str]] = None,
render_loop: typing.Optional[_renderloop.RenderLoop] = None,
plugin_loader: _configrunnerpluginloader.ConfigRunnerPluginLoader = None,
version: typing.Union[_types.Version, str] = dgenerate.__version__,
throw: bool = False):
"""
:raises dgenerate.plugin.ModuleFileNotFoundError: If a module path parsed from
``--plugin-modules`` in ``injected_args`` could not be found on disk.
:param injected_args: dgenerate command line arguments in the form of a list, see: shlex module, or sys.argv.
These arguments will be injected at the end of every dgenerate invocation in the config. ``--plugin-modules``
are parsed from ``injected_args`` and added to ``plugin_loader``. If ``-v/--verbose`` is present in ``injected_args``
debugging output will be enabled globally while the config runs, and not just for invocations.
:param render_loop: RenderLoop instance, if ``None`` is provided one will be created.
:param plugin_loader: Batch processor plugin loader, if one is not provided one will be created.
:param version: Config version for ``#! dgenerate x.x.x`` version checks, defaults to ``dgenerate.__version__``
:param throw: Whether to throw exceptions from :py:func:`dgenerate.invoker.invoke_dgenerate` or handle them.
If you set this to ``True`` exceptions will propagate out of dgenerate invocations instead of a
:py:exc:`dgenerate.batchprocess.BatchProcessError` being raised by the created
:py:class:`dgenerate.batchprocess.BatchProcessor`. A line number where the error occurred can be
obtained using :py:attr:`dgenerate.batchprocess.BatchProcessor.current_line`.
"""
def invoker(args):
try:
return_code = \
_invoker.invoke_dgenerate(args,
render_loop=self.render_loop,
throw=throw)
if return_code == 0:
self.template_variables.update(self._generate_template_variables())
return return_code
finally:
self.render_loop.model_extra_modules = None
super().__init__(
invoker=invoker,
name='dgenerate',
version=version,
injected_args=injected_args if injected_args else [])
if render_loop is None:
render_loop = _renderloop.RenderLoop()
self.render_loop = render_loop
def _format_prompt(prompt):
pos = prompt.positive
neg = prompt.negative
if pos is None:
raise _batchprocessor.BatchProcessError('Attempt to format a prompt with no positive prompt value.')
if pos and neg:
return shlex.quote(f"{pos}; {neg}")
return shlex.quote(pos)
def format_prompt(string_or_iterable):
if isinstance(string_or_iterable, _prompt.Prompt):
return _format_prompt(string_or_iterable)
return ' '.join(_format_prompt(p) for p in string_or_iterable)
def quote(string_or_iterable):
if isinstance(string_or_iterable, str):
return shlex.quote(str(string_or_iterable))
return ' '.join(shlex.quote(str(s)) for s in string_or_iterable)
def unquote(string_or_iterable):
if isinstance(string_or_iterable, str):
return shlex.split(str(string_or_iterable))
return [shlex.split(str(s)) for s in string_or_iterable]
def last(list_or_iterable):
if isinstance(list_or_iterable, list):
return list_or_iterable[-1]
try:
*_, last_item = list_or_iterable
except ValueError:
raise _batchprocessor.BatchProcessError(
'Usage of template function "last" on an empty iterable.')
return last_item
def first(iterable):
try:
v = next(iter(iterable))
except StopIteration:
raise _batchprocessor.BatchProcessError(
'Usage of template function "first" on an empty iterable.')
return v
self.template_variables = {
'injected_args': self.injected_args,
'injected_device': _arguments.parse_device(self.injected_args)[0],
'injected_verbose': _arguments.parse_verbose(self.injected_args)[0],
'injected_plugin_modules': _arguments.parse_plugin_modules(self.injected_args)[0],
'saved_modules': dict(),
'glob': glob
}
self.template_variables = self._generate_template_variables()
self.reserved_template_variables = set(self.template_variables.keys())
self.template_functions = {
'unquote': unquote,
'quote': quote,
'format_prompt': format_prompt,
'format_size': _textprocessing.format_size,
'last': last,
'first': first
}
def return_zero(func, help):
def wrap(args):
func()
return 0
wrap.__doc__ = help
return wrap
self.directives = {
'templates_help': self._templates_help_directive,
'clear_model_cache': return_zero(
_pipelinewrapper.clear_model_cache,
help='Clear all user specified models from the in memory cache.'),
'clear_pipeline_cache': return_zero(
_pipelinewrapper.clear_pipeline_cache,
help='Clear all diffusers pipelines from the in memory cache, '
'this will not clear user specified VAEs, UNets, and ControlNet models, '
'just pipeline objects which may or may not have automatically loaded those for you.'),
'clear_unet_cache': return_zero(
_pipelinewrapper.clear_unet_cache,
help='Clear all user specified UNet models from the in memory cache.'),
'clear_vae_cache': return_zero(
_pipelinewrapper.clear_vae_cache,
help='Clear all user specified VAE models from the in memory cache.'),
'clear_control_net_cache': return_zero(
_pipelinewrapper.clear_control_net_cache,
help='Clear all user specified ControlNet models from the in memory cache.'),
'save_modules': self._save_modules_directive,
'use_modules': self._use_modules_directive,
'clear_modules': self._clear_modules_directive,
'gen_seeds': self._gen_seeds_directive,
'exit': self._exit_directive
}
self.plugin_loader = \
_configrunnerpluginloader.ConfigRunnerPluginLoader() if \
plugin_loader is None else plugin_loader
self._plugin_module_paths = set()
if injected_args:
self._plugin_module_paths.update(_arguments.parse_plugin_modules(injected_args)[0])
self.plugin_loader.load_plugin_modules(self._plugin_module_paths)
self.render_loop.image_processor_loader.load_plugin_modules(self._plugin_module_paths)
for plugin_class in self.plugin_loader.get_available_classes():
self.plugin_loader.load(plugin_class.get_names()[0],
config_runner=self,
render_loop=self.render_loop)
self.directives['import_plugins'] = self._import_plugins_directive
def _import_plugins_directive(self, plugin_paths: collections.abc.Sequence[str]):
"""
Imports plugins from within a config, this imports config plugins as well as image processor plugins.
This has an identical effect to the --plugin-modules argument. You may specify multiple plugin
module directories or python files containing plugin implementations.
"""
if len(plugin_paths) == 0:
raise _batchprocessor.BatchProcessError(
'\\import_plugins must be used with at least one argument.')
self._plugin_module_paths.update(plugin_paths)
self.render_loop.image_processor_loader.load_plugin_modules(plugin_paths)
new_classes = self.plugin_loader.load_plugin_modules(plugin_paths)
for cls in new_classes:
self.plugin_loader.load(cls.get_names()[0],
config_runner=self,
render_loop=self.render_loop)
return 0
def _exit_directive(self, args: collections.abc.Sequence[str]):
"""
Causes the dgenerate process to exit with a specific return code.
This directive accepts one argument, the return code, which is optional
and 0 by default. It must be an integer value.
"""
if (len(args)) == 0:
exit(0)
try:
return_code = int(args[0])
except ValueError:
raise _batchprocessor.BatchProcessError(
f'\\exit return code must be an integer value, received: {args[0]}')
exit(return_code)
def _save_modules_directive(self, args: collections.abc.Sequence[str]):
"""
Save a set of pipeline modules off the last diffusers pipeline used for the
main model of a dgenerate invocation. The first argument is a variable name
that the modules will be saved to, which can be reference later with \\use_modules.
The rest of the arguments are names of pipeline modules that you want to save to this
variable as a set of modules that are kept together, usable names are: unet, vae, text_encoder,
text_encoder_2, tokenizer, tokenizer_2, safety_checker, feature_extractor, controlnet,
scheduler, unet
"""
saved_modules = self.template_variables.get('saved_modules')
if len(args) < 2:
raise _batchprocessor.BatchProcessError(
'\\save_modules directive must have at least 2 arguments, '
'a variable name and one or more module names.')
if self.render_loop.pipeline_wrapper is None:
raise _batchprocessor.BatchProcessError(
'\\save_modules directive cannot be used until a '
'dgenerate invocation has occurred.')
creation_result = self.render_loop.pipeline_wrapper.recall_main_pipeline()
saved_modules[args[0]] = creation_result.get_pipeline_modules(args[1:])
return 0
def _use_modules_directive(self, args: collections.abc.Sequence[str]):
"""
Use a set of pipeline modules saved with \\save_modules, accepts one argument,
the name that set of modules was saved to.
"""
saved_modules = self.template_variables.get('saved_modules')
if not saved_modules:
raise _batchprocessor.BatchProcessError(
'\\use_modules error, no modules are currently saved that can be referenced.')
if len(args) != 1:
raise _batchprocessor.BatchProcessError(
'\\use_modules accepts one argument and one argument only, '
'the name that the modules were previously saved to with \\save_modules'
)
saved_name = args[0]
self.render_loop.model_extra_modules = saved_modules[saved_name]
return 0
def _clear_modules_directive(self, args: collections.abc.Sequence[str]):
"""
Clears a named set of pipeline modules saved with \\save_modules, accepts one argument, the name
that the set of modules was saved to. When no argument is provided, all modules ever
saved are cleared.
"""
saved_modules = self.template_variables.get('saved_modules')
if len(args) > 0:
for arg in args:
try:
del saved_modules[arg]
except KeyError:
raise _batchprocessor.BatchProcessError(
f'No pipeline modules were saved to the variable name "{arg}", '
f'that name could not be found.')
else:
saved_modules.clear()
return 0
def _gen_seeds_directive(self, args: collections.abc.Sequence[str]):
"""
Generate N random integer seeds and store them as a list to a template variable name.
The first argument is the variable name, the second argument is the number of seeds to generate.
"""
if len(args) == 2:
try:
self.template_variables[args[0]] = \
[str(s) for s in _renderloop.gen_seeds(int(args[1]))]
except ValueError:
raise _batchprocessor.BatchProcessError(
'The second argument of \\gen_seeds must be an integer value.')
else:
raise _batchprocessor.BatchProcessError(
'\\gen_seeds directive takes 2 arguments, template variable '
'name (to store value at), and number of seeds to generate.')
return 0
def _config_generate_template_variables_with_types(self) -> dict[str, tuple[type, typing.Any]]:
template_variables = {}
variable_prefix = 'last_'
for attr, hint in typing.get_type_hints(self.render_loop.config.__class__).items():
value = getattr(self.render_loop.config, attr)
if variable_prefix:
prefix = variable_prefix if not attr.startswith(variable_prefix) else ''
else:
prefix = ''
gen_name = prefix + attr
if gen_name not in template_variables:
if _types.is_type_or_optional(hint, collections.abc.Sequence):
t_val = value if value is not None else []
template_variables[gen_name] = (hint, t_val)
else:
template_variables[gen_name] = (hint, value)
template_variables.update({
'last_images': (collections.abc.Iterable[str], self.render_loop.written_images),
'last_animations': (collections.abc.Iterable[str], self.render_loop.written_animations),
})
return template_variables
def _generate_template_variables_with_types(self) -> dict[str, tuple[type, typing.Any]]:
template_variables = self._config_generate_template_variables_with_types()
template_variables['injected_args'] = (collections.abc.Sequence[str],
self.template_variables.get('injected_args'))
template_variables['injected_device'] = (_types.OptionalString,
self.template_variables.get('injected_device'))
template_variables['injected_verbose'] = (_types.OptionalBoolean,
self.template_variables.get('injected_verbose'))
template_variables['injected_plugin_modules'] = (_types.OptionalPaths,
self.template_variables.get('injected_plugin_modules'))
template_variables['saved_modules'] = (dict[str, dict[str, typing.Any]],
self.template_variables.get('saved_modules'))
template_variables['glob'] = (types.ModuleType, self.template_variables.get('glob'))
return template_variables
def _generate_template_variables(self) -> dict[str, typing.Any]:
return {k: v[1] for k, v in self._generate_template_variables_with_types().items()}
[docs]
def generate_directives_help(self, directive_names: typing.Optional[typing.Collection[str]] = None):
"""
Generate the help string for ``--directives-help``
:param directive_names: Display help for specific directives, if ``None`` or ``[]`` is specified, display all.
:raise ValueError: if given directive names could not be found
:return: help string
"""
directives: dict[str, typing.Union[str, typing.Callable]] = self.directives.copy()
directives.update({
'set': 'Sets a template variable, accepts two arguments, the variable name and the value. '
'Attempting to set a reserved template variable such as those pre-defined by dgenerate '
'will result in an error. The second argument is accepted as a raw value, it is not shell '
'parsed in any way, only striped of leading and trailing whitespace.',
'print': 'Prints all content to the right to stdout, no shell parsing of the argument occurs.'
})
if len(directive_names) == 0:
help_string = f'Available config directives:' + '\n\n'
help_string += '\n'.join((' ' * 4) + _textprocessing.quote('\\' + i) for i in directives.keys())
else:
help_string = ''
directive_names = {n.lstrip('\\') for n in directive_names}
if directive_names is not None and len(directive_names) > 0:
found = dict()
not_found = []
for n in directive_names:
if n not in directives:
not_found.append(n)
continue
found[n] = directives[n]
if not_found:
raise ValueError(
f'No directives named: {_textprocessing.oxford_comma(not_found, "or")}')
directives = found
def docs():
for name, impl in directives.items():
if isinstance(impl, str):
doc = impl
else:
doc = _textprocessing.justify_left(impl.__doc__).strip() \
if impl.__doc__ is not None else 'No documentation provided.'
doc = \
_textprocessing.wrap_paragraphs(
doc,
initial_indent=' ' * 4,
subsequent_indent=' ' * 4,
width=_textprocessing.long_text_wrap_width())
yield name + _textprocessing.underline(':\n\n' + doc + '\n')
help_string += '\n'.join(docs())
return help_string
[docs]
def generate_template_variables_help(self,
variable_names: typing.Optional[typing.Collection[str]] = None,
show_values: bool = True):
"""
Generate a help string describing available template variables, their types, and values for use in batch processing.
This is used for ``--templates-help``
:param variable_names: Display help for specific variables, if ``None`` or ``[]`` is specified, display all.
:param show_values: Show the value of the template variable or just the name?
:raise ValueError: if given variable names could not be found
:return: a human-readable description of all template variables
"""
values = self._generate_template_variables_with_types()
for k, v in self.template_variables.items():
if k not in values:
values[k] = (v.__class__, v)
if variable_names is not None and len(variable_names) > 0:
found = dict()
not_found = []
for n in variable_names:
if n not in values:
not_found.append(n)
continue
found[n] = values[n]
if not_found:
raise ValueError(
f'No template variables named: {_textprocessing.oxford_comma(not_found, "or")}')
values = found
if len(values) > 1:
header = 'Config template variables are'
else:
header = 'Config template variable is'
help_string = f'{header}:' + '\n\n'
def wrap(val):
return _textprocessing.wrap(
str(val),
width=_textprocessing.long_text_wrap_width(),
subsequent_indent=' ' * 17)
return help_string + '\n'.join(
' ' * 4 + f'Name: {_textprocessing.quote(i[0])}\n{" " * 8}'
f'Type: {i[1][0]}' + (f'\n{" " * 8}Value: {wrap(i[1][1])}' if show_values else '') for i in
values.items())
def _templates_help_directive(self, args: collections.abc.Sequence[str]):
"""
Prints all template variables in the global scope, with their types and values.
This does not cause the config to exit.
"""
_messages.log(self.generate_template_variables_help(args) + '\n')
return 0
__all__ = _types.module_all()