Source code for dgenerate.pygments

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import pygments.lexer as _lexer
import pygments.token as _token

__doc__ = """
This module provides a pygments lexer for the dgenerate config / shell language.

This can be used for syntax highlighting.
"""

# Define Jinja2 keywords
_JINJA2_KEYWORDS = sorted((
    'true', 'false', 'none', 'and', 'or', 'not', 'if', 'else', 'elif', 'for', 'endfor', 'in', 'do', 'async', 'await',
    'block', 'extends', 'include', 'import', 'from', 'macro', 'call', 'set', 'with', 'without', 'filter',
    'endfilter', 'capture', 'endcapture', 'spaceless', 'endspaceless', 'flush', 'load', 'url', 'static',
    'trans', 'endtrans', 'as', 'safe', 'end', 'autoescape', 'endautoescape', 'raw', 'endraw'
), key=lambda s: len(s), reverse=True)

# asteval keywords
_SETP_KEYWORDS = sorted((
    'for', 'in', 'if', 'else', 'elif', 'and', 'or', 'not', 'is', 'True', 'False'
), key=lambda s: len(s), reverse=True)

# Common patterns
_ecos = r'(\\[^\s]|[^/\s\\#=;:{}"\'])'
_comment_pattern = (r'(?<!\\)(#.*$)', _token.Comment.Single)
_env_var_pattern = (r'\$(?:\w+|\{\w+\})|%[\w]+%', _token.Name.Constant)

_jinja_block_pattern = (
    r'(\{%)(\s*)(\w+)',
    _lexer.bygroups(
        _token.String.Interpol,
        _token.Whitespace,
        _token.Keyword
    ), 'jinja_block'
)

_jinja_comment_pattern = (r'(\{#)', _token.Comment.Multiline, 'jinja_comment')
_jinja_interpolate_pattern = (r'(\{\{)', _token.String.Interpol, 'jinja_interpolate')
_operators_punctuation_pattern = (r'[\[\]{}()=\\;,:]', _token.Operator)
_operators_pattern = (r'\*\*|<<|>>|[-+*/%^|&<>!.$]', _token.Operator)
_dimensions_pattern = (r'(?<!\w)-?\d+[xX]-?\d+([xX]-?\d+){0,2}(?!\w)', _token.Number.Hex)
_number_float_pattern = (r'(?<!\w)(-?\d+(\.\d*)?([eE][-+]?\d+)?)(?!\w)', _token.Number.Float)
_decimal_integer_pattern = (r'(?<!\w)-?\d+(?!\w)', _token.Number.Integer)
_binary_integer_pattern = (r'(?<!\w)0[bB][01]+(?!\w)', _token.Number.Binary)
_hexa_decimal_integer_pattern = (r'(?<!\w)0[xX][0-9a-fA-F]+(?!\w)', _token.Number.Hex)
_octal_integer_pattern = (r'(?<!\w)0[oO][0-7]+(?!\w)', _token.Number.Octal)
_text_pattern = (r'[^=\s\[\]{}()$"\'`\\<&|;:,]+', _token.Text)
_variable_names_pattern = (r'(?<!\w)[a-zA-Z_][a-zA-Z0-9_]*(?!\w)', _token.Name.Variable)

_function_call_pattern = (
    r'(?<!\w)([a-zA-Z_][a-zA-Z0-9_]*)(\()',
    _lexer.bygroups(_token.Name.Function,
                    _token.Operator)
)

_http_pattern = (
    r'(?<!\w)(https?://(?:'  # Protocol
    r'(?:[a-zA-Z0-9\-]+\.)+[a-zA-Z]{2,6}|'  # Domain
    r'localhost|'  # Localhost
    r'(?:\d{1,3}\.){3}\d{1,3}|'  # IPv4
    r'\[[a-fA-F0-9:]+\])'  # IPv6
    r'(?::\d+)?'  # Optional port
    r'(?:/(\\[$%\']?|[a-zA-Z0-9\-._~%!$&\'()*+,=:@/])*)?'  # Path (including allowed shell escapes)
    r'(?:\?(\\[$%\']?|[a-zA-Z0-9\-._~%!$&\'()*+,=:@/?])*)?'  # Query parameters (including allowed shell escapes)
    r')', _token.String
)

