Source code for dgenerate.promptupscalers.promptupscaler

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import abc
import typing

import torch

import dgenerate.promptupscalers.exceptions as _exceptions
import dgenerate.plugin as _plugin
import dgenerate.prompt as _prompt
import dgenerate.promptupscalers.constants as _constants
import dgenerate.messages as _messages
import dgenerate.devicecache as _devicecache
import dgenerate.memoize as _memoize
import dgenerate.memory as _memory
import dgenerate.torchutil as _torchutil

_prompt_upscaler_cache = _memoize.create_object_cache(
    'prompt_upscaler',
    cache_type=_memory.SizedConstrainedObjectCache
)


def _cache_debug_hit(key, hit):
    _memoize.simple_cache_hit_debug("Prompt Upscaler Model", key, hit)


def _cache_debug_miss(key, new):
    _memoize.simple_cache_miss_debug("Prompt Upscaler Model", key, new)


[docs] class PromptUpscaler(_plugin.Plugin, abc.ABC): """ Abstract base class for prompt upscaler implementations. """ # you cannot specify these via a URI HIDE_ARGS = ['local-files-only']
[docs] @classmethod def inheritable_help(cls, loaded_by_name): help_messages = { 'device': ( 'The "device" argument can be used to set the device ' 'the prompt upscaler will run any models on, for example: cpu, cuda, cuda:1. ' 'this argument will default to the value of the dgenerate argument --device.' ) } return help_messages
[docs] def __init__(self, loaded_by_name: str, device: str | None = None, local_files_only: bool = False, **kwargs): """ :param loaded_by_name: The name the prompt upscaler was loaded by :param device: Torch device string for running any models, passing ``None`` defaults the device to ``cpu`` :param local_files_only: if ``True``, the plugin should never try to download models from the internet automatically, and instead only look for them in cache / on disk. :param kwargs: child class forwarded arguments """ super().__init__(loaded_by_name=loaded_by_name, argument_error_type=_exceptions.PromptUpscalerArgumentError, **kwargs) if device is not None: if not _torchutil.is_valid_device_string(device): raise _exceptions.PromptUpscalerArgumentError( f'Invalid device argument, {_torchutil.invalid_device_message(device, cap=False)}') self.__device = device if device else 'cpu' self.__local_files_only = local_files_only self.__size_estimate = 0
[docs] def memory_guard_device(self, device: str | torch.device, memory_required: int): """ Check a specific device against an amount of memory in bytes. If the device is a gpu device and any of the memory constraints specified by :py:attr:`dgenerate.promptupscalers.constants.PROMPT_UPSCALER_GPU_MEMORY_CONSTRAINTS` are met on that device, attempt to remove cached objects off a gpu device to free space. If the device is a cpu and any of the memory constraints specified by :py:attr:`dgenerate.promptupscalers.constants.PROMPT_UPSCALER_CACHE_GC_CONSTRAINTS` are met, attempt to remove cached prompt upscaler objects off the device to free space. Then, enforce :py:attr:`dgenerate.promptupscalers.constants.PROMPT_UPSCALER_CACHE_MEMORY_CONSTRAINTS`. :param device: the device :param memory_required: the amount of memory required on the device in bytes :return: ``True`` if an attempt was made to free memory, ``False`` otherwise. """ device = torch.device(device) cleared = False if _memory.is_supported_gpu_device(device): if _memory.gpu_memory_constraints( _constants.PROMPT_UPSCALER_GPU_MEMORY_CONSTRAINTS, extra_vars={'memory_required': memory_required}, device=device): _messages.debug_log( f'Prompt Upscaler "{self.__class__.__name__}" is clearing the GPU side object ' f'cache due to GPU side memory constraint evaluating to to True.') _devicecache.clear_device_cache(device) cleared = True elif device.type == 'cpu': if (_memory.memory_constraints( _constants.PROMPT_UPSCALER_CACHE_GC_CONSTRAINTS, extra_vars={'memory_required': memory_required})): _messages.debug_log( f'Prompt upscaler "{self.__class__.__name__}" is clearing the CPU side object ' f'cache due to CPU side memory constraint evaluating to to True.') _memoize.clear_object_caches() cleared = True cleared = cleared or _prompt_upscaler_cache.enforce_cpu_mem_constraints( _constants.PROMPT_UPSCALER_CACHE_MEMORY_CONSTRAINTS, size_var='memory_required', new_object_size=memory_required ) return cleared
[docs] def set_size_estimate(self, size_bytes: int): """ Set the estimated size of this plugin in bytes for memory management heuristics, this is intended to be used by implementors of the :py:class:`PromptUpscaler` plugin class. For the best memory optimization, this value should be set very shortly before any associated model even enters CPU side ram, IE: before it is loaded at all. :raise ValueError: if ``size_bytes`` is less than zero. :param size_bytes: the size in bytes """ if size_bytes < 0: raise ValueError( 'prompt upscaler size estimate cannot be less than zero.') self.__size_estimate = int(size_bytes)
[docs] def load_object_cached(self, tag: str, estimated_size: int, method: typing.Callable, memory_guard_device: str | torch.device | None = 'cpu' ): """ Load a potentially large object into the CPU side ``prompt_upscaler`` object cache. :param tag: A unique string within the context of the image processor implementation constructor. :param estimated_size: Estimated size in bytes of the object in RAM. :param method: A method which loads and returns the object. :param memory_guard_device: call :py:meth:`PromptUpscaler.memory_guard_device` on the specified device before the object is loaded (on cache miss) :return: The loaded object """ @_memoize.memoize( _prompt_upscaler_cache, on_hit=_cache_debug_hit, on_create=_cache_debug_miss) def load_cached(loaded_by_name=self.loaded_by_name, tag=tag): if memory_guard_device is not None: self.memory_guard_device(memory_guard_device, estimated_size) return method(), _memoize.CachedObjectMetadata(size=estimated_size) return load_cached()
@property def device(self) -> str: """ Device that will be used for any text processing models. """ return self.__device @property def local_files_only(self) -> bool: """ Is this prompt upscaler only going to look for resources such as models in cache / on disk? """ return self.__local_files_only @property def size_estimate(self) -> int: """ Estimated size of the models / objects used by this prompt upscaler. :return: size in bytes """ return self.__size_estimate @property def accepts_batch(self): """ Can this prompt upscaler accept a batch of prompts? The implementor must override this property, this is a default implementation. :return: Default: ``False`` """ return False
[docs] @abc.abstractmethod def upscale(self, prompt: _prompt.PromptOrPrompts) -> _prompt.PromptOrPrompts: """ Upscale a prompt / prompts and return them modified. :param prompt: The incoming prompt or prompts :return: Modified prompt / prompts, you may return multiple prompts (an iterable) to indicate expansion """ return prompt