Source code for dgenerate.batchprocess.batchprocessor

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import collections.abc
import io
import itertools
import os
import re
import shlex
import typing

import jinja2

import dgenerate.arguments as _arguments
import dgenerate.messages as _messages
import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types


class PeekReader:
    """
    Read from a ``TextIO`` file object while peeking at the next line in the file.

    This is an iterable reader wrapper that yields the tuple (current_line, next_line)

    **next_line** will be ``None`` if the next line is the end of the file.
    """

    def __init__(self, file: typing.TextIO):
        """
        :param file: The ``TextIO`` reader to wrap.
        """
        self._file = file
        self._last_next_line = None

    def __iter__(self):
        return self

    def __next__(self):
        if self._last_next_line is not None:
            self._cur_line = self._last_next_line
            self._last_next_line = None
        else:
            self._cur_line = next(self._file)

        try:
            self._next_line = next(self._file)
            self._last_next_line = self._next_line
        except StopIteration:
            self._next_line = None

        return self._cur_line, self._next_line


[docs] class BatchProcessError(Exception): """ Thrown by :py:meth:`.BatchProcessor.run_file` and :py:meth:`.BatchProcessor.run_string` when an error in a batch processing script is encountered. """ pass
[docs] class BatchProcessor: """ Implements dgenerates batch processing scripts in a generified manner. This is the bare-bones implementation of the shell with nothing implemented for you except the ``\\print`` and ``\\set`` directives. If you wish to create this object to run a dgenerate configuration, use :py:class:`dgenerate.batchprocess.ConfigRunner` """ invoker: typing.Callable[[collections.abc.Sequence[str]], int] """ Invoker function, responsible for executing lines recognized as shell commands. """ name: _types.Name """ Name of this batch processor, currently used in the hash bang version check directive and messages. """ version: _types.Version """ Version tuple for the version check hash bang directive. """ template_variables: dict[str, typing.Any] """ Live template variables. """ reserved_template_variables: set[str] """ These template variables cannot be set with the ``\\set`` directive """ template_functions: dict[str, typing.Callable[[typing.Any], typing.Any]] """ Functions available when templating is occurring. """ directives: dict[str, typing.Optional[typing.Callable[[collections.abc.Sequence[str]], int]]] """ Batch process directives, shell commands starting with a backslash. Dictionary of callable(list) -> int. The function should return a return code, 0 for success, anything else for failure. """ injected_args: collections.abc.Sequence[str] """ Shell arguments to inject at the end of every invocation. """ expand_vars: typing.Callable[[str], str] """ A function for expanding environmental variables, defaults to :py:func:`os.path.expandvars` """
[docs] def __init__(self, invoker: typing.Callable[[collections.abc.Sequence[str]], int], name: _types.Name, version: typing.Union[_types.Version, str], template_variables: typing.Optional[dict[str, typing.Any]] = None, reserved_template_variables: typing.Optional[set[str]] = None, template_functions: typing.Optional[ dict[str, typing.Callable[[typing.Any], typing.Any]]] = None, directives: dict[str, typing.Optional[typing.Callable[[list], None]]] = None, injected_args: typing.Optional[collections.abc.Sequence[str]] = None): """ :param invoker: A function for invoking lines recognized as shell commands, should return a return code. :param name: The name of this batch processor, currently used in the version check directive and messages :param version: Version for version check hash bang directive. :param template_variables: Live template variables, the initial environment, this dictionary will be modified during runtime. :param reserved_template_variables: These template variable names cannot be set with the \\set directive. :param template_functions: Functions available to Jinja2 :param directives: batch processing directive handlers, for: *\\\\directives*. This is a dictionary of names to functions which accept a single parameter, a list of directive arguments, and return a return code. :param injected_args: Arguments to be injected at the end of user specified arguments for every shell invocation. If ``-v/--verbose`` is present in ``injected_args`` debugging output will be enabled globally while the config runs, and not just for invocations. Passing ``-v/--verbose`` also disables handling of unhandled non :py:exc:`SystemExit` exceptions raised by config directive implementations, a stack trace will be printed when these exceptions are encountered. """ self._template_functions = None self.invoker = invoker self.name = name self.template_variables = template_variables if template_variables else dict() self.reserved_template_variables = reserved_template_variables if reserved_template_variables else set() self.template_functions = template_functions if template_functions else dict() self.directives = directives if directives else dict() self._directive_exceptions = False self.injected_args = injected_args if injected_args else [] self._current_line = 0 self._running_template_continuation = False if isinstance(version, str): self.version = _textprocessing.parse_version(version) else: self.version: tuple[int, int, int] = tuple(version) if len(self.version) != 3: raise ValueError( f'version tuple expected to contain three components: (major, minor, patch). received: {self.version}') self.expand_vars = os.path.expandvars
@property def current_line(self) -> int: """ The current line in the file being processed. """ return self._current_line
[docs] def render_template(self, string: str): """ Render a template from a string :param string: the string containing the Jinja2 template. :return: rendered string """ jinja_env = jinja2.Environment() for name, func in self.template_functions.items(): jinja_env.globals[name] = func jinja_env.filters[name] = func try: return self.expand_vars( jinja_env.from_string(string). render(**self.template_variables)) except jinja2.TemplateSyntaxError as e: raise BatchProcessError(f'Template Syntax Error: {str(e).strip()}')
def _look_for_version_mismatch(self, line_idx, line): versioning = re.match(r'#!\s+' + self.name + r'\s+([0-9]+\.[0-9]+\.[0-9]+)', line) if versioning: config_file_version = versioning.group(1) config_file_version_parts = config_file_version.split('.') cur_major_version = self.version[0] config_major_version = int(config_file_version_parts[0]) cur_minor_version = self.version[1] config_minor_version = int(config_file_version_parts[1]) version_str = '.'.join(str(i) for i in self.version) if cur_major_version != config_major_version: _messages.log( f'Failed version check (major version missmatch) on line {line_idx}, ' f'running an incompatible version of {self.name}! You are running version {version_str} ' f'and the config file specifies the required version: {config_file_version}' , underline=True, level=_messages.WARNING) elif cur_minor_version < config_minor_version: _messages.log( f'Failed version check (current minor version less than requested) ' f'on line {line_idx}, running an incompatible version of {self.name}! ' f'You are running version {version_str} and the config file specifies ' f'the required version: {".".join(config_file_version)}' , underline=True, level=_messages.WARNING) def _jinja_user_define(self, name, value): if name in self.template_functions: raise BatchProcessError( f'Cannot define template variable "{name}" on line {self.current_line}, ' f'as that name is taken by a template function.') if name in self.reserved_template_variables: raise BatchProcessError( f'Cannot define template variable "{name}" on line {self.current_line}, ' f'as that name is a reserved variable name.') self.template_variables[name] = value def _directive_handlers(self, line): if line.startswith('\\set'): directive_args = line.split(' ', 2) if len(directive_args) == 3: self._jinja_user_define(directive_args[1].strip(), self.render_template(directive_args[2].strip())) return True else: raise BatchProcessError( f'\\set directive received less than 2 arguments, ' f'syntax is: \\set name value') elif line.startswith('\\print'): directive_args = line.split(' ', 1) if len(directive_args) == 2: _messages.log(self.render_template(directive_args[1].strip())) return True else: raise BatchProcessError( f'\\print directive received no arguments, ' f'syntax is: \\print value') if line.startswith('{'): try: self._running_template_continuation = True self.run_string(self.render_template(line)) finally: self._running_template_continuation = False return True elif line.startswith('\\'): directive_args = line.split(' ', 1) directive = directive_args[0].lstrip('\\') impl = self.directives.get(directive) if impl is None: raise BatchProcessError(f'Unknown directive "\\{directive}".') directive_args = directive_args[1:] try: if directive_args: return_code = impl( shlex.split(self.render_template(directive_args[0].strip()))) else: return_code = impl([]) if return_code != 0: raise BatchProcessError( f'Directive error return code: {return_code}') except Exception as e: if self._directive_exceptions: raise e raise BatchProcessError(e) return True return False def _lex_and_run_invocation(self, invocation_string): raw_templated_string = self.render_template(invocation_string) try: shell_lexed = shlex.split(raw_templated_string) except ValueError as e: raise BatchProcessError(e) for arg in self.injected_args: shell_lexed.append(arg) raw_injected_args = ' '.join(str(a) for a in self.injected_args) if raw_injected_args: cmd_info = raw_templated_string + ' ' + raw_injected_args else: cmd_info = raw_templated_string header = 'Processing Arguments: ' args_wrapped = \ _textprocessing.wrap( cmd_info, width=_textprocessing.long_text_wrap_width() - len(header), subsequent_indent=' ' * len(header)) _messages.log(header + args_wrapped, underline=True) return_code = self.invoker(shell_lexed) if return_code != 0: raise BatchProcessError( f'Invocation error return code: {return_code}') def _run_file(self, stream: typing.TextIO): continuation = '' template_continuation = False normal_continuation = False def run_continuation(cur_line): nonlocal continuation, template_continuation, normal_continuation if not template_continuation: completed_continuation = (continuation + ' ' + cur_line).strip() else: completed_continuation = (continuation + cur_line).strip() template_continuation = False normal_continuation = False continuation = '' if self._directive_handlers(completed_continuation): return self._lex_and_run_invocation(completed_continuation) def remove_tail_comments_unlexable(string): try: # find the start of a possible comment comment_start = string.index('#') except ValueError: # no comments, all good return string if comment_start == 0: # it starts the string, ignore return string segment = string[:comment_start] if segment.endswith('\\'): # found a comment token, but it was escaped next_spaces = ''.join(itertools.takewhile(lambda x: x.isspace(), string[comment_start + 1:])) # record all space characters after the #, they may be taken by the lexer if # it can make sense of what came after # lexed, result = remove_tail_comments(string[comment_start + 1:]) # recursive decent to the right starting after the # (operator lol) to solve the rest of the string if next_spaces and lexed: # the spaces would have been consumed by lexing, add them back to the # left side of the string where the user intended them to be result = next_spaces + result # return the left side minus the escape sequence + the comment # token, plus the evaluated right side return segment.removesuffix('\\') + '#' + result # unescaped comment start, return the segment to the left return segment def remove_tail_comments(string): # this is a more difficult problem than I imagined. try: parts = _textprocessing.tokenized_split(string, '#', escapable_separator=True, allow_unterminated_strings=True) # attempt to split off the comment if not parts: # empty case, effectively un-lexed return False, string new_value = parts[0] if not new_value.strip(): # do not remove if the comment is all that exists on the line # that is handled elsewhere return False, string # the left side was lexed and striped of leading and trailing whitespace return True, new_value except _textprocessing.TokenizedSplitSyntaxError: # could not lex this because of a syntax error, since unterminated # strings are understandable by the lexer given our options, this # is an uncommon if not near impossible occurrence return False, remove_tail_comments_unlexable(string) last_line = None for line_idx, line_and_next in enumerate(PeekReader(stream)): line: str next_line: typing.Optional[str] line, next_line = line_and_next line_strip = remove_tail_comments(line)[1].strip() if not self._running_template_continuation: self._current_line = line_idx if line_strip == '': if continuation and last_line is not None: if last_line.startswith('-') and \ not last_line.endswith('\\'): run_continuation('') elif line_strip.startswith('#'): self._look_for_version_mismatch(line_idx, line) elif line_strip.startswith('{') and not template_continuation and not normal_continuation: continuation += line template_continuation = True elif not template_continuation and (line_strip.endswith('\\') or next_line and next_line.lstrip().startswith('-')): continuation += ' ' + line_strip.strip().removesuffix('\\').strip() normal_continuation = True elif template_continuation: line_rstrip = remove_tail_comments(line)[1].rstrip() if line_rstrip.endswith('!END'): run_continuation(line_rstrip.removesuffix('!END')) else: continuation += line else: run_continuation(line_strip) last_line = line_strip if continuation: run_continuation('')
[docs] def run_file(self, stream: typing.TextIO): """ Process a batch processing script from a file stream :raise BatchProcessError: :param stream: A filestream in text read mode """ try: parsed, _ = _arguments.parse_known_args( self.injected_args, log_error=False ) except _arguments.DgenerateUsageError as e: raise BatchProcessError(f'Error parsing injected arguments: {str(e).strip()}') directive_exceptions_last = self._directive_exceptions if parsed.verbose: _messages.push_level(_messages.DEBUG) self._directive_exceptions = True try: self._run_file(stream) except BatchProcessError as e: raise BatchProcessError(f'Error on line {self.current_line}: {str(e).strip()}') finally: _messages.pop_level() self._directive_exceptions = directive_exceptions_last
[docs] def run_string(self, string: str): """ Process a batch processing script from a string :raise BatchProcessError: :param string: a string containing the script """ self.run_file(io.StringIO(string))
__all__ = _types.module_all()