Source code for dgenerate.pygments

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import pygments.lexer as _lexer
import pygments.token as _token

__doc__ = """
This module provides a pygments lexer for the dgenerate config / shell language.

This can be used for syntax highlighting.
"""

_DGENERATE_FUNCTIONS = sorted((
    'abs', 'all', 'any', 'ascii', 'bin', 'bool', 'bytearray', 'bytes',
    'callable', 'chr', 'complex', 'cwd', 'dict', 'divmod', 'enumerate',
    'filter', 'first', 'float', 'format', 'format_prompt', 'format_size',
    'frozenset', 'gen_seeds', 'getattr', 'hasattr', 'hash', 'hex', 'int',
    'iter', 'last', 'len', 'list', 'map', 'max', 'min', 'next', 'object',
    'oct', 'ord', 'pow', 'quote', 'range', 'repr', 'reversed', 'round',
    'set', 'slice', 'sorted', 'str', 'sum', 'tuple', 'type', 'unquote',
    'zip', 'download'
), key=lambda s: len(s), reverse=True)

_DGENERATE_FUNCTIONS_NS = sorted(map(lambda x: '({})(.)({})'.format(*x.split('.')), (
    'glob.glob', 'glob.iglob', 'glob.escape', 'path.abspath', 'path.basename',
    'path.commonpath', 'path.commonprefix',
    'path.dirname', 'path.exists', 'path.expanduser', 'path.expandvars',
    'path.getatime', 'path.getctime', 'path.getmtime', 'path.getsize',
    'path.isabs', 'path.isdir', 'path.isfile', 'path.islink',
    'path.ismount', 'path.join', 'path.lexists', 'path.normcase',
    'path.normpath', 'path.realpath', 'path.relpath', 'path.samefile',
    'path.sameopenfile', 'path.samestat', 'path.split', 'path.splitdrive',
    'path.splitext', 'path.supports_unicode_filenames'
)), key=lambda s: len(s), reverse=True)

_JINJA2_FUNCTIONS = sorted((
    # Functions
    'range', 'lipsum', 'cycler', 'joiner',

    # Filters
    'safe', 'capitalize', 'lower', 'upper', 'title', 'trim', 'striptags', 'urlencode',
    'wordcount', 'replace', 'format', 'escape', 'tojson', 'join', 'sort', 'reverse',
    'length', 'sum', 'random', 'batch', 'slice', 'first', 'last', 'default', 'groupby'
), key=lambda s: len(s), reverse=True)

# Define Jinja2 keywords
_JINJA2_KEYWORDS = sorted((
    'true', 'false', 'none', 'and', 'or', 'not', 'if', 'else', 'elif', 'for', 'endfor', 'in', 'do', 'async', 'await',
    'block', 'extends', 'include', 'import', 'from', 'macro', 'call', 'set', 'with', 'without', 'filter',
    'endfilter', 'capture', 'endcapture', 'spaceless', 'endspaceless', 'flush', 'load', 'url', 'static',
    'trans', 'endtrans', 'as', 'safe', 'end', 'autoescape', 'endautoescape', 'raw', 'endraw'
), key=lambda s: len(s), reverse=True)

# asteval keywords
_SETP_KEYWORDS = sorted((
    'for', 'in', 'if', 'else', 'elif', 'and', 'or', 'not', 'is', 'True', 'False'
), key=lambda s: len(s), reverse=True)

# Common patterns
_ecos = r'(\\#|[^/\s\\#=;:{}"\'])'
_comment_pattern = (r'(?<!\\)(#.*$)', _token.Comment.Single)
_env_var_pattern = (r'\$(?:\w+|\{\w+\})|%[\w]+%', _token.Name.Constant)
_jinja_block_pattern = (r'(\{%)(\s*)(\w+)',
                        _lexer.bygroups(_token.String.Interpol, _token.Whitespace, _token.Keyword), 'jinja_block')