_path_patterns = (
    _http_pattern,
    (rf'(?<!\w)([a-zA-Z]:(([/]|\\\\){_ecos}*)+)', _token.String),  # Drive letters
    (rf'(?<!\w)(~|..?)?(([/]|\\\\){_ecos}+)(([/]|\\\\){_ecos}*)*', _token.String),
    # absolute paths with relative components
    (rf'(?<!\w)({_ecos}+)(([/]|\\\\){_ecos}*)+', _token.String),  # relative paths with relative components
    (rf'(?<!\w){_ecos}+\.{_ecos}+', _token.String)
)

_MISC_KEYWORDS = sorted((
    'help',
    'helpargs',
    'null',
    'auto',
    'cpu',
    'cuda',
    'mps'
), key=lambda s: len(s), reverse=True)

_SCHEDULER_KEYWORDS = sorted((
    "DDIMScheduler",
    "DDPMScheduler",
    "DEISMultistepScheduler",
    "DPMSolverMultistepScheduler",
    "DPMSolverSDEScheduler",
    "DPMSolverSinglestepScheduler",
    "EDMEulerScheduler",
    "EulerAncestralDiscreteScheduler",
    "EulerDiscreteScheduler",
    "HeunDiscreteScheduler",
    "KDPM2AncestralDiscreteScheduler",
    "KDPM2DiscreteScheduler",
    "LCMScheduler",
    "LMSDiscreteScheduler",
    "PNDMScheduler",
    "UniPCMultistepScheduler",
    "DDPMWuerstchenScheduler",
    "FlowMatchEulerDiscreteScheduler"
), key=lambda s: len(s), reverse=True)

_CLASS_KEYWORDS = sorted((
    "AutoencoderKL",
    "AsymmetricAutoencoderKL",
    "AutoencoderTiny",
    "ConsistencyDecoderVAE",
    "CLIPTextModel",
    "CLIPTextModelWithProjection",
    "T5EncoderModel",
    "ChatGLMModel",
    "DistillT5EncoderModel"
), key=lambda s: len(s), reverse=True)

_MODEL_TYPE_KEYWORDS = sorted((
    'sd',
    'pix2pix',
    'sd3',
    'sd3-pix2pix',
    'flux',
    'flux-fill',
    'flux-kontext',
    'sdxl',
    'kolors',
    'if',
    'ifs',
    'ifs-img2img',
    'sdxl-pix2pix',
    'upscaler-x2',
    'upscaler-x4',
    's-cascade'
), key=lambda s: len(s), reverse=True)

_DTYPE_KEYWORDS = sorted((
    'float16',
    'bfloat16',
    'float32',
    'float64',
    'fp16',
    'bf16'
), key=lambda s: len(s), reverse=True)

_number_patterns = (_number_float_pattern,
                    _decimal_integer_pattern,
                    _binary_integer_pattern,
                    _hexa_decimal_integer_pattern,
                    _octal_integer_pattern)


