Source code for dgenerate.globalconfig

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


__doc__ = "Configure dgenerate's global constants."

import contextlib
import importlib
import inspect
import json
import types

import yaml
import toml
import dgenerate.types as _types

_config_variable_map = dict()


[docs] def register_config_variable(module: str | types.ModuleType, variable_name: str, config_variable_name: str | None = None): """ Register a global config variable that exists inside an arbitrary module. :param module: The module name or object reference. :param variable_name: Name of the variable inside the module. :param config_variable_name: Name to represent the variable in the global config file, if left ``None`` this will be ``variable_name`` in lowercase. """ if isinstance(module, types.ModuleType): module = module.__name__ config_variable_name = variable_name.lower() \ if not config_variable_name else config_variable_name _config_variable_map[config_variable_name] = module + '.' + variable_name
[docs] def register_all(): """ Register all public non-module type global objects inside the current module as config variables. """ frame = inspect.currentframe().f_back module = inspect.getmodule(frame) for name, value in frame.f_globals.items(): if not name.startswith('_') and not isinstance(value, types.ModuleType): register_config_variable(module, name)
def _get_constant_by_string(path: str): module_path, _, attr_name = path.rpartition(".") if not module_path or not attr_name: raise ValueError(f"Invalid path: {path}") try: module = importlib.import_module(module_path) return getattr(module, attr_name) except (ModuleNotFoundError, AttributeError) as e: raise ImportError(f"Could not retrieve '{path}': {e}") from e def _set_constant_by_string(path: str, value): module_path, _, attr_name = path.rpartition(".") if not module_path or not attr_name: raise ValueError(f"Invalid path: {path}") try: module = importlib.import_module(module_path) return setattr(module, attr_name, value) except (ModuleNotFoundError, AttributeError) as e: raise ImportError(f"Could not retrieve '{path}': {e}") from e def _get_config_dict(): config_dict = dict() for name, location in _config_variable_map.items(): config_dict[name] = _get_constant_by_string(location) return config_dict
[docs] def get_config_dict(): """ Return a dictionary representation of the global configuration. :return: config dictionary """ return _types.partial_deep_copy_container(_get_config_dict())
[docs] def set_from_config_dict(config_dict: dict): """ Set the current global config from a dictionary object. This dictionary may be partial, i.e. an incomplete set of settings as long as the key names mentioned are correct. :param config_dict: The config dictionary :raise KeyError: If a configuration key name is not valid. """ for name, value in config_dict.items(): _set_constant_by_string(_config_variable_map[name], value)
[docs] def serialize_current_config(stream=None, mode: str = 'json') -> str | None: """ Serialize the current global config. :param stream: File like object, if not provided this function will return a string. :param mode: ``json``, ``yaml``, or ``toml`` :return: the serialized config """ config_dict = _get_config_dict() if mode == 'json': if stream: json.dump(config_dict, stream, indent=4) else: return json.dumps(config_dict, indent=4) elif mode == 'yaml': if stream: yaml.dump(config_dict, stream=stream, default_flow_style=False) else: return yaml.dump(config_dict, default_flow_style=False) elif mode == 'toml': if stream: toml.dump(config_dict, stream) else: return toml.dumps(config_dict) else: raise ValueError(f'Unknown serialization mode: {mode}') return None
__config_stack = []
[docs] def push_config(): """ Save the current configuration to the stack. """ __config_stack.append(_get_config_dict())
[docs] def pop_config(): """ Pop the last saved configuration off the stack and restore it. :raise IndexError: if the stack is empty. """ config_dict = __config_stack.pop() set_from_config_dict(config_dict)
[docs] @contextlib.contextmanager def restore_config_context(): """ Context manager which pushes the current global configuration to the stack and pops it when the ``with`` context ends. """ try: push_config() yield finally: pop_config()
[docs] def load_config(content_or_stream, mode: str = 'json'): """ Load global config from a string. :param content_or_stream: string content or file like object. :param mode: ``json``, ``yaml``, or ``toml`` """ if mode == 'json': config = json.loads(content_or_stream) \ if isinstance(content_or_stream, str) else json.load(content_or_stream) elif mode == 'yaml': config = yaml.safe_load(content_or_stream) elif mode == 'toml': config = toml.loads(content_or_stream) \ if isinstance(content_or_stream, str) else toml.load(content_or_stream) else: raise ValueError(f'Unknown deserialization mode: {mode}') for name, value in config.items(): _set_constant_by_string(_config_variable_map[name], value)
__all__ = _types.module_all()