_jinja_comment_pattern = (r'(\{#)', _token.Comment.Multiline, 'jinja_comment')
_jinja_interpolate_pattern = (r'(\{\{)', _token.String.Interpol, 'jinja_interpolate')
_operators_punctuation_pattern = (r'[\[\]{}()=\\;,:]', _token.Operator)
_operators_pattern = (r'\*\*|<<|>>|[-+*/%^|&<>!.]', _token.Operator)
_size_pattern = (r'(?<!\w)\d+[xX]\d+(?!\w)', _token.Number.Hex)
_number_float_pattern = (r'(?<!\w)(-?\d+(\.\d*)?([eE][-+]?\d+)?)(?!\w)', _token.Number.Float)
_decimal_integer_pattern = (r'(?<!\w)-?\d+(?!\w)', _token.Number.Integer)
_binary_integer_pattern = (r'(?<!\w)0[bB][01]+(?!\w)', _token.Number.Binary)
_hexa_decimal_integer_pattern = (r'(?<!\w)0[xX][0-9a-fA-F]+(?!\w)', _token.Number.Hex)
_octal_integer_pattern = (r'(?<!\w)0[oO][0-7]+(?!\w)', _token.Number.Octal)
_text_pattern = (r'[^=\s\[\]{}()$"\'`\\<&|;:,]+', _token.Text)
_variable_names_pattern = (r'(?<!\w)[a-zA-Z_][a-zA-Z0-9_]*(?!\w)', _token.Name.Variable)
_function_call_generic = (r'(?<!\w)([a-zA-Z_][a-zA-Z0-9_]*)(\()', _lexer.bygroups(_token.Name, _token.Operator))

_http_pattern = (r'(?<!\w)(https?://(?:'  # Protocol
                 r'(?:[a-zA-Z0-9\-]+\.)+[a-zA-Z]{2,6}|'  # Domain
                 r'localhost|'  # Localhost
                 r'(?:\d{1,3}\.){3}\d{1,3}|'  # IPv4
                 r'\[[a-fA-F0-9:]+\])'  # IPv6
                 r'(?::\d+)?'  # Optional port
                 r'(?:/[a-zA-Z0-9\-._~%!$&\'()*+,;=:@/]*)?'  # Path
                 r'(?:\?[a-zA-Z0-9\-._~%!$&\'()*+,;=:@/?]*)?'  # Query parameters
                 r')',
                 _token.String)

_path_patterns = (
    _http_pattern,
    (rf'(?<!\w)([a-zA-Z]:(([/]|\\\\){_ecos}*)+)', _token.String), # Drive letters
    (rf'(?<!\w)(~|..?)?(([/]|\\\\){_ecos}+)(([/]|\\\\){_ecos}*)*', _token.String),  # absolute paths with relative components
    (rf'(?<!\w)({_ecos}+)(([/]|\\\\){_ecos}*)+', _token.String),  # relative paths with relative components
    (rf'(?<!\w){_ecos}+\.{_ecos}+', _token.String)
)

_SCHEDULER_KEYWORDS = sorted((
    "DDIMScheduler",
    "DDPMScheduler",
    "DEISMultistepScheduler",
    "DPMSolverMultistepScheduler",
    "DPMSolverSDEScheduler",
    "DPMSolverSinglestepScheduler",
    "EDMEulerScheduler",
    "EulerAncestralDiscreteScheduler",
    "EulerDiscreteScheduler",
    "HeunDiscreteScheduler",
    "KDPM2AncestralDiscreteScheduler",
    "KDPM2DiscreteScheduler",
    "LCMScheduler",
    "LMSDiscreteScheduler",
    "PNDMScheduler",
    "UniPCMultistepScheduler"
    "DDPMWuerstchenScheduler"),
    key=lambda s: len(s), reverse=True)

_VAE_KEYWORDS = sorted((
    "AutoencoderKL",
    "AsymmetricAutoencoderKL",
    "AutoencoderTiny",
    "ConsistencyDecoderVAE",
    "FlaxAutoencoderKL"
), key=lambda s: len(s), reverse=True)

