Source code for dgenerate.plugin

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import ast
import collections.abc
import importlib.machinery
import inspect
import itertools
import os
import sys
import types
import typing

import dgenerate.messages as _messages
import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types

__doc__ = """
URI based plugin loading system base implementations.
"""

LOADED_PLUGIN_MODULES: dict[str, types.ModuleType] = {}
"""Plugin module in memory cache"""


[docs] class PluginArg:
[docs] def __init__(self, name: str, type: type = typing.Any, **kwargs): self.name = name self.have_default = 'default' in kwargs self.default = kwargs['default'] if self.have_default else None self.options = kwargs.get('options', None) self.files = kwargs.get('files', None) self.type = type
@property def is_hinted_optional(self): return _types.is_optional(self.type) @property def hinted_optional_type(self): return _types.get_type_of_optional(self.type, get_origin=False) @property def base_type(self) -> type: if self.is_hinted_optional: return _types.get_type(self.hinted_optional_type) else: return _types.get_type(self.type)
[docs] def name_dashup(self) -> 'PluginArg': r = PluginArg(_textprocessing.dashup(self.name)) r.have_default = self.have_default r.default = self.default r.type = r.type return r
[docs] def name_dashdown(self) -> 'PluginArg': r = PluginArg(_textprocessing.dashdown(self.name)) r.have_default = self.have_default r.default = self.default r.type = r.type return r
[docs] def type_string(self): if not _types.is_typing_hint(self.type): return self.type.__name__ return str(self.type).replace('typing.', '')
[docs] def validate_non_parsed(self, value: typing.Any): if (self.is_hinted_optional or (self.have_default and self.default is None)) \ and value is None: return None if self.base_type is typing.Any: return value if self.base_type is float and isinstance(value, int): return float(value) if isinstance(value, self.base_type): return value if _types.is_union(self.base_type) and isinstance(value, int): union_types = typing.get_args(self.type) has_int_type = any(t is int or _types.get_type(t) is int for t in union_types) if not has_int_type: for t in union_types: if t is float or _types.get_type(t) is float: return float(value) raise ValueError( f'Literal type "{value.__class__.__name__}" ' f'does not match plugin argument "{self.name}" type ' f'hint "{self.type_string()}".' )
[docs] def parse_by_type(self, value: str | typing.Any): if not isinstance(value, str): if isinstance(value, int) and self.base_type is float: return float(value) return value base_type = self.base_type try: if not _types.is_typing_hint(base_type) or base_type is typing.Any: if base_type is bool: return _types.parse_bool(value) if any(base_type is t for t in (list, tuple, dict, set, typing.Any)): try: evaled = ast.literal_eval(value) except ValueError: if base_type is typing.Any: return value raise if base_type is not typing.Any and not isinstance(evaled, base_type): if base_type is float and isinstance(evaled, int): return float(evaled) if not self.is_hinted_optional or evaled is not None: raise ValueError( f'Literal type "{evaled.__class__.__name__}" ' f'does not match plugin argument "{self.name}" type ' f'hint "{self.type_string()}".') return evaled return base_type(value) if _types.is_union(base_type): try: evaled = ast.literal_eval(value) except (ValueError, SyntaxError): # string evaled = value failures = 0 union_types = typing.get_args(self.type) for t in union_types: if _types.is_type(t, type(evaled)): continue else: failures += 1 if failures == len(union_types) and isinstance(evaled, int): has_int_type = any(t is int or _types.get_type(t) is int for t in union_types) if not has_int_type: for t in union_types: if t is float or _types.get_type(t) is float: evaled = float(evaled) failures = 0 break if failures == len(union_types): raise ValueError( f'Literal type "{evaled.__class__.__name__}" ' f'does not match plugin argument "{self.name}" type ' f'hint "{self.type_string()}".') return evaled return value except SyntaxError as e: if base_type is typing.Any: return value offset = e.offset - 1 if e.offset > 0 else 0 raise ValueError(f'Syntax Error: {e.text[:offset]}[ERROR HERE>]{e.text[offset:]}') from e
def __str__(self): default_part = f', default={repr(self.default)}' if self.have_default else '' return f'{self.__class__.__name__}(name="{self.name}", type={self.type}{default_part})' def __repr__(self): return str(self)
[docs] class PluginArgumentError(Exception): """ Raised when a plugin encounters an error in the arguments it is loaded by. Or errors in arguments used for execution. """ pass
[docs] class Plugin:
[docs] def __init__(self, loaded_by_name: str | None = None, argument_error_type: type[PluginArgumentError] = PluginArgumentError, **kwargs): """ :param loaded_by_name: The name the plugin was loaded by, will be passed by the loader. If ``None`` is passed, the first name mentioned by the plugin implementation will be used. This can simplify using some plugin classes directly without loading them through a loader implementation. :param argument_error_type: This exception type will be raised upon argument errors (invalid arguments) when loading a plugin using a :py:class:`.PluginLoader` implementation. It should match the ``argument_error_type`` given to the :py:class:`.PluginLoader` implementation being used to load the inheritor of this class. :param kwargs: Additional arguments that may arise when using an ``ARGS`` static signature definition with multiple ``NAMES`` in your implementation. """ if loaded_by_name is None: loaded_by_name = self.get_names()[0] self.__loaded_by_name = loaded_by_name self.__argument_error_type = argument_error_type # Store the constructor arguments for __str__ method self.__constructor_args_strs = {'loaded-by-name': loaded_by_name} for arg in self.__class__.get_accepted_args(loaded_by_name): snake_name = _textprocessing.dashdown(arg.name) if snake_name in kwargs: self.__constructor_args_strs[arg.name] = str(kwargs[snake_name]) elif arg.have_default: self.__constructor_args_strs[arg.name] = str(arg.default)
[docs] def argument_error(self, msg: str): """ Return an constructed exception that is suitable for raising as an argument error for this plugin. Example: ``raise self.argument_error('oops!')`` :param msg: exception message :return: the exception object, you must ``raise`` it. """ return self.__argument_error_type(msg)
[docs] @classmethod def get_names(cls) -> list[str]: """ Get the names that this class can be loaded by. :return: """ if hasattr(cls, 'NAMES'): if isinstance(cls.NAMES, str): return [cls.NAMES] else: return cls.NAMES else: return [_types.fullname(cls)]
[docs] @classmethod def get_help(cls, loaded_by_name: str, wrap_width: int | None = None, include_bases: bool = False) -> str: """ Get formatted help information about the plugin. This includes any implemented help strings and an auto formatted description of the plugins accepted arguments. :param loaded_by_name: The name used to load the plugin. Help may vary depending on how many names the plugin implementation handles and what loading it by a certain name does. :param wrap_width: wrap paragraphs to this width. :param include_bases: include argument names and inherited help from base classes? :return: Formatted help string """ help_str = None if hasattr(cls, 'help'): help_str = cls.help(loaded_by_name) if help_str: help_str = inspect.cleandoc(help_str).strip() elif cls.__doc__: help_str = inspect.cleandoc(cls.__doc__).strip() if include_bases: base_help = dict() for base in cls.get_bases(): if hasattr(base, 'inheritable_help'): bhelp_val = base.inheritable_help(loaded_by_name) if bhelp_val: base_help.update(bhelp_val) if hasattr(cls, 'inheritable_help'): bhelp_val = cls.inheritable_help(loaded_by_name) if bhelp_val: base_help.update(bhelp_val) hidden_args = cls.get_hidden_args(loaded_by_name) base_help = '\n\n'.join( message for arg, message in base_help.items() if arg not in hidden_args) if base_help: help_str = help_str + '\n\n' + base_help args_with_defaults = cls.get_accepted_args( loaded_by_name, include_bases=include_bases) arg_descriptors = [] for arg in args_with_defaults: if not arg.have_default: arg_descriptors.append(arg.name + ': ' + arg.type_string()) else: default_value = arg.default if isinstance(default_value, str): default_value = _textprocessing.quote(default_value) arg_descriptors.append(f'{arg.name}: {arg.type_string()} = {default_value}') if arg_descriptors: args_part = f'\n{" " * 4}arguments:\n{" " * 8}{(chr(10) + " " * 8).join(arg_descriptors)}\n' else: args_part = '\n' if help_str: wrap = \ _textprocessing.wrap_paragraphs( help_str, initial_indent=' ' * 4, subsequent_indent=' ' * 4, width=_textprocessing.long_text_wrap_width() if wrap_width is None else wrap_width) return loaded_by_name + f':{args_part}\n' + wrap else: return loaded_by_name + f':{args_part}'
[docs] @classmethod def get_required_args(cls, loaded_by_name: str) -> list[PluginArg]: """ Get a list of required arguments for this plugin class. :param loaded_by_name: The name used to load the plugin. Required arguments may vary by name used to load. :return: list of argument names """ return [a for a in cls.get_accepted_args(loaded_by_name) if not a.have_default]
[docs] @classmethod def get_default_args(cls, loaded_by_name: str) -> list[PluginArg]: """ Get the names and values of arguments for this plugin that possess default values. :param loaded_by_name: The name used to load the plugin. Default arguments may vary by name used to load. :return: list of arguments with default value: (name, value) """ return [a for a in cls.get_accepted_args(loaded_by_name) if a.have_default]
[docs] @classmethod def get_bases(cls) -> list[typing.Type['Plugin']]: """ Return a list of base classes, except for :py:class:`Plugin` :return: list of class type objects """ return list(c for c in _types.get_all_base_classes(cls) if issubclass(c, Plugin) and c is not Plugin)
[docs] @classmethod def get_accepted_args_schema(cls, loaded_by_name: str, include_bases: bool = False): """ Reduce the accepted arguments to a schema dict. Keyed by argument ``name``, content keys include: ``default`` contains any default value, this key may not exist if the argument has no default value. ``types`` contains all accepted types for the argument in string form. ``optional`` can the argument accept the value ``None``? :param loaded_by_name: Plugin loaded by name :param include_bases: Include all base classes except :py:class:`Plugin`? :return: dict """ schema = dict() for arg in cls.get_accepted_args(loaded_by_name, include_bases=include_bases): entry = {} schema[arg.name] = entry def _type_name(t): return (str(t) if t.__module__ != 'builtins' else t.__name__).strip() def _resolve_union(t): name = _type_name(t) if _types.is_union(t): return set(itertools.chain.from_iterable( [_resolve_union(t) for t in arg.type.__args__])) return [name] type_name = _type_name(arg.type) if _types.is_union(arg.type): union_args = _resolve_union(arg.type) if 'NoneType' in union_args: entry['optional'] = True union_args.remove('NoneType') # deterministic order is ideal # for schemas in version control entry['types'] = list(sorted(union_args)) else: entry['optional'] = False entry['types'] = [type_name] if arg.have_default: schema[arg.name]['default'] = arg.default if arg.options: schema[arg.name]['options'] = arg.options if arg.files: schema[arg.name]['files'] = arg.files return schema
[docs] @classmethod def get_hidden_args(cls, loaded_by_name: str) -> set[str]: """ Get argument names that have been explicitly hidden from use or disabled by the plugin for URI use. i.e. ``HIDE_ARGS``. These may be unsupported arguments inherited from a base class, or just arguments the plugin does not want you to use via a URI. These arguments can still be passed manually from code in the interest of maintaining a generic interface, but they may be ignored by the plugin implementation. .. code-block:: python HIDE_ARGS = ['hide-me', 'dont-use-me'] # or HIDE_ARGS = { "plugin_name": ['hide-me', 'dont-use-me'], "alt_plugin_name": ['hide-me-2'] } :param loaded_by_name: The name used to load the plugin. Argument signature may vary by name used to load. :return: hidden argument names """ args_hidden = set() for base in cls.mro(): if hasattr(base, 'HIDE_ARGS'): if isinstance(base.HIDE_ARGS, dict): if loaded_by_name in base.HIDE_ARGS: for arg in base.HIDE_ARGS[loaded_by_name]: args_hidden.add(_textprocessing.dashup(arg)) else: for arg in base.HIDE_ARGS: args_hidden.add(_textprocessing.dashup(arg)) return args_hidden
[docs] @classmethod def get_option_args(cls, loaded_by_name: str) -> dict[str, list]: """ Get argument names that have an associated list of option valid values. i.e. ``OPTION_ARGS``. This returns metadata provided by the plugin about specific arguments which accept a limited set of values, such as shape names like ``circle`` or ``rectangle``` etc. ``OPTION_ARGS`` should be in the form of a dict, where the keys are argument names, and the values are lists of valid values for that argument. .. code-block:: python OPTION_ARGS = { "shape": ['circle', 'rectangle', 'triangle'], } If the plugin supports multiple names for loading, this can be a dict of dicts, where the outer keys are the names used to load the plugin, and the inner keys are the argument names. .. code-block:: python OPTION_ARGS = { "plugin_name": { "shape": ['circle', 'rectangle', 'triangle'] } } These values can be inherited. :param loaded_by_name: The name used to load the plugin. Argument signature may vary by name used to load. :return: options argument names """ args_option = dict() for base in cls.mro(): if hasattr(base, 'OPTION_ARGS'): if isinstance(base.OPTION_ARGS, dict): values = list(base.OPTION_ARGS.values()) if values and isinstance(values[0], dict): # segregated by loaded_by_name if loaded_by_name in base.OPTION_ARGS: options = base.OPTION_ARGS[loaded_by_name] args_option.update({_textprocessing.dashup(k): list(v) for k, v in options.items()}) else: args_option.update({_textprocessing.dashup(k): list(v) for k, v in base.OPTION_ARGS.items()}) else: raise RuntimeError( f'Plugin metadata {cls.__name__}.OPTION_ARGS must be a dict.') return args_option
[docs] @classmethod def get_file_args(cls, loaded_by_name: str) -> dict[str, dict[str, typing.Any]]: """ Get argument names that have an associated list of option valid file types. i.e. ``FILE_ARGS``. This returns metadata provided by the plugin about specific arguments which accept a limited set of file types, such as model file types or config file types. ``FILE_ARGS`` should be a dictionary where keys are argument names, for example: .. code-block:: python FILE_ARGS = { "model": { "mode": "in/out" "filetypes": [('Model', ['*.safetensors', '*.pt'])] } } # only accepts directories FILE_ARGS = { "model": { "mode": "dir" } } # accepts input files or directories FILE_ARGS = { "model": { "mode": ["in", "dir"] "filetypes": [('Model', ['*.safetensors', '*.pt'])] } } # output files or directories # (mutually exclusive with "in" mode) FILE_ARGS = { "model": { "mode": ["out", "dir"] "filetypes": [('Model', ['*.safetensors', '*.pt'])] } } If the plugin supports multiple names for loading, this can be a dict of dicts, where the outer keys are the names used to load the plugin, and the inner keys are the argument names. .. code-block:: python FILE_ARGS = { "plugin_name": { "model": { "mode": "in/out" "filetypes": [('Model', ['*.safetensors', '*.pt'])] } } } These values can be inherited. :param loaded_by_name: The name used to load the plugin. Argument signature may vary by name used to load. :return: options argument names """ args_file = dict() def _descriptor_format(d: dict[str, typing.Any]): if not isinstance(d, dict): raise RuntimeError( f'Plugin metadata {cls.__name__}.FILE_ARGS descriptor ' f'values must be dicts, not {type(d).__name__}.') if 'mode' not in d: raise RuntimeError( f'Plugin metadata {cls.__name__}.FILE_ARGS descriptor missing "mode" key.' ) mode = d['mode'] if isinstance(mode, list): if any(m not in {'in', 'out', 'dir'} for m in mode): raise RuntimeError( f'Plugin metadata {cls.__name__}.FILE_ARGS descriptor "mode" ' f'values must be "in", "out", or "dir", not: {d["mode"]}.' ) else: if 'in' in mode and 'out' in mode: raise RuntimeError( f'Plugin metadata {cls.__name__}.FILE_ARGS descriptor "mode" ' f'cannot contain both "in" and "out" at the same time: {d["mode"]}.' ) elif isinstance(mode, str) and mode not in {'in', 'out', 'dir'}: raise RuntimeError( f'Plugin metadata {cls.__name__}.FILE_ARGS descriptor "mode" ' f'values must be "in", "out", or "dir", not: {d["mode"]}.' ) dir_only = (isinstance(mode, list) and all(m == 'dir' for m in mode)) or mode == 'dir' if 'filetypes' not in d and not dir_only: raise RuntimeError( f'Plugin metadata {cls.__name__}.FILE_ARGS descriptor missing ' f'"filetypes" key and not in directory only "mode".' ) if not dir_only: if not isinstance(d['filetypes'], (list, tuple)) or not all( isinstance(ft, (tuple, list)) and len(ft) == 2 and isinstance(ft[0], str) and isinstance(ft[1], (tuple, list)) for ft in d['filetypes'] ): raise RuntimeError( f'Plugin metadata {cls.__name__}.FILE_ARGS descriptor "filetypes" ' f'must be a list/tuple of lists/tuples, where each tuple contains ' f'a descriptor string and a list of file extensions, not: {d["filetypes"]}.' ) return d for base in cls.mro(): if hasattr(base, 'FILE_ARGS'): if isinstance(base.FILE_ARGS, dict): values = list(base.FILE_ARGS.values()) if (values and isinstance(values[0], dict) and isinstance(list(values[0].values())[0], dict)): # segregated by loaded_by_name if loaded_by_name in base.FILE_ARGS: options = base.FILE_ARGS[loaded_by_name] args_file.update( {_textprocessing.dashup(k): _descriptor_format(v) for k, v in options.items()} ) else: args_file.update( {_textprocessing.dashup(k): _descriptor_format(v) for k, v in base.FILE_ARGS.items()} ) else: raise RuntimeError( f'Plugin metadata {cls.__name__}.FILE_ARGS must be a dict.') return args_file
[docs] @classmethod def get_accepted_args(cls, loaded_by_name: str, include_bases: bool = False): """ Retrieve the argument signature of a plugin implementation. :param loaded_by_name: The name used to load the plugin. Argument signature may vary by name used to load. :param include_bases: Include all base classes except :py:class:`Plugin`? :return: List of argument descriptors, :py:class:`.PluginArg` """ if include_bases: rest = itertools.chain.from_iterable( c._get_accepted_args(loaded_by_name) for c in cls.get_bases()) else: rest = [] arg_name_set = dict() hidden_args = cls.get_hidden_args(loaded_by_name) for a in itertools.chain(cls._get_accepted_args(loaded_by_name), rest): if a.name == 'loaded-by-name': continue if a.name in hidden_args: continue if a.name in arg_name_set: raise RuntimeError( 'Cannot handle base classes with shadowed constructor arguments.') else: arg_name_set[a.name] = a return list(arg_name_set.values())
@classmethod def _get_accepted_args(cls, loaded_by_name: str) -> list[PluginArg]: if hasattr(cls, 'ARGS'): args_with_defaults = [] if isinstance(cls.ARGS, dict): if loaded_by_name in cls.ARGS: args_with_defaults = cls.ARGS.get(loaded_by_name) else: args_with_defaults = cls.ARGS fixed_args = [] for arg in args_with_defaults: if not isinstance(arg, PluginArg): raise RuntimeError( f'{cls.__name__}.ARGS["{loaded_by_name}"] ' f'contained a non PluginArg value: {arg}') fixed_args.append(arg.name_dashup()) return [] if fixed_args is None else fixed_args args_with_defaults = [] option_args = cls.get_option_args(loaded_by_name) file_args = cls.get_file_args(loaded_by_name) spec = list(_types.get_accepted_args_with_defaults(cls.__init__))[1:] hints = typing.get_type_hints(cls.__init__) for arg in spec: name = arg[0] hint = hints.get(name) extra = {} name = _textprocessing.dashup(name) if hint is not None: extra['type'] = hint if name in option_args: extra['options'] = option_args[name] if name in file_args: extra['files'] = file_args[name] if len(arg) == 1: args_with_defaults.append( PluginArg(name, **extra)) else: args_with_defaults.append( PluginArg(name, default=arg[1], **extra)) return args_with_defaults @property def loaded_by_name(self) -> str: """ The name the plugin was loaded by. :return: name """ return self.__loaded_by_name def __str__(self): """ Return a string representation of this plugin instance, showing all constructor arguments. Format: ClassName(arg1=value1, arg2=value2, ...) Child classes can override this method if needed. :return: string representation """ return f'{self.__class__.__name__}({", ".join(f"{k}={v}" for k, v in self.__constructor_args_strs.items())})' def __repr__(self): return str(self)
[docs] class ModuleFileNotFoundError(FileNotFoundError): """ Raised by :py:func:`.load_modules` if a module could not be found on disk. """ pass
[docs] def load_modules(paths: collections.abc.Iterable[str]) -> list[types.ModuleType]: """ Load python modules from a folder, directly from a .py file, or from a python module installed in the environment. Cache them so that repeat requests for loading return an already loaded module. :raises ModuleFileNotFoundError: If a module path could not be found on disk, or when a module could not be loaded from the python environment. :param paths: list of folder/file paths, or references to python modules installed in the environment :return: list of :py:class:`types.ModuleType` """ r = [] for plugin_path in paths: if os.path.exists(plugin_path): plugin_path, ext = os.path.splitext(os.path.abspath(plugin_path)) if not ext: plugin_path = os.path.join(plugin_path, '__init__.py') else: plugin_path += ext if plugin_path in LOADED_PLUGIN_MODULES: mod = LOADED_PLUGIN_MODULES[plugin_path] else: try: mod = importlib.machinery.SourceFileLoader(plugin_path, plugin_path).load_module() except FileNotFoundError as e: raise ModuleFileNotFoundError(e) from e LOADED_PLUGIN_MODULES[plugin_path] = mod r.append(mod) else: try: mod = importlib.import_module(plugin_path) except Exception as e: raise ModuleFileNotFoundError(e) from e LOADED_PLUGIN_MODULES[plugin_path] = mod r.append(mod) return r
PluginArgumentsDef = list[PluginArg] | None
[docs] class PluginNotFoundError(Exception): """ Raised when a plugin could not be located by a name. """ pass
PLUGIN_PATHS = set() """ Plugin paths that are considered by all :py:class:`PluginLoader` instances. This should be updated with :py:func:`import_plugins` """
[docs] def import_plugins(paths: collections.abc.Iterable[str]): """ Set plugin paths that will be considered by all plugin loader instances. :param paths: environment modules, python script paths, directory paths """ PLUGIN_PATHS.update(paths)
[docs] class PluginLoader:
[docs] def __init__(self, base_class=Plugin, description: str = "plugin", reserved_args: PluginArgumentsDef = None, argument_error_type: type[PluginArgumentError] = PluginArgumentError, not_found_error_type: type[PluginNotFoundError] = PluginNotFoundError): """ :param base_class: Base class of plugins, will be used for searching modules. :param description: Short plugin description / name, used in exception messages. :param reserved_args: Constructor arguments that are used by the plugin class which cannot be redefined by implementors of the plugin class. This should be a list of plugin argument descriptors, :py:class:`.PluginArg` :param argument_error_type: This exception type will be raised when the plugin is loaded with invalid URI arguments. :param not_found_error_type: This exception type will be raised when a plugin could not be located by a name specified in a loading URI. """ self.__classes = set() self.__classes_by_name = dict() self.__plugin_module_paths = set() self.__reserved_args = reserved_args if reserved_args else [] self.__argument_error_type = argument_error_type self.__not_found_error_type = not_found_error_type self.__description = description self.__base_class = base_class self.load_plugin_modules(PLUGIN_PATHS)
@property def plugin_module_paths(self) -> frozenset[str]: """ Every module path ever seen by :py:meth:`PluginLoader.load_plugin_modules`. :return: frozen set """ self.load_plugin_modules(PLUGIN_PATHS) return frozenset(self.__plugin_module_paths)
[docs] def add_class(self, cls: type[Plugin]): """ Add an implementation class to this loader. :raises RuntimeError: If the added class specifies a name that already exists in this loader. :param cls: the class """ if cls in self.__classes or (hasattr(cls, 'HIDDEN') and getattr(cls, 'HIDDEN')): # no-op return for name in cls.get_names(): if name in self.__classes_by_name: raise RuntimeError( f'plugin class using the name {name} already exists.') self.__classes_by_name[name] = cls self.__classes.add(cls)
[docs] def add_search_module_string(self, string: str) -> list[type[Plugin]]: """ Add a module string (in sys.modules) that will be searched for implementations. :param string: the module string :return: list of classes that were newly discovered """ classes = self._load_classes([sys.modules[string]]) for cls in classes: self.add_class(cls) return classes
[docs] def add_search_module(self, module: types.ModuleType) -> list[type[Plugin]]: """ Directly add a module object that will be searched for implementations. :param module: the module object :raises ValueError: If ``module`` is not a python module object. :return: list of classes that were newly discovered """ if not isinstance(module, types.ModuleType): raise ValueError('passed object in not a python module') classes = self._load_classes([module]) for cls in classes: self.add_class(cls) return classes
[docs] def load_plugin_modules(self, paths: collections.abc.Iterable[str]) -> list[type[Plugin]]: """ Modules that will be loaded from disk, or the python environment, and searched for implementations. Either python files, or module directories containing __init__.py, or names of python modules installed in the environment. It can be a mix of these. :raises ModuleFileNotFoundError: If a module path could not be found on disk, or when a module could not be loaded from the python environment. :param paths: list of folder/file paths, or references to python modules installed in the environment :return: list of classes that were newly discovered """ paths = set(paths) paths.update(PLUGIN_PATHS) classes = self._load_classes(load_modules( [path for path in paths if path not in self.__plugin_module_paths])) self.__plugin_module_paths.update(paths) for cls in classes: self.add_class(cls) return classes
def _load_classes(self, modules: collections.abc.Iterable[types.ModuleType]): found_classes = set() for mod in modules: def _excluded(cls): try: if cls in self.__classes: return True except TypeError: # handle un-hashable return True if not inspect.isclass(cls): return True if cls is self.__base_class: return True if not issubclass(cls, self.__base_class): return True if hasattr(cls, 'HIDDEN'): return cls.HIDDEN else: return False found_classes.update([value for value in _types.get_public_members(mod).values() if not _excluded(value)]) return list(found_classes)
[docs] def get_available_classes(self) -> list[type[Plugin]]: """ Get classes seen by this plugin loader. :return: list of classes (types) """ self.load_plugin_modules(PLUGIN_PATHS) return list(self.__classes)
[docs] def get_class_by_name(self, plugin_name: _types.Name) -> type[Plugin]: """ Get a plugin class by one of its names. IE: one of the names listed in its ``NAMES`` static attribute. :param plugin_name: a name associated with a plugin class :raises PluginNotFoundError: If the plugin name could not be found. :return: class (type) """ self.load_plugin_modules(PLUGIN_PATHS) cls = self.__classes_by_name.get(plugin_name) if cls is None: raise self.__not_found_error_type( f'Found no {self.__description} with the name: {plugin_name}') return cls
[docs] def get_all_names(self) -> _types.Names: """ Get all plugin names that this loader can see. :return: list of names (strings) """ self.load_plugin_modules(PLUGIN_PATHS) return list(self.__classes_by_name.keys())
[docs] def get_accepted_args_schema(self, include_bases: bool = False) -> dict[str, dict[str, typing.Any]]: """ Get a :py:meth:`Plugin.get_accepted_args_schema` for every plugin class, keyed by callable plugin name. :param include_bases: Include base class arguments? This excludes the base :py:class:`Plugin` :return: dict """ schema = dict() for name in self.get_all_names(): schema[name] = self.get_class_by_name(name).get_accepted_args_schema( name, include_bases=include_bases) return schema
[docs] def get_help(self, plugin_name: _types.Name, wrap_width: int | None = None, include_bases: bool = False) -> str: """ Get a formatted help string for a plugin by one of its loadable names. :param plugin_name: a name associated with the plugin class :param wrap_width: wrap paragraphs to this width. :param include_bases: include argument names and inherited help from base classes? :raises PluginNotFoundError: If the plugin name could not be found. :return: formatted string """ return self.get_class_by_name(plugin_name).get_help( plugin_name, wrap_width=wrap_width, include_bases=include_bases)
[docs] def load(self, uri: _types.Uri, **kwargs) -> Plugin: """ Load an plugin using a URI string containing its name and arguments. :param uri: The URI string :param kwargs: default argument values, will be override by arguments specified in the URI :raises ValueError: If uri is ``None`` :raises RuntimeError: If a plugin is discovered to be using a reserved argument name upon loading it. :raises dgenerate.plugin.PluginArgumentError: If there is an error in the loading arguments for the plugin. :raises dgenerate.plugin.PluginNotFoundError: If the plugin name mentioned in the URI could not be found. :return: plugin instance """ if uri is None: raise ValueError('uri must not be None') self.load_plugin_modules(PLUGIN_PATHS) loaded_by_name = uri.split(';', 1)[0].strip() plugin_class = self.get_class_by_name(loaded_by_name) parser_accepted_args = [a.name for a in plugin_class.get_accepted_args(loaded_by_name)] parser_raw_args = [a.name for a in plugin_class.get_accepted_args(loaded_by_name) if a.base_type not in (int, str, float, bool)] if 'loaded-by-name' in parser_accepted_args: # inheritors of base_class can't define this raise RuntimeError(f'"loaded-by-name" is a reserved {self.__description} module argument, ' 'chose another argument name for your module.') hidden_args = plugin_class.get_hidden_args(loaded_by_name) for module_arg in self.__reserved_args: # reserved args always go into **kwargs # inheritors of base_class if module_arg.name in parser_accepted_args: raise RuntimeError(f'"{module_arg}" is a reserved {self.__description} module argument, ' 'chose another argument name for your module.') if module_arg.name in hidden_args: continue parser_accepted_args.append(module_arg.name) arg_parser = _textprocessing.ConceptUriParser( self.__description, known_args=parser_accepted_args, args_raw=parser_raw_args) try: parsed_args = arg_parser.parse(uri).args except _textprocessing.ConceptUriParseError as e: raise self.__argument_error_type(str(e)) from e args_dict = {} for arg in plugin_class.get_default_args(loaded_by_name): # defaults specified by the implementation class args_dict[_textprocessing.dashdown(arg.name)] = arg.default for reserved_arg in self.__reserved_args: # defaults specified by the loader snake_case = _textprocessing.dashdown(reserved_arg.name) try: if reserved_arg.have_default: args_dict[snake_case] = reserved_arg.parse_by_type( parsed_args.get(reserved_arg.name, reserved_arg.default)) else: if reserved_arg.name in parsed_args: args_dict[snake_case] = reserved_arg.parse_by_type( parsed_args.get(reserved_arg.name)) elif snake_case not in kwargs: # Nothing provided this reserved argument value if reserved_arg.is_hinted_optional: args_dict[snake_case] = None else: raise self.__argument_error_type( f'Missing required argument "{reserved_arg.name}" for {self.__description} ' f'"{loaded_by_name}".') except ValueError as e: raise self.__argument_error_type( f'Argument "{reserved_arg.name}" must match type: "{reserved_arg.type_string()}". ' f'Failure cause: {str(e).strip()}') accepted_args = {_textprocessing.dashup(n.name): n for n in itertools.chain(plugin_class.get_accepted_args(loaded_by_name=loaded_by_name), self.__reserved_args)} # plugin user in code arguments for k, v in kwargs.items(): try: arg = accepted_args[_textprocessing.dashup(k)] except KeyError: raise self.__argument_error_type( f'Unknown plugin argument: "{k}"' ) try: args_dict[k] = arg.validate_non_parsed(v) except ValueError as e: raise self.__argument_error_type( f'Argument "{k}" must match type: "{arg.type_string()}". ' f'Failure cause: {str(e).strip()}') for k, v in parsed_args.items(): # URI overrides everything arg = accepted_args[k] try: args_dict[_textprocessing.dashdown(k)] = arg.parse_by_type(v) except ValueError as e: raise self.__argument_error_type( f'Argument "{k}" must match type: "{arg.type_string()}". ' f'Failure cause: {str(e).strip()}') # Automagic argument args_dict['loaded_by_name'] = loaded_by_name for arg_name, plugin_arg in ((k, v) for k, v in accepted_args.items() if not v.have_default): snake_case = _textprocessing.dashdown(arg_name) if snake_case not in args_dict: if plugin_arg.is_hinted_optional: args_dict[snake_case] = None else: raise self.__argument_error_type( f'Missing required argument "{arg_name}" for {self.__description} "{loaded_by_name}".') try: return plugin_class(**args_dict) except self.__argument_error_type as e: raise self.__argument_error_type( f'Invalid argument given to {self.__description} ' f'"{loaded_by_name}": {str(e).strip()}')
[docs] def loader_help(self, names: _types.Names, plugin_module_paths: _types.OptionalPaths = None, title='plugin', title_plural='plugins', throw=False, log_error=True, include_bases: bool = False): """ Implements ``--sub-command-help`` and ``--image-processor-help`` command line options for example. :param names: arguments (sub-command names, or empty list) :param plugin_module_paths: extra plugin module paths to search :param title: plugin title, used in messages :param title_plural: plural plugin title, used in messages :param throw: throw on error? :param log_error: log errors to stderr? :param include_bases: include argument names from base classes? :raises PluginNotFoundError: ``names`` contained an unknown plugin name :raises ModuleFileNotFoundError: ``plugin_module_paths`` contained a missing module :return: return-code, anything other than 0 is failure """ self.load_plugin_modules(PLUGIN_PATHS) if plugin_module_paths is not None: try: self.load_plugin_modules(plugin_module_paths) except ModuleFileNotFoundError as e: if log_error: _messages.error( f'Plugin module could not be found: {str(e).strip()}' ) if throw: raise return 1 if len(names) == 0: available = ('\n' + ' ' * 4).join(_textprocessing.quote(name) for name in sorted(self.get_all_names())) _messages.log( f'Available {title_plural}:\n\n{" " * 4}{available}') return 0 help_strs = [] for name in names: try: help_strs.append(self.get_help(name, include_bases=include_bases)) except PluginNotFoundError: if log_error: _messages.error( f'An {title} with the name of "{name}" could not be found.' ) if throw: raise return 1 for help_str in help_strs: _messages.log(help_str + '\n', underline=True) return 0