# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import ast
import collections.abc
import importlib.machinery
import inspect
import itertools
import os
import sys
import types
import typing
import dgenerate.messages as _messages
import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types
__doc__ = """
URI based plugin loading system base implementations.
"""
LOADED_PLUGIN_MODULES: dict[str, types.ModuleType] = {}
"""Plugin module in memory cache"""
[docs]
class PluginArg:
[docs]
def __init__(self, name: str, type: type = typing.Any, **kwargs):
self.name = name
self.have_default = 'default' in kwargs
self.default = kwargs['default'] if self.have_default else None
self.type = type
@property
def is_hinted_optional(self):
return _types.is_optional(self.type)
@property
def hinted_optional_type(self):
return _types.get_type_of_optional(self.type, get_origin=False)
@property
def base_type(self):
if self.is_hinted_optional:
return _types.get_type(self.hinted_optional_type)
else:
return _types.get_type(self.type)
[docs]
def name_dashup(self) -> 'PluginArg':
r = PluginArg(_textprocessing.dashup(self.name))
r.have_default = self.have_default
r.default = self.default
r.type = r.type
return r
[docs]
def name_dashdown(self) -> 'PluginArg':
r = PluginArg(_textprocessing.dashdown(self.name))
r.have_default = self.have_default
r.default = self.default
r.type = r.type
return r
[docs]
def type_string(self):
if not _types.is_typing_hint(self.type):
return self.type.__name__
return str(self.type).replace('typing.', '')
[docs]
def parse_by_type(self, value: str | typing.Any):
if not isinstance(value, str):
return value
base_type = self.base_type
try:
if not _types.is_typing_hint(base_type) or base_type is typing.Any:
if base_type is bool:
return _types.parse_bool(value)
if any(base_type is t for t in (list, dict, set, typing.Any)):
try:
evaled = ast.literal_eval(value)
except ValueError:
if base_type is typing.Any:
return value
raise
if base_type is not typing.Any and not isinstance(evaled, base_type):
if not self.is_hinted_optional or evaled is not None:
raise ValueError(
f'Literal type "{evaled.__class__.__name__}" '
f'does not match plugin argument "{self.name}" type '
f'hint "{self.type_string()}".')
return evaled
return base_type(value)
if base_type is typing.Union:
try:
evaled = ast.literal_eval(value)
except (ValueError, SyntaxError):
# string
evaled = value
failures = 0
union_types = typing.get_args(self.type)
for t in union_types:
if _types.is_type(t, type(evaled)):
continue
else:
failures += 1
if failures == len(union_types):
raise ValueError(
f'Literal type "{evaled.__class__.__name__}" '
f'does not match plugin argument "{self.name}" type '
f'hint "{self.type_string()}".')
return evaled
return value
except SyntaxError as e:
if base_type is typing.Any:
return value
offset = e.offset - 1 if e.offset > 0 else 0
raise ValueError(f'Syntax Error: {e.text[:offset]}[ERROR HERE>]{e.text[offset:]}')
def __str__(self):
return f'{self.__class__.__name__}(name="{self.name}", type={self.type}, default={repr(self.default)})'
def __repr__(self):
return str(self)
[docs]
class PluginArgumentError(Exception):
"""
Raised when a plugin encounters an error in the arguments it is loaded by.
Or errors in arguments used for execution.
"""
pass
[docs]
class Plugin:
[docs]
def __init__(self, loaded_by_name: str, argument_error_type: type[PluginArgumentError] = PluginArgumentError,
**kwargs):
"""
:param loaded_by_name: The name the plugin was loaded by, will be passed by the loader.
:param argument_error_type: This exception type will be raised upon argument errors (invalid arguments)
when loading a plugin using a :py:class:`.PluginLoader` implementation. It should match the
``argument_error_type`` given to the :py:class:`.PluginLoader` implementation being used
to load the inheritor of this class.
:param kwargs: Additional arguments that may arise when using an ``ARGS`` static signature definition
with multiple ``NAMES`` in your implementation.
"""
self.__loaded_by_name = loaded_by_name
self.__argument_error_type = argument_error_type
[docs]
def argument_error(self, msg: str):
"""
Return an constructed exception that is suitable for raising
as an argument error for this plugin.
Example: ``raise self.argument_error('oops!')``
:param msg: exception message
:return: the exception object, you must ``raise`` it.
"""
return self.__argument_error_type(msg)
[docs]
@classmethod
def get_names(cls) -> list[str]:
"""
Get the names that this class can be loaded by.
:return:
"""
if hasattr(cls, 'NAMES'):
if isinstance(cls.NAMES, str):
return [cls.NAMES]
else:
return cls.NAMES
else:
return [_types.fullname(cls)]
[docs]
@classmethod
def get_help(cls, loaded_by_name: str) -> str:
"""
Get formatted help information about the plugin.
This includes any implemented help strings and an auto formatted
description of the plugins accepted arguments.
:param loaded_by_name: The name used to load the plugin.
Help may vary depending on how many names the plugin
implementation handles and what loading it by a certain
name does.
:return: Formatted help string
"""
help_str = None
if hasattr(cls, 'help'):
help_str = cls.help(loaded_by_name)
if help_str:
help_str = inspect.cleandoc(help_str).strip()
elif cls.__doc__:
help_str = inspect.cleandoc(cls.__doc__).strip()
args_with_defaults = cls.get_accepted_args(loaded_by_name)
arg_descriptors = []
for arg in args_with_defaults:
if not arg.have_default:
arg_descriptors.append(arg.name + ': ' + arg.type_string())
else:
default_value = arg.default
if isinstance(default_value, str):
default_value = _textprocessing.quote(default_value)
arg_descriptors.append(f'{arg.name}: {arg.type_string()} = {default_value}')
if arg_descriptors:
args_part = f'\n{" " * 4}arguments:\n{" " * 8}{(chr(10) + " " * 8).join(arg_descriptors)}\n'
else:
args_part = '\n'
if help_str:
wrap = \
_textprocessing.wrap_paragraphs(
help_str,
initial_indent=' ' * 4,
subsequent_indent=' ' * 4,
width=_textprocessing.long_text_wrap_width())
return loaded_by_name + f':{args_part}\n' + wrap
else:
return loaded_by_name + f':{args_part}'
[docs]
@classmethod
def get_required_args(cls, loaded_by_name: str) -> list[PluginArg]:
"""
Get a list of required arguments for this plugin class.
:param loaded_by_name: The name used to load the plugin.
Required arguments may vary by name used to load.
:return: list of argument names
"""
return [a for a in
cls.get_accepted_args(loaded_by_name) if not a.have_default]
[docs]
@classmethod
def get_default_args(cls, loaded_by_name: str) -> list[PluginArg]:
"""
Get the names and values of arguments for this plugin that possess default values.
:param loaded_by_name: The name used to load the plugin.
Default arguments may vary by name used to load.
:return: list of arguments with default value: (name, value)
"""
return [a for a in
cls.get_accepted_args(loaded_by_name) if a.have_default]
[docs]
@classmethod
def get_accepted_args(cls, loaded_by_name) -> list[PluginArg]:
"""
Retrieve the argument signature of a plugin implementation.
:param loaded_by_name: The name used to load the plugin.
Argument signature may vary by name used to load.
:return: List of argument descriptors, :py:class:`.PluginArg`
"""
if hasattr(cls, 'ARGS'):
if isinstance(cls.ARGS, dict):
if loaded_by_name not in cls.ARGS:
raise RuntimeError(
'Plugin module implementation bug, args for '
f'"{loaded_by_name}" not specified in ARGS dictionary.')
args_with_defaults = cls.ARGS.get(loaded_by_name)
else:
args_with_defaults = cls.ARGS
fixed_args = []
for arg in args_with_defaults:
if not isinstance(arg, PluginArg):
raise RuntimeError(
f'{cls.__name__}.ARGS["{loaded_by_name}"] '
f'contained a non PluginArg value: {arg}')
fixed_args.append(arg.name_dashup())
return [] if fixed_args is None else fixed_args
args_with_defaults = []
spec = list(_types.get_accepted_args_with_defaults(cls.__init__))[1:]
hints = typing.get_type_hints(cls.__init__)
for arg in spec:
name = arg[0]
hint = hints.get(name)
extra = {}
if hint is not None:
extra['type'] = hint
if len(arg) == 1:
args_with_defaults.append(
PluginArg(_textprocessing.dashup(name),
**extra))
else:
args_with_defaults.append(
PluginArg(_textprocessing.dashup(name),
default=arg[1],
**extra))
return args_with_defaults
@property
def loaded_by_name(self) -> str:
"""
The name the plugin was loaded by.
:return: name
"""
return self.__loaded_by_name
[docs]
class ModuleFileNotFoundError(FileNotFoundError):
"""
Raised by :py:func:`.load_modules` if a module could not be found on disk.
"""
pass
[docs]
def load_modules(paths: collections.abc.Iterable[str]) -> list[types.ModuleType]:
"""
Load python modules from a folder, directly from a .py file, or from a python module
installed in the environment. Cache them so that repeat requests for loading return
an already loaded module.
:raises ModuleFileNotFoundError: If a module path could not be found on disk,
or when a module could not be loaded from the python environment.
:param paths: list of folder/file paths, or references to python modules installed
in the environment
:return: list of :py:class:`types.ModuleType`
"""
r = []
for plugin_path in paths:
if os.path.exists(plugin_path):
plugin_path, ext = os.path.splitext(os.path.abspath(plugin_path))
if not ext:
plugin_path = os.path.join(plugin_path, '__init__.py')
else:
plugin_path += ext
if plugin_path in LOADED_PLUGIN_MODULES:
mod = LOADED_PLUGIN_MODULES[plugin_path]
else:
try:
mod = importlib.machinery.SourceFileLoader(plugin_path, plugin_path).load_module()
except FileNotFoundError as e:
raise ModuleFileNotFoundError(e)
LOADED_PLUGIN_MODULES[plugin_path] = mod
r.append(mod)
else:
try:
mod = importlib.import_module(plugin_path)
except Exception as e:
raise ModuleFileNotFoundError(e)
LOADED_PLUGIN_MODULES[plugin_path] = mod
r.append(mod)
return r
PluginArgumentsDef = list[PluginArg] | None
[docs]
class PluginNotFoundError(Exception):
"""
Raised when a plugin could not be located by a name.
"""
pass
[docs]
class PluginLoader:
[docs]
def __init__(self,
base_class=Plugin,
description: str = "plugin",
reserved_args: PluginArgumentsDef = None,
argument_error_type: type[PluginArgumentError] = PluginArgumentError,
not_found_error_type: type[PluginNotFoundError] = PluginNotFoundError):
"""
:param base_class: Base class of plugins, will be used for searching modules.
:param description: Short plugin description / name, used in exception messages.
:param reserved_args: Constructor arguments that are used by the plugin class which
cannot be redefined by implementors of the plugin class. This should be a
list of plugin argument descriptors, :py:class:`.PluginArg`
:param argument_error_type: This exception type will be raised when the plugin is loaded
with invalid URI arguments.
:param not_found_error_type: This exception type will be raised when a plugin could
not be located by a name specified in a loading URI.
"""
self.__classes = set()
self.__classes_by_name = dict()
self.__plugin_module_paths = set()
self.__reserved_args = reserved_args if reserved_args else []
self.__argument_error_type = argument_error_type
self.__not_found_error_type = not_found_error_type
self.__description = description
self.__base_class = base_class
@property
def plugin_module_paths(self) -> frozenset[str]:
"""
Every module path ever seen by :py:meth:`PluginLoader.load_plugin_modules`.
:return: frozen set
"""
return frozenset(self.__plugin_module_paths)
[docs]
def add_class(self, cls: type[Plugin]):
"""
Add an implementation class to this loader.
:raises RuntimeError: If the added class specifies a name that already exists in this loader.
:param cls: the class
"""
if cls in self.__classes or (hasattr(cls, 'HIDDEN') and getattr(cls, 'HIDDEN')):
# no-op
return
for name in cls.get_names():
if name in self.__classes_by_name:
raise RuntimeError(
f'plugin class using the name {name} already exists.')
self.__classes_by_name[name] = cls
self.__classes.add(cls)
[docs]
def add_search_module_string(self, string: str) -> list[type[Plugin]]:
"""
Add a module string (in sys.modules) that will be searched for implementations.
:param string: the module string
:return: list of classes that were newly discovered
"""
classes = self._load_classes([sys.modules[string]])
for cls in classes:
self.add_class(cls)
return classes
[docs]
def add_search_module(self, module: types.ModuleType) -> list[type[Plugin]]:
"""
Directly add a module object that will be searched for implementations.
:param module: the module object
:raises ValueError: If ``module`` is not a python module object.
:return: list of classes that were newly discovered
"""
if not isinstance(module, types.ModuleType):
raise ValueError('passed object in not a python module')
classes = self._load_classes([module])
for cls in classes:
self.add_class(cls)
return classes
[docs]
def load_plugin_modules(self, paths: collections.abc.Iterable[str]) -> list[type[Plugin]]:
"""
Modules that will be loaded from disk, or the python environment, and searched for implementations.
Either python files, or module directories containing __init__.py, or
names of python modules installed in the environment.
It can be a mix of these.
:raises ModuleFileNotFoundError: If a module path could not be found on disk,
or when a module could not be loaded from the python environment.
:param paths: list of folder/file paths, or references to python modules installed
in the environment
:return: list of classes that were newly discovered
"""
classes = self._load_classes(load_modules(
[path for path in paths if path not in self.__plugin_module_paths]))
self.__plugin_module_paths.update(paths)
for cls in classes:
self.add_class(cls)
return classes
def _load_classes(self, modules: collections.abc.Iterable[types.ModuleType]):
found_classes = set()
for mod in modules:
def _excluded(cls):
try:
if cls in self.__classes:
return True
except TypeError:
# handle un-hashable
return True
if not inspect.isclass(cls):
return True
if cls is self.__base_class:
return True
if not issubclass(cls, self.__base_class):
return True
if hasattr(cls, 'HIDDEN'):
return cls.HIDDEN
else:
return False
found_classes.update([value for value in _types.get_public_members(mod).values() if not _excluded(value)])
return list(found_classes)
[docs]
def get_available_classes(self) -> list[type[Plugin]]:
"""
Get classes seen by this plugin loader.
:return: list of classes (types)
"""
return list(self.__classes)
[docs]
def get_class_by_name(self, plugin_name: _types.Name) -> type[Plugin]:
"""
Get a plugin class by one of its names.
IE: one of the names listed in its ``NAMES`` static attribute.
:param plugin_name: a name associated with a plugin class
:raises PluginNotFoundError: If the plugin name could not be found.
:return: class (type)
"""
cls = self.__classes_by_name.get(plugin_name)
if cls is None:
raise self.__not_found_error_type(
f'Found no {self.__description} with the name: {plugin_name}')
return cls
[docs]
def get_all_names(self) -> _types.Names:
"""
Get all plugin names that this loader can see.
:return: list of names (strings)
"""
return list(self.__classes_by_name.keys())
[docs]
def get_help(self, plugin_name: _types.Name) -> str:
"""
Get a formatted help string for a plugin by one of its loadable names.
:param plugin_name: a name associated with the plugin class
:raises PluginNotFoundError: If the plugin name could not be found.
:return: formatted string
"""
return self.get_class_by_name(plugin_name).get_help(plugin_name)
[docs]
def load(self, uri: _types.Uri, **kwargs) -> Plugin:
"""
Load an plugin using a URI string containing its name and arguments.
:param uri: The URI string
:param kwargs: default argument values, will be override by arguments specified in the URI
:raises ValueError: If uri is ``None``
:raises RuntimeError: If a plugin is discovered to be using a reserved argument name upon loading it.
:raises PluginArgumentError: If there is an error in the loading arguments for the plugin.
:raises PluginNotFoundError: If the plugin name mentioned in the URI could not be found.
:return: plugin instance
"""
if uri is None:
raise ValueError('uri must not be None')
call_by_name = uri.split(';', 1)[0].strip()
plugin_class = self.get_class_by_name(call_by_name)
parser_accepted_args = [a.name for a in plugin_class.get_accepted_args(call_by_name)]
parser_raw_args = [a.name for a in plugin_class.get_accepted_args(call_by_name)
if a.base_type not in (int, str, float, bool)]
if 'loaded-by-name' in parser_accepted_args:
# inheritors of base_class can't define this
raise RuntimeError(f'"loaded-by-name" is a reserved {self.__description} module argument, '
'chose another argument name for your module.')
for module_arg in self.__reserved_args:
# reserved args always go into **kwargs
# inheritors of base_class
if module_arg.name in parser_accepted_args:
raise RuntimeError(f'"{module_arg}" is a reserved {self.__description} module argument, '
'chose another argument name for your module.')
parser_accepted_args.append(module_arg.name)
arg_parser = _textprocessing.ConceptUriParser(
self.__description,
known_args=parser_accepted_args,
args_raw=parser_raw_args)
try:
parsed_args = arg_parser.parse(uri).args
except _textprocessing.ConceptUriParseError as e:
raise self.__argument_error_type(str(e))
args_dict = {}
for arg in plugin_class.get_default_args(call_by_name):
# defaults specified by the implementation class
args_dict[_textprocessing.dashdown(arg.name)] = arg.default
for reserved_arg in self.__reserved_args:
# defaults specified by the loader
snake_case = _textprocessing.dashdown(reserved_arg.name)
try:
if reserved_arg.have_default:
args_dict[snake_case] = reserved_arg.parse_by_type(
parsed_args.get(reserved_arg.name, reserved_arg.default))
else:
if reserved_arg.name in parsed_args:
args_dict[snake_case] = reserved_arg.parse_by_type(
parsed_args.get(reserved_arg.name))
elif snake_case not in kwargs:
# Nothing provided this reserved argument value
if reserved_arg.is_hinted_optional:
args_dict[snake_case] = None
else:
raise self.__argument_error_type(
f'Missing required argument "{reserved_arg.name}" for {self.__description} '
f'"{call_by_name}".')
except ValueError as e:
raise self.__argument_error_type(
f'Argument "{reserved_arg.name}" must match type: "{reserved_arg.type_string()}". '
f'Failure cause: {str(e).strip()}')
# plugin load user arguments
args_dict.update(kwargs)
accepted_args = {_textprocessing.dashup(n.name): n for n in
itertools.chain(plugin_class.get_accepted_args(loaded_by_name=call_by_name),
self.__reserved_args)}
for k, v in parsed_args.items():
# URI overrides everything
arg = accepted_args[k]
try:
args_dict[_textprocessing.dashdown(k)] = arg.parse_by_type(v)
except ValueError as e:
raise self.__argument_error_type(
f'Argument "{k}" must match type: "{arg.type_string()}". '
f'Failure cause: {str(e).strip()}')
# Automagic argument
args_dict['loaded_by_name'] = call_by_name
for arg_name, plugin_arg in ((k, v) for k, v in accepted_args.items() if not v.have_default):
snake_case = _textprocessing.dashdown(arg_name)
if snake_case not in args_dict:
if plugin_arg.is_hinted_optional:
args_dict[snake_case] = None
else:
raise self.__argument_error_type(
f'Missing required argument "{arg_name}" for {self.__description} "{call_by_name}".')
try:
return plugin_class(**args_dict)
except TypeError as e:
msg = str(e)
if 'required positional argument' in msg:
raise self.__argument_error_type(msg)
except self.__argument_error_type as e:
raise self.__argument_error_type(
f'Invalid argument given to {self.__description} '
f'"{call_by_name}": {str(e).strip()}')
[docs]
def loader_help(self,
names: _types.Names,
plugin_module_paths: _types.OptionalPaths = None,
title='plugin',
title_plural='plugins',
throw=False,
log_error=True):
"""
Implements ``--sub-command-help`` and ``--image-processor-help``
command line options for example.
:param names: arguments (sub-command names, or empty list)
:param plugin_module_paths: plugin module paths to search
:param title: plugin title, used in messages
:param title_plural: plural plugin title, used in messages
:param throw: throw on error?
:param log_error: log errors to stderr?
:raises PluginNotFoundError: ``names`` contained an unknown plugin name
:raises ModuleFileNotFoundError: ``plugin_module_paths`` contained a missing module
:return: return-code, anything other than 0 is failure
"""
if plugin_module_paths is not None:
try:
self.load_plugin_modules(plugin_module_paths)
except ModuleFileNotFoundError as e:
if log_error:
_messages.log(
f'Plugin module could not be found: {str(e).strip()}',
level=_messages.ERROR)
if throw:
raise
return 1
if len(names) == 0:
available = ('\n' + ' ' * 4).join(_textprocessing.quote(name) for name in self.get_all_names())
_messages.log(
f'Available {title_plural}:\n\n{" " * 4}{available}')
return 0
help_strs = []
for name in names:
try:
help_strs.append(self.get_help(name))
except PluginNotFoundError:
if log_error:
_messages.log(
f'An {title} with the name of "{name}" could not be found.',
level=_messages.ERROR)
if throw:
raise
return 1
for help_str in help_strs:
_messages.log(help_str + '\n', underline=True)
return 0