_MODEL_TYPE_KEYWORDS = sorted((
    'torch',
    'torch-pix2pix',
    'torch-sdxl',
    'torch-if',
    'torch-ifs',
    'torch-ifs-img2img',
    'torch-sdxl-pix2pix',
    'torch-upscaler-x2',
    'torch-upscaler-x4',
    'torch-s-cascade',
    'flax',
    'help',
    'helpargs',
), key=lambda s: len(s), reverse=True)

_DTYPE_KEYWORDS = sorted((
    'float16',
    'bfloat16',
    'float32',
    'float64',
    'fp16',
    'bf16'
), key=lambda s: len(s), reverse=True)

_number_patterns = (_number_float_pattern,
                    _decimal_integer_pattern,
                    _binary_integer_pattern,
                    _hexa_decimal_integer_pattern,
                    _octal_integer_pattern)


def _create_string_continue(name, char, root_state):
    string_token = _token.String.Double if char == '"' else _token.String.Single
    return {
        name: [
            _env_var_pattern,
            _jinja_block_pattern,
            _jinja_comment_pattern,
            _jinja_interpolate_pattern,
            (r'\\[^\s]', _token.String.Escape),
            (r'\\', _token.Operator, f"{name}_escape"),
            (rf'(?<!\\)(#[^{char}\n]*)(\s*\n\s*)(-)',
             _lexer.bygroups(_token.Comment.Single, _token.Whitespace, string_token), f'{name}_escape'),
            (rf'(?<!\\)(#[^{char}\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace), root_state),
            (r'#', string_token),
            (rf'[^{char}\\\n#{{]+\n\s*-', string_token, f'{name}_escape'),
            (rf'[^{char}\\\n#{{]+', string_token),
            (r'{[^{#%]', string_token),
            (char, string_token, '#pop'),
            (r'\n', _token.Whitespace, root_state),
            (r'$', _token.Whitespace, root_state),
        ],
        f'{name}_escape': [
            (r'(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace)),
            (r'(\s|\n)+', _token.Whitespace),
            (r'(?<=[-\n\s.])', _token.Whitespace, f'{name}_continue')
        ],
        f'{name}_continue': [
            _env_var_pattern,
            _jinja_block_pattern,
            _jinja_comment_pattern,
            _jinja_interpolate_pattern,
            (r'\\[^\s]', _token.String.Escape),
            (r'\\', _token.Operator, f"{name}_escape"),
            (rf'(?<!\\)(#[^\n]*)(\s*\n\s*)(-)',
             _lexer.bygroups(_token.Comment.Single, _token.Whitespace, string_token), f'{name}_escape'),
            (r'(?<!\\)(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace), root_state),
            (r'#', string_token),
            (rf'[^{char}\\\n#{{]+\n\s*-', string_token, f'{name}_escape'),
            (rf'[^{char}\\\n#{{]+', string_token),
            (r'{[^{#%]', string_token),
            (char, string_token, root_state),
            (r'\n', _token.Whitespace, root_state),
            (r'$', _token.Whitespace, root_state),
        ]}


[docs] class DgenerateLexer(_lexer.RegexLexer): """ pygments lexer for dgenerate configuration / script """ name = 'DgenerateLexer' aliases = ['dgenerate'] filenames = ['*.dgen'] tokens = { 'root': [ _comment_pattern, _env_var_pattern, (r'(?<!\w)(\\set[e]?|\\gen_seeds|\\download)(\s+)([a-zA-Z_][a-zA-Z0-9_]*)', _lexer.bygroups(_token.Name.Builtin, _token.Text.Whitespace, _token.Name.Variable), 'value'), (r'(?<!\w)(\\setp)(\s+)([a-zA-Z_][a-zA-Z0-9_]*)', _lexer.bygroups(_token.Name.Builtin, _token.Text.Whitespace, _token.Name.Variable), 'setp_value'), (r'(?<!\w)(\\import_plugins)(\s+)', _lexer.bygroups(_token.Name.Builtin, _token.Text.Whitespace), 'import_plugins_value'), (r'(?<!\w)(\\unset|\\save_modules|\\use_modules|\\clear_modules)(\s+)([a-zA-Z_][a-zA-Z0-9_]*)', _lexer.bygroups(_token.Name.Builtin, _token.Text.Whitespace, _token.Name.Variable)), (r'(?<!\w)(\\[a-zA-Z_][a-zA-Z0-9_]*)', _lexer.bygroups(_token.Name.Builtin)), (r'\b(%s)\b' % '|'.join(_SCHEDULER_KEYWORDS), _token.Keyword), (r'\b(%s)\b' % '|'.join(_VAE_KEYWORDS), _token.Keyword), (r'\b(%s)\b' % '|'.join(_MODEL_TYPE_KEYWORDS), _token.Keyword), (r'\b(%s)\b' % '|'.join(_DTYPE_KEYWORDS), _token.Keyword), (r'\bauto\b', _token.Keyword), (r'\bcpu\b', _token.Keyword), (r'\bcuda\b', _token.Keyword), (r'\bTrue|true\b', _token.Keyword), (r'\bFalse|false\b', _token.Keyword), (r'[!]END\b', _token.Keyword), _jinja_block_pattern, _jinja_comment_pattern, _jinja_interpolate_pattern, _size_pattern, *_number_patterns, *_path_patterns, _operators_punctuation_pattern, _operators_pattern, (r'"', _token.String.Double, 'double_string'), (r"'", _token.String.Single, 'single_string'), _text_pattern, (r'\s+', _token.Whitespace), ], 'value': [ (r'(\s*?\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\s*?\n', _token.Whitespace, 'root'), (r'\\[^\s]', _token.Operator), (r'(\\)(\s*?\n)', _lexer.bygroups(_token.Operator, _token.Whitespace), 'value_escape'), (r'(\\)(\s+)(#[^\n]*)', _lexer.bygroups(_token.Operator, _token.Whitespace, _token.Comment.Single), 'value_escape'), _comment_pattern, _env_var_pattern, _jinja_block_pattern, _jinja_comment_pattern, _jinja_interpolate_pattern, _size_pattern, *_number_patterns, *_path_patterns, _operators_punctuation_pattern, _operators_pattern, (r'"', _token.String.Double, 'double_string'), (r"'", _token.String.Single, 'single_string'), _text_pattern, (r'\s+', _token.Whitespace), ], 'setp_value': [ (r'(\s*?\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\s*?\n', _token.Whitespace, 'root'), (r'(\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\\[^\s]', _token.Operator), (r'(\\)(\s*?\n)', _lexer.bygroups(_token.Operator, _token.Whitespace), 'setp_value_escape'), (r'(\\)(\s+)(#[^\n]*)', _lexer.bygroups(_token.Operator, _token.Whitespace, _token.Comment.Single), 'setp_value_escape'), _comment_pattern, _env_var_pattern, (r'\b(%s)\b' % '|'.join(_SETP_KEYWORDS), _token.Keyword), (r'\b(%s)\b' % '|'.join(_DGENERATE_FUNCTIONS), _token.Name.Function), *((r'\b%s\b' % fun, _lexer.bygroups(_token.Name.Variable, _token.Operator, _token.Name.Function)) for fun in _DGENERATE_FUNCTIONS_NS), _function_call_generic, _variable_names_pattern, _jinja_block_pattern, _jinja_comment_pattern, _jinja_interpolate_pattern, *_number_patterns, _operators_punctuation_pattern, _operators_pattern, (r'"', _token.String.Double, 'double_string_setp_value'), (r"'", _token.String.Single, 'single_string_setp_value'), (r'\s+', _token.Whitespace), ], 'import_plugins_value': [ (r'(\s*?\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\s*?\n', _token.Whitespace, 'root'), (r'(\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\\[^\s]', _token.Operator), (r'(\\)(\s*?\n)', _lexer.bygroups(_token.Operator, _token.Whitespace), 'import_plugins_value_escape'), (r'(\\)(\s+)(#[^\n]*)', _lexer.bygroups(_token.Operator, _token.Whitespace, _token.Comment.Single), 'import_plugins_value_escape'), _comment_pattern, _env_var_pattern, *((p[0], _token.Keyword) for p in _path_patterns[1:4]), # everything except http and file.ext (_variable_names_pattern[0], _token.Keyword), # module names are keywords (r'[.]', _token.Operator), # namespace operator _jinja_block_pattern, _jinja_comment_pattern, _jinja_interpolate_pattern, (r'"', _token.String.Double, 'double_string_import_plugins_value'), (r"'", _token.String.Single, 'single_string_import_plugins_value'), (r'\s+', _token.Whitespace), ], f'import_plugins_value_escape': [ (r'(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace)), (r'(\s|\n)+', _token.Whitespace), (r'(?<=[-\n\s.])', _token.Whitespace, 'import_plugins_value') ], f'setp_value_escape': [ (r'(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace)), (r'(\s|\n)+', _token.Whitespace), (r'(?<=[-\n\s.])', _token.Whitespace, 'setp_value') ], f'value_escape': [ (r'(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace)), (r'(\s|\n)+', _token.Whitespace), (r'(?<=[-\n\s.])', _token.Whitespace, 'value') ], **_create_string_continue('double_string', '"', 'root'), **_create_string_continue('single_string', "'", 'root'), **_create_string_continue('double_string_setp_value', '"', 'setp_value'), **_create_string_continue('single_string_setp_value', "'", 'setp_value'), **_create_string_continue('double_string_import_plugins_value', '"', 'import_plugins_value'), **_create_string_continue('single_string_import_plugins_value', "'", 'import_plugins_value'), 'jinja_block': [ # End of Jinja2 block statement (r'(\s*)(%\})', _lexer.bygroups(_token.Text, _token.String.Interpol), '#pop'), # Function names (r'\b(%s)\b' % '|'.join(_DGENERATE_FUNCTIONS), _token.Name.Function), *((r'\b%s\b' % fun, _lexer.bygroups(_token.Name.Variable, _token.Operator, _token.Name.Function)) for fun in _DGENERATE_FUNCTIONS_NS), _function_call_generic, # Jinja2 Function names (r'\b(%s)\b' % '|'.join(_JINJA2_FUNCTIONS), _token.Name.Function), # Jinja2 Keywords (r'\b(%s)\b' % '|'.join(_JINJA2_KEYWORDS), _token.Keyword), # Variable names _variable_names_pattern, # Strings (r':?"(\\\\|\\[^\\]|[^"\\])*"', _token.String.Double), (r":?'(\\\\|\\[^\\]|[^'\\])*'", _token.String.Single), # Numbers *_number_patterns, # Operators and punctuation _operators_punctuation_pattern, _operators_pattern, # Other text _text_pattern, (r'\s+', _token.Whitespace), ], 'jinja_comment': [ # End of Jinja2 comment (r'(#\})', _token.Comment.Multiline, '#pop'), # Comment content (r'.+?(?=#\})', _token.Comment.Multiline), ], 'jinja_interpolate': [ # End of Jinja2 template expression (r'(\}\})', _token.String.Interpol, '#pop'), # Function names (r'\b(%s)\b' % '|'.join(_DGENERATE_FUNCTIONS), _token.Name.Function), *((r'\b%s\b' % fun, _lexer.bygroups(_token.Name.Variable, _token.Operator, _token.Name.Function)) for fun in _DGENERATE_FUNCTIONS_NS), _function_call_generic, # Jinja2 Function names (r'\b(%s)\b' % '|'.join(_JINJA2_FUNCTIONS), _token.Name.Function), # Jinja2 Keywords (r'\b(%s)\b' % '|'.join(_JINJA2_KEYWORDS), _token.Keyword), # Variable names _variable_names_pattern, # Strings (r':?"(\\\\|\\[^\\]|[^"\\])*"', _token.String.Double), (r":?'(\\\\|\\[^\\]|[^'\\])*'", _token.String.Single), # Numbers *_number_patterns, # Operators and punctuation _operators_punctuation_pattern, _operators_pattern, # Other text _text_pattern, (r'\s+', _token.Whitespace), ], }