def _create_string_continue(name, char, root_state):
    string_token = _token.String.Double if char == '"' else _token.String.Single
    env_vars = [_env_var_pattern] if char == '"' else []
    return {
        name: env_vars + [
            _jinja_block_pattern,
            _jinja_comment_pattern,
            _jinja_interpolate_pattern,
            (r'\\[^\s]', _token.String.Escape),
            (r'\\', _token.Operator, f"{name}_escape"),
            (rf'(?<!\\)(#[^{char}\n]*)(\s*\n\s*)(-)',
             _lexer.bygroups(_token.Comment.Single, _token.Whitespace, string_token), f'{name}_escape'),
            (rf'(?<!\\)(#[^{char}\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace), root_state),
            (r'#', string_token),
            (rf'[^{char}\\\n#{{]+\n\s*-', string_token, f'{name}_escape'),
            (rf'[^{char}\\\n#{{]+', string_token),
            (r'{[^{#%]', string_token),
            (char, string_token, '#pop'),
            (r'\n', _token.Whitespace, root_state),
            (r'$', _token.Whitespace, root_state),
        ],
        f'{name}_escape': [
            (r'(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace)),
            (r'(\s|\n)+', _token.Whitespace),
            (r'(?<=[-\n\s.])', _token.Whitespace, f'{name}_continue')
        ],
        f'{name}_continue': env_vars + [
            _jinja_block_pattern,
            _jinja_comment_pattern,
            _jinja_interpolate_pattern,
            (r'\\[^\s]', _token.String.Escape),
            (r'\\', _token.Operator, f"{name}_escape"),
            (rf'(?<!\\)(#[^\n]*)(\s*\n\s*)(-)',
             _lexer.bygroups(_token.Comment.Single, _token.Whitespace, string_token), f'{name}_escape'),
            (r'(?<!\\)(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace), root_state),
            (r'#', string_token),
            (rf'[^{char}\\\n#{{]+\n\s*-', string_token, f'{name}_escape'),
            (rf'[^{char}\\\n#{{]+', string_token),
            (r'{[^{#%]', string_token),
            (char, string_token, root_state),
            (r'\n', _token.Whitespace, root_state),
            (r'$', _token.Whitespace, root_state),
        ]}


def _create_wait_for_var(name, next_state):
    return {
        name: [
            ('([a-zA-Z_][a-zA-Z0-9_]*)', _token.Name.Variable, next_state),
            _env_var_pattern + (next_state,),
            (r'(?<=\}\})\s', _token.Whitespace, next_state),
            (r'(?<=%\})\s', _token.Whitespace, next_state),
            (r'(?<=#\})\s', _token.Whitespace, next_state),
            _comment_pattern,
            _jinja_block_pattern,
            _jinja_comment_pattern,
            _jinja_interpolate_pattern,
            _operators_punctuation_pattern,
            _operators_pattern,
            _text_pattern,
            (r'\s+', _token.Whitespace),
        ]
    }

def _keyword_top_level(s):
    return rf'(?<![-_])\b({s})\b(?![-_])'

[docs] class DgenerateLexer(_lexer.RegexLexer): """ pygments lexer for dgenerate configuration / script """ name = 'DgenerateLexer' aliases = ['dgenerate'] filenames = ['*.dgen'] tokens = { **_create_wait_for_var('var_then_value', 'value'), **_create_wait_for_var('var_then_setp_value', 'setp_value'), **_create_wait_for_var('var_then_root', 'root'), 'root': [ _comment_pattern, _env_var_pattern, (r'(?<!\w)(\\set[e]?|\\gen_seeds|\\download)(\s+)', _lexer.bygroups(_token.Name.Builtin, _token.Text.Whitespace), 'var_then_value'), (r'(?<!\w)(\\setp)(\s+)', _lexer.bygroups(_token.Name.Builtin, _token.Text.Whitespace), 'var_then_setp_value'), (r'(?<!\w)(\\import_plugins)(\s+)', _lexer.bygroups(_token.Name.Builtin, _token.Text.Whitespace), 'import_plugins_value'), (r'(?<!\w)(\\import)(\s+)', _lexer.bygroups(_token.Name.Builtin, _token.Text.Whitespace), 'import_value'), (r'(?<!\w)(\\unset|\\save_modules|\\use_modules|\\clear_modules)(\s+)', _lexer.bygroups(_token.Name.Builtin, _token.Text.Whitespace), 'var_then_root'), (r'(?<!\w)(\\[a-zA-Z_][a-zA-Z0-9_]*)', _lexer.bygroups(_token.Name.Builtin)), (r'\\[^\s]', _token.Escape), (_keyword_top_level('|'.join(_MISC_KEYWORDS)), _token.Keyword), (_keyword_top_level('|'.join(_SCHEDULER_KEYWORDS)), _token.Keyword), (_keyword_top_level('|'.join(_CLASS_KEYWORDS)), _token.Keyword), (_keyword_top_level('|'.join(_MODEL_TYPE_KEYWORDS)), _token.Keyword), (_keyword_top_level('|'.join(_DTYPE_KEYWORDS)), _token.Keyword), (_keyword_top_level('True|true'), _token.Keyword), (_keyword_top_level('False|false'), _token.Keyword), _jinja_block_pattern, _jinja_comment_pattern, _jinja_interpolate_pattern, _dimensions_pattern, *_number_patterns, *_path_patterns, _operators_punctuation_pattern, _operators_pattern, _function_call_pattern, (r'"', _token.String.Double, 'double_string'), (r"'", _token.String.Single, 'single_string'), _text_pattern, (r'\s+', _token.Whitespace), ], 'value': [ (r'(\s*?\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\s*?\n', _token.Whitespace, 'root'), (r'\\[^\s]', _token.Escape), (r'(\\)(\s*?\n)', _lexer.bygroups(_token.Operator, _token.Whitespace), 'value_escape'), (r'(\\)(\s+)(#[^\n]*)', _lexer.bygroups(_token.Operator, _token.Whitespace, _token.Comment.Single), 'value_escape'), _comment_pattern, _env_var_pattern, _jinja_block_pattern, _jinja_comment_pattern, _jinja_interpolate_pattern, _dimensions_pattern, *_number_patterns, *_path_patterns, _operators_punctuation_pattern, _operators_pattern, _function_call_pattern, (r'"', _token.String.Double, 'double_string'), (r"'", _token.String.Single, 'single_string'), _text_pattern, (r'\s+', _token.Whitespace), ], 'setp_value': [ (r'(\s*?\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\s*?\n', _token.Whitespace, 'root'), (r'(\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\\[^\s]', _token.Escape), (r'(\\)(\s*?\n)', _lexer.bygroups(_token.Operator, _token.Whitespace), 'setp_value_escape'), (r'(\\)(\s+)(#[^\n]*)', _lexer.bygroups(_token.Operator, _token.Whitespace, _token.Comment.Single), 'setp_value_escape'), _comment_pattern, _env_var_pattern, (r'\b(%s)\b' % '|'.join(_SETP_KEYWORDS), _token.Keyword), _function_call_pattern, _variable_names_pattern, _jinja_block_pattern, _jinja_comment_pattern, _jinja_interpolate_pattern, *_number_patterns, _operators_punctuation_pattern, _operators_pattern, (r'"', _token.String.Double, 'double_string_setp_value'), (r"'", _token.String.Single, 'single_string_setp_value'), (r'\s+', _token.Whitespace), ], 'import_plugins_value': [ (r'(\s*?\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\s*?\n', _token.Whitespace, 'root'), (r'(\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\\[^\s]', _token.Escape), (r'(\\)(\s*?\n)', _lexer.bygroups(_token.Operator, _token.Whitespace), 'import_plugins_value_escape'), (r'(\\)(\s+)(#[^\n]*)', _lexer.bygroups(_token.Operator, _token.Whitespace, _token.Comment.Single), 'import_plugins_value_escape'), _comment_pattern, _env_var_pattern, *((p[0], _token.Keyword) for p in _path_patterns[1:4]), # everything except http and file.ext (_variable_names_pattern[0], _token.Keyword), # plugin names are keywords (r'[.]', _token.Operator), # namespace operator _jinja_block_pattern, _jinja_comment_pattern, _jinja_interpolate_pattern, (r'"', _token.String.Double, 'double_string_import_plugins_value'), (r"'", _token.String.Single, 'single_string_import_plugins_value'), (r'\s+', _token.Whitespace), ], 'import_value': [ (r'(\s*?\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\s*?\n', _token.Whitespace, 'root'), (r'(\n\s*?)(-)', _lexer.bygroups(_token.Whitespace, _token.Operator)), (r'\\[^\s]', _token.Escape), (r'(\\)(\s*?\n)', _lexer.bygroups(_token.Operator, _token.Whitespace), 'import_value_escape'), (r'(\\)(\s+)(#[^\n]*)', _lexer.bygroups(_token.Operator, _token.Whitespace, _token.Comment.Single), 'import_value_escape'), _comment_pattern, _env_var_pattern, *((p[0], _token.Keyword) for p in _path_patterns[1:4]), # everything except http and file.ext (r'(?<!\.)as(?!\.)', _token.Keyword), # 'as' keyword for imports - only when standalone (_variable_names_pattern[0], _token.Name.Namespace), # python module names are namespace (r'[.]', _token.Operator), # namespace operator _jinja_block_pattern, _jinja_comment_pattern, _jinja_interpolate_pattern, (r'"', _token.String.Double, 'double_string_import_value'), (r"'", _token.String.Single, 'single_string_import_value'), (r'\s+', _token.Whitespace), ], f'import_value_escape': [ (r'(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace)), (r'(\s|\n)+', _token.Whitespace), (r'(?<=[-\n\s.])', _token.Whitespace, 'import_value') ], f'import_plugins_value_escape': [ (r'(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace)), (r'(\s|\n)+', _token.Whitespace), (r'(?<=[-\n\s.])', _token.Whitespace, 'import_plugins_value') ], f'setp_value_escape': [ (r'(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace)), (r'(\s|\n)+', _token.Whitespace), (r'(?<=[-\n\s.])', _token.Whitespace, 'setp_value') ], f'value_escape': [ (r'(#[^\n]*)(\s*\n)', _lexer.bygroups(_token.Comment.Single, _token.Whitespace)), (r'(\s|\n)+', _token.Whitespace), (r'(?<=[-\n\s.])', _token.Whitespace, 'value') ], **_create_string_continue('double_string', '"', 'root'), **_create_string_continue('single_string', "'", 'root'), **_create_string_continue('double_string_setp_value', '"', 'setp_value'), **_create_string_continue('single_string_setp_value', "'", 'setp_value'), **_create_string_continue('double_string_import_plugins_value', '"', 'import_plugins_value'), **_create_string_continue('single_string_import_plugins_value', "'", 'import_plugins_value'), **_create_string_continue('double_string_import_value', '"', 'import_value'), **_create_string_continue('single_string_import_value', "'", 'import_value'), 'jinja_block': [ # End of Jinja2 block statement (r'(\s*)(%\})', _lexer.bygroups(_token.Text, _token.String.Interpol), '#pop'), # Function calls _function_call_pattern, # Jinja2 Keywords (r'\b(%s)\b' % '|'.join(_JINJA2_KEYWORDS), _token.Keyword), # Variable names _variable_names_pattern, # Strings (r':?"(\\\\|\\[^\\]|[^"\\])*"', _token.String.Double), (r":?'(\\\\|\\[^\\]|[^'\\])*'", _token.String.Single), # Numbers *_number_patterns, # Operators and punctuation _operators_punctuation_pattern, _operators_pattern, # Other text _text_pattern, (r'\s+', _token.Whitespace), ], 'jinja_comment': [ # End of Jinja2 comment (r'(#\})', _token.Comment.Multiline, '#pop'), # Comment content (r'.+?(?=#\})', _token.Comment.Multiline), ], 'jinja_interpolate': [ # End of Jinja2 template expression (r'(\}\})', _token.String.Interpol, '#pop'), # Function calls _function_call_pattern, # Jinja2 Keywords (r'\b(%s)\b' % '|'.join(_JINJA2_KEYWORDS), _token.Keyword), # Variable names _variable_names_pattern, # Strings (r':?"(\\\\|\\[^\\]|[^"\\])*"', _token.String.Double), (r":?'(\\\\|\\[^\\]|[^'\\])*'", _token.String.Single), # Numbers *_number_patterns, # Operators and punctuation _operators_punctuation_pattern, _operators_pattern, # Other text _text_pattern, (r'\s+', _token.Whitespace), ], }