Source code for dgenerate.pipelinewrapper.cache

# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
#    the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import gc
import typing

import diffusers

import dgenerate.memoize as _d_memoize
import dgenerate.memory as _memory
import dgenerate.messages as _messages
import dgenerate.types as _types

_TORCH_PIPELINE_CACHE = dict()
"""Global in memory cache for torch diffusers pipelines"""

_FLAX_PIPELINE_CACHE = dict()
"""Global in memory cache for flax diffusers pipelines"""

_PIPELINE_CACHE_SIZE = 0
"""Estimated memory consumption in bytes of all pipelines cached in memory"""

_TORCH_CONTROL_NET_CACHE = dict()
"""Global in memory cache for torch ControlNet models"""

_FLAX_CONTROL_NET_CACHE = dict()
"""Global in memory cache for flax ControlNet models"""

_CONTROL_NET_CACHE_SIZE = 0
"""Estimated memory consumption in bytes of all ControlNet models cached in memory"""

_TORCH_VAE_CACHE = dict()
"""Global in memory cache for torch VAE models"""

_FLAX_VAE_CACHE = dict()
"""Global in memory cache for flax VAE models"""

_TORCH_UNET_CACHE = dict()
"""Global in memory cache for torch UNet models"""

_FLAX_UNET_CACHE = dict()
"""Global in memory cache for flax UNet models"""

_VAE_CACHE_SIZE = 0
"""Estimated memory consumption in bytes of all VAE models cached in memory"""

_UNET_CACHE_SIZE = 0
"""Estimated memory consumption in bytes of all UNet models cached in memory"""


[docs] def pipeline_cache_size() -> int: """ Return the estimated memory usage in bytes of all diffusers pipelines currently cached in memory. :return: memory usage in bytes. """ return _PIPELINE_CACHE_SIZE
[docs] def vae_cache_size() -> int: """ Return the estimated memory usage in bytes of all user specified VAEs currently cached in memory. :return: memory usage in bytes. """ return _VAE_CACHE_SIZE
[docs] def unet_cache_size() -> int: """ Return the estimated memory usage in bytes of all user specified UNets currently cached in memory. :return: memory usage in bytes. """ return _UNET_CACHE_SIZE
[docs] def control_net_cache_size() -> int: """ Return the estimated memory usage in bytes of all user specified ControlNet models currently cached in memory. :return: memory usage in bytes. """ return _CONTROL_NET_CACHE_SIZE
CACHE_MEMORY_CONSTRAINTS: list[str] = ['used_percent > 70'] """ Cache constraint expressions for when to clear all model caches (DiffusionPipeline, VAE, and ControlNet), syntax provided via :py:func:`dgenerate.memory.memory_constraints` If any of these constraints are met, a call to :py:func:`.enforce_cache_constraints` will call :py:func:`.clear_model_cache` and force a garbage collection. """ PIPELINE_CACHE_MEMORY_CONSTRAINTS: list[str] = ['pipeline_size > (available * 0.75)'] """ Cache constraint expressions for when to clear the DiffusionPipeline cache, syntax provided via :py:func:`dgenerate.memory.memory_constraints` If any of these constraints are met, a call to :py:func:`.enforce_pipeline_cache_constraints` will call :py:func:`.clear_pipeline_cache` and force a garbage collection. Extra variables include: ``cache_size`` (the current estimated cache size in bytes), and ``pipeline_size`` (the estimated size of the new pipeline before it is brought into memory, in bytes) """ UNET_CACHE_MEMORY_CONSTRAINTS: list[str] = ['unet_size > (available * 0.75)'] """ Cache constraint expressions for when to clear UNet cache, syntax provided via :py:func:`dgenerate.memory.memory_constraints` If any of these constraints are met, a call to :py:func:`.enforce_unet_cache_constraints` will call :py:func:`.clear_unet_cache` and force a garbage collection. Extra variables include: ``cache_size`` (the current estimated cache size in bytes), and ``unet_size`` (the estimated size of the new UNet before it is brought into memory, in bytes) """ VAE_CACHE_MEMORY_CONSTRAINTS: list[str] = ['vae_size > (available * 0.75)'] """ Cache constraint expressions for when to clear VAE cache, syntax provided via :py:func:`dgenerate.memory.memory_constraints` If any of these constraints are met, a call to :py:func:`.enforce_vae_cache_constraints` will call :py:func:`.clear_vae_cache` and force a garbage collection. Extra variables include: ``cache_size`` (the current estimated cache size in bytes), and ``vae_size`` (the estimated size of the new VAE before it is brought into memory, in bytes) """ CONTROL_NET_CACHE_MEMORY_CONSTRAINTS: list[str] = ['control_net_size > (available * 0.75)'] """ Cache constraint expressions for when to clear the ControlNet cache, syntax provided via :py:func:`dgenerate.memory.memory_constraints` If any of these constraints are met, a call to :py:func:`.enforce_control_net_cache_constraints` will call :py:func:`.clear_control_net_cache` and force a garbage collection. Extra variables include: ``cache_size`` (the current estimated cache size in bytes), and ``control_net_size`` (the estimated size of the new ControlNet before it is brought into memory, in bytes) """
[docs] def clear_pipeline_cache(collect=True): """ Clear DiffusionPipeline cache and then garbage collect. :param collect: Call :py:func:`gc.collect` ? """ global _TORCH_PIPELINE_CACHE, \ _FLAX_PIPELINE_CACHE, \ _PIPELINE_CACHE_SIZE _TORCH_PIPELINE_CACHE.clear() _FLAX_PIPELINE_CACHE.clear() _PIPELINE_CACHE_SIZE = 0 if collect: _messages.debug_log( f'{_types.fullname(clear_pipeline_cache)} calling gc.collect() by request') gc.collect()
[docs] def clear_control_net_cache(collect=True): """ Clear ControlNet cache and then garbage collect. :param collect: Call :py:func:`gc.collect` ? """ global _TORCH_CONTROL_NET_CACHE, \ _FLAX_CONTROL_NET_CACHE, \ _CONTROL_NET_CACHE_SIZE _TORCH_CONTROL_NET_CACHE.clear() _FLAX_CONTROL_NET_CACHE.clear() _CONTROL_NET_CACHE_SIZE = 0 if collect: _messages.debug_log( f'{_types.fullname(clear_control_net_cache)} calling gc.collect() by request') gc.collect()
[docs] def clear_vae_cache(collect=True): """ Clear VAE cache and then garbage collect. :param collect: Call :py:func:`gc.collect` ? """ global _TORCH_VAE_CACHE, \ _FLAX_VAE_CACHE, \ _VAE_CACHE_SIZE _TORCH_VAE_CACHE.clear() _FLAX_VAE_CACHE.clear() _VAE_CACHE_SIZE = 0 if collect: _messages.debug_log( f'{_types.fullname(clear_vae_cache)} calling gc.collect() by request') gc.collect()
[docs] def clear_unet_cache(collect=True): """ Clear UNet cache and then garbage collect. :param collect: Call :py:func:`gc.collect` ? """ global _TORCH_UNET_CACHE, \ _FLAX_UNET_CACHE, \ _UNET_CACHE_SIZE _TORCH_UNET_CACHE.clear() _FLAX_UNET_CACHE.clear() _UNET_CACHE_SIZE = 0 if collect: _messages.debug_log( f'{_types.fullname(clear_unet_cache)} calling gc.collect() by request') gc.collect()
[docs] def clear_model_cache(collect=True): """ Clear all in memory model caches and garbage collect. :param collect: Call :py:func:`gc.collect` ? """ global _TORCH_PIPELINE_CACHE, \ _FLAX_PIPELINE_CACHE, \ _TORCH_CONTROL_NET_CACHE, \ _FLAX_CONTROL_NET_CACHE, \ _TORCH_UNET_CACHE, \ _FLAX_UNET_CACHE, \ _TORCH_VAE_CACHE, \ _FLAX_VAE_CACHE, \ _PIPELINE_CACHE_SIZE, \ _CONTROL_NET_CACHE_SIZE, \ _UNET_CACHE_SIZE, \ _VAE_CACHE_SIZE _TORCH_PIPELINE_CACHE.clear() _FLAX_PIPELINE_CACHE.clear() _TORCH_CONTROL_NET_CACHE.clear() _FLAX_CONTROL_NET_CACHE.clear() _TORCH_UNET_CACHE.clear() _FLAX_UNET_CACHE.clear() _TORCH_VAE_CACHE.clear() _FLAX_VAE_CACHE.clear() _PIPELINE_CACHE_SIZE = 0 _CONTROL_NET_CACHE_SIZE = 0 _UNET_CACHE_SIZE = 0 _VAE_CACHE_SIZE = 0 if collect: _messages.debug_log( f'{_types.fullname(clear_model_cache)} calling gc.collect() by request') gc.collect()
[docs] def enforce_cache_constraints(collect=True): """ Enforce :py:attr:`dgenerate.pipelinewrapper.CACHE_MEMORY_CONSTRAINTS` and clear caches accordingly :param collect: Call :py:func:`gc.collect` after a cache clear ? :return: Whether any caches were cleared due to constraint expressions. """ m_name = __name__ _messages.debug_log(f'Enforcing {m_name}.CACHE_MEMORY_CONSTRAINTS =', CACHE_MEMORY_CONSTRAINTS) _messages.debug_log(_memory.memory_use_debug_string()) if _memory.memory_constraints(CACHE_MEMORY_CONSTRAINTS): _messages.debug_log(f'{m_name}.CACHE_MEMORY_CONSTRAINTS ' f'{CACHE_MEMORY_CONSTRAINTS} met, ' f'calling {_types.fullname(clear_model_cache)}.') clear_model_cache(collect=collect) return True return False
[docs] def enforce_pipeline_cache_constraints(new_pipeline_size, collect=True): """ Enforce :py:attr:`dgenerate.pipelinewrapper.PIPELINE_CACHE_MEMORY_CONSTRAINTS` and clear the :py:class:`diffusers.DiffusionPipeline` cache if needed. :param new_pipeline_size: estimated size in bytes of any new pipeline that is about to enter memory :param collect: Call :py:func:`gc.collect` after a cache clear ? :return: Whether the cache was cleared due to constraint expressions. """ m_name = __name__ _messages.debug_log(f'Enforcing {m_name}.PIPELINE_CACHE_MEMORY_CONSTRAINTS =', PIPELINE_CACHE_MEMORY_CONSTRAINTS, f'(cache_size = {_memory.bytes_best_human_unit(pipeline_cache_size())},', f'pipeline_size = {_memory.bytes_best_human_unit(new_pipeline_size)})') _messages.debug_log(_memory.memory_use_debug_string()) if _memory.memory_constraints(PIPELINE_CACHE_MEMORY_CONSTRAINTS, extra_vars={'cache_size': pipeline_cache_size(), 'pipeline_size': new_pipeline_size}): _messages.debug_log(f'{m_name}.PIPELINE_CACHE_MEMORY_CONSTRAINTS ' f'{PIPELINE_CACHE_MEMORY_CONSTRAINTS} met, ' f'calling {_types.fullname(clear_pipeline_cache)}.') clear_pipeline_cache(collect=collect) return True return False
[docs] def enforce_vae_cache_constraints(new_vae_size, collect=True): """ Enforce :py:attr:`dgenerate.pipelinewrapper.VAE_CACHE_MEMORY_CONSTRAINTS` and clear the VAE cache if needed. :param new_vae_size: estimated size in bytes of any new vae that is about to enter memory :param collect: Call :py:func:`gc.collect` after a cache clear ? :return: Whether the cache was cleared due to constraint expressions. """ m_name = __name__ _messages.debug_log(f'Enforcing {m_name}.VAE_CACHE_MEMORY_CONSTRAINTS =', VAE_CACHE_MEMORY_CONSTRAINTS, f'(cache_size = {_memory.bytes_best_human_unit(vae_cache_size())},', f'vae_size = {_memory.bytes_best_human_unit(new_vae_size)})') _messages.debug_log(_memory.memory_use_debug_string()) if _memory.memory_constraints(VAE_CACHE_MEMORY_CONSTRAINTS, extra_vars={'cache_size': vae_cache_size(), 'vae_size': new_vae_size}): _messages.debug_log(f'{m_name}.VAE_CACHE_MEMORY_CONSTRAINTS ' f'{VAE_CACHE_MEMORY_CONSTRAINTS} met, ' f'calling {_types.fullname(clear_vae_cache)}.') clear_vae_cache(collect=collect) return True return False
[docs] def enforce_unet_cache_constraints(new_unet_size, collect=True): """ Enforce :py:attr:`dgenerate.pipelinewrapper.UNET_CACHE_MEMORY_CONSTRAINTS` and clear the UNet cache if needed. :param new_unet_size: estimated size in bytes of any new unet that is about to enter memory :param collect: Call :py:func:`gc.collect` after a cache clear ? :return: Whether the cache was cleared due to constraint expressions. """ m_name = __name__ _messages.debug_log(f'Enforcing {m_name}.UNET_CACHE_MEMORY_CONSTRAINTS =', UNET_CACHE_MEMORY_CONSTRAINTS, f'(cache_size = {_memory.bytes_best_human_unit(unet_cache_size())},', f'unet_size = {_memory.bytes_best_human_unit(new_unet_size)})') _messages.debug_log(_memory.memory_use_debug_string()) if _memory.memory_constraints(UNET_CACHE_MEMORY_CONSTRAINTS, extra_vars={'cache_size': unet_cache_size(), 'unet_size': new_unet_size}): _messages.debug_log(f'{m_name}.UNET_CACHE_MEMORY_CONSTRAINTS ' f'{UNET_CACHE_MEMORY_CONSTRAINTS} met, ' f'calling {_types.fullname(clear_unet_cache)}.') clear_unet_cache(collect=collect) return True return False
[docs] def enforce_control_net_cache_constraints(new_control_net_size, collect=True): """ Enforce :py:attr:`dgenerate.pipelinewrapper.CONTROL_NET_CACHE_MEMORY_CONSTRAINTS` and clear the ControlNet cache if needed. :param new_control_net_size: estimated size in bytes of any new control net that is about to enter memory :param collect: Call :py:func:`gc.collect` after a cache clear ? :return: Whether the cache was cleared due to constraint expressions. """ m_name = __name__ _messages.debug_log(f'Enforcing {m_name}.CONTROL_NET_CACHE_MEMORY_CONSTRAINTS =', CONTROL_NET_CACHE_MEMORY_CONSTRAINTS, f'(cache_size = {_memory.bytes_best_human_unit(control_net_cache_size())},', f'control_net_size = {_memory.bytes_best_human_unit(new_control_net_size)})') _messages.debug_log(_memory.memory_use_debug_string()) if _memory.memory_constraints(CONTROL_NET_CACHE_MEMORY_CONSTRAINTS, extra_vars={'cache_size': control_net_cache_size(), 'control_net_size': new_control_net_size}): _messages.debug_log(f'{m_name}.CONTROL_NET_CACHE_MEMORY_CONSTRAINTS ' f'{CONTROL_NET_CACHE_MEMORY_CONSTRAINTS} met, ' f'calling {_types.fullname(clear_control_net_cache)}.') clear_control_net_cache(collect=collect) return True return False
[docs] def uri_hash_with_parser(parser): """ Create a hash function from a particular URI parser function that hashes a URI string. The URI is parsed and then the object that results from parsing is hashed with :py:func:`dgenerate.memoize.struct_hasher`. :param parser: The URI parser function :return: a hash function compatible with :py:func:`dgenerate.memoize.memoize` """ def hasher(path): if not path: return path return _d_memoize.struct_hasher(parser(path)) return hasher
[docs] def uri_list_hash_with_parser(parser): """ Create a hash function from a particular URI parser function that hashes a list of URIs. :param parser: The URI parser function :return: a hash function compatible with :py:func:`dgenerate.memoize.memoize` """ def hasher(paths): if not paths: return '[]' return '[' + ','.join(uri_hash_with_parser(parser)(path) for path in paths) + ']' return hasher
[docs] def pipeline_create_update_cache_info(pipeline, estimated_size: int): """ Add additional information about the size of a newly created :py:class:`diffusers.DiffusionPipeline` to the cache. Tag the object with an internal tag. :param pipeline: the :py:class:`diffusers.DiffusionPipeline` object :param estimated_size: size bytes """ global _PIPELINE_CACHE_SIZE _PIPELINE_CACHE_SIZE += estimated_size # Tag for internal use pipeline.DGENERATE_SIZE_ESTIMATE = estimated_size
[docs] def controlnet_create_update_cache_info(controlnet, estimated_size: int): """ Add additional information about the size of a newly created ControlNet model to the cache. Tag the object with an internal tag. :param controlnet: the ControlNet object :param estimated_size: size bytes """ global _CONTROL_NET_CACHE_SIZE _CONTROL_NET_CACHE_SIZE += estimated_size # Tag for internal use controlnet.DGENERATE_SIZE_ESTIMATE = estimated_size
[docs] def vae_create_update_cache_info(vae, estimated_size: int): """ Add additional information about the size of a newly created VAE model to the cache. Tag the object with an internal tag. :param vae: the VAE object :param estimated_size: size bytes """ global _VAE_CACHE_SIZE _VAE_CACHE_SIZE += estimated_size # Tag for internal use vae.DGENERATE_SIZE_ESTIMATE = estimated_size
[docs] def unet_create_update_cache_info(unet, estimated_size: int): """ Add additional information about the size of a newly created UNet model to the cache. Tag the object with an internal tag. :param unet: the UNet object :param estimated_size: size bytes """ global _UNET_CACHE_SIZE _UNET_CACHE_SIZE += estimated_size # Tag for internal use unet.DGENERATE_SIZE_ESTIMATE = estimated_size
[docs] def pipeline_to_cpu_update_cache_info(pipeline: diffusers.DiffusionPipeline): """ Update CPU side cache size information when a diffusers pipeline is moved to the CPU :param pipeline: the pipeline """ global _PIPELINE_CACHE_SIZE enforce_pipeline_cache_constraints(pipeline.DGENERATE_SIZE_ESTIMATE) _PIPELINE_CACHE_SIZE += pipeline.DGENERATE_SIZE_ESTIMATE _messages.debug_log(f'{_types.class_and_id_string(pipeline)} ' f'Size = {pipeline.DGENERATE_SIZE_ESTIMATE} Bytes ' f'({_memory.bytes_best_human_unit(pipeline.DGENERATE_SIZE_ESTIMATE)}) ' f'is entering CPU side memory, {_types.fullname(pipeline_cache_size)}() ' f'is now {pipeline_cache_size()} Bytes ' f'({_memory.bytes_best_human_unit(pipeline.DGENERATE_SIZE_ESTIMATE)})')
def unet_to_cpu_update_cache_info(unet): """ Update CPU side cache size information when a UNet module is moved to the CPU :param unet: the UNet """ global _UNET_CACHE_SIZE if hasattr(unet, 'DGENERATE_SIZE_ESTIMATE'): # UNet returning to CPU side memory enforce_unet_cache_constraints(unet.DGENERATE_SIZE_ESTIMATE) _UNET_CACHE_SIZE += unet.DGENERATE_SIZE_ESTIMATE _messages.debug_log(f'{_types.class_and_id_string(unet)} ' f'Size = {unet.DGENERATE_SIZE_ESTIMATE} Bytes ' f'({_memory.bytes_best_human_unit(unet.DGENERATE_SIZE_ESTIMATE)}) ' f'is entering CPU side memory, {_types.fullname(unet_cache_size)}() ' f'is now {unet_cache_size()} Bytes ' f'({_memory.bytes_best_human_unit(unet.DGENERATE_SIZE_ESTIMATE)})') def vae_to_cpu_update_cache_info(vae): """ Update CPU side cache size information when a VAE module is moved to the CPU :param vae: the VAE """ global _VAE_CACHE_SIZE if hasattr(vae, 'DGENERATE_SIZE_ESTIMATE'): # vae returning to CPU side memory enforce_vae_cache_constraints(vae.DGENERATE_SIZE_ESTIMATE) _VAE_CACHE_SIZE += vae.DGENERATE_SIZE_ESTIMATE _messages.debug_log(f'Cached VAE {_types.class_and_id_string(vae)} ' f'Size = {vae.DGENERATE_SIZE_ESTIMATE} Bytes ' f'({_memory.bytes_best_human_unit(vae.DGENERATE_SIZE_ESTIMATE)}) ' f'is entering CPU side memory, {_types.fullname(vae_cache_size)}() ' f'is now {vae_cache_size()} Bytes ' f'({_memory.bytes_best_human_unit(vae.DGENERATE_SIZE_ESTIMATE)})') def controlnet_to_cpu_update_cache_info(controlnet: typing.Union[diffusers.models.ControlNetModel, diffusers.pipelines.controlnet.MultiControlNetModel]): """ Update CPU side cache size information when a ControlNet module is moved to the CPU :param controlnet: the control net, or multi control net """ global _CONTROL_NET_CACHE_SIZE if isinstance(controlnet, diffusers.pipelines.controlnet.MultiControlNetModel): total_size = 0 for control_net in controlnet.nets: total_size += control_net.DGENERATE_SIZE_ESTIMATE _messages.debug_log(f'Cached ControlNetModel {_types.class_and_id_string(control_net)} ' f'Size = {control_net.DGENERATE_SIZE_ESTIMATE} Bytes ' f'({_memory.bytes_best_human_unit(control_net.DGENERATE_SIZE_ESTIMATE)}) ' f'from "MultiControlNetModel" is entering CPU side memory.') enforce_control_net_cache_constraints(total_size) _CONTROL_NET_CACHE_SIZE += total_size _messages.debug_log(f'"MultiControlNetModel" size fully estimated, ' f'{_types.fullname(control_net_cache_size)}() ' f'is now {control_net_cache_size()} Bytes ' f'({_memory.bytes_best_human_unit(control_net_cache_size())})') else: # ControlNet returning to CPU side memory enforce_control_net_cache_constraints(controlnet.DGENERATE_SIZE_ESTIMATE) _CONTROL_NET_CACHE_SIZE += controlnet.DGENERATE_SIZE_ESTIMATE _messages.debug_log(f'Cached ControlNetModel {_types.class_and_id_string(controlnet)} ' f'Size = {controlnet.DGENERATE_SIZE_ESTIMATE} Bytes ' f'({_memory.bytes_best_human_unit(controlnet.DGENERATE_SIZE_ESTIMATE)}) ' f'is entering CPU side memory, {_types.fullname(control_net_cache_size)}() ' f'is now {control_net_cache_size()} Bytes ' f'({_memory.bytes_best_human_unit(control_net_cache_size())})')
[docs] def pipeline_off_cpu_update_cache_info( pipeline: typing.Union[diffusers.DiffusionPipeline, diffusers.FlaxDiffusionPipeline]): """ Update CPU side cache size information when a diffusers pipeline is moved to a device that is not the CPU :param pipeline: the pipeline """ global _PIPELINE_CACHE_SIZE _PIPELINE_CACHE_SIZE -= pipeline.DGENERATE_SIZE_ESTIMATE if _PIPELINE_CACHE_SIZE < 0: _PIPELINE_CACHE_SIZE = 0 _messages.debug_log(f'Cached Diffusers Pipeline {_types.class_and_id_string(pipeline)} ' f'Size = {pipeline.DGENERATE_SIZE_ESTIMATE} Bytes ' f'({_memory.bytes_best_human_unit(pipeline.DGENERATE_SIZE_ESTIMATE)}) ' f'is leaving CPU side memory, {_types.fullname(pipeline_cache_size)}() ' f'is now {pipeline_cache_size()} Bytes ' f'({_memory.bytes_best_human_unit(pipeline_cache_size())})')
def unet_off_cpu_update_cache_info(unet): """ Update CPU side cache size information when a UNet module is moved to a device that is not the CPU :param unet: the UNet """ global _UNET_CACHE_SIZE if hasattr(unet, 'DGENERATE_SIZE_ESTIMATE'): _UNET_CACHE_SIZE -= unet.DGENERATE_SIZE_ESTIMATE if _UNET_CACHE_SIZE < 0: _UNET_CACHE_SIZE = 0 _messages.debug_log(f'Cached UNet {_types.class_and_id_string(unet)} ' f'Size = {unet.DGENERATE_SIZE_ESTIMATE} Bytes ' f'({_memory.bytes_best_human_unit(unet.DGENERATE_SIZE_ESTIMATE)}) ' f'is leaving CPU side memory, {_types.fullname(unet_cache_size)}() ' f'is now {unet_cache_size()} Bytes ' f'({_memory.bytes_best_human_unit(unet_cache_size())})') def vae_off_cpu_update_cache_info(vae): """ Update CPU side cache size information when a VAE module is moved to a device that is not the CPU :param vae: the VAE """ global _VAE_CACHE_SIZE if hasattr(vae, 'DGENERATE_SIZE_ESTIMATE'): _VAE_CACHE_SIZE -= vae.DGENERATE_SIZE_ESTIMATE if _VAE_CACHE_SIZE < 0: _VAE_CACHE_SIZE = 0 _messages.debug_log(f'Cached VAE {_types.class_and_id_string(vae)} ' f'Size = {vae.DGENERATE_SIZE_ESTIMATE} Bytes ' f'({_memory.bytes_best_human_unit(vae.DGENERATE_SIZE_ESTIMATE)}) ' f'is leaving CPU side memory, {_types.fullname(vae_cache_size)}() ' f'is now {vae_cache_size()} Bytes ' f'({_memory.bytes_best_human_unit(vae_cache_size())})') def controlnet_off_cpu_update_cache_info(controlnet: typing.Union[diffusers.models.ControlNetModel, diffusers.pipelines.controlnet.MultiControlNetModel]): """ Update CPU side cache size information when a ControlNet module is moved to a device that is not the CPU :param controlnet: the control net, or multi control net """ global _CONTROL_NET_CACHE_SIZE if isinstance(controlnet, diffusers.pipelines.controlnet.MultiControlNetModel): for control_net in controlnet.nets: _CONTROL_NET_CACHE_SIZE -= control_net.DGENERATE_SIZE_ESTIMATE if _CONTROL_NET_CACHE_SIZE < 0: _CONTROL_NET_CACHE_SIZE = 0 _messages.debug_log(f'Cached ControlNetModel {_types.class_and_id_string(control_net)} Size = ' f'{control_net.DGENERATE_SIZE_ESTIMATE} Bytes ' f'({_memory.bytes_best_human_unit(control_net.DGENERATE_SIZE_ESTIMATE)}) ' f'from "MultiControlNetModel" is leaving CPU side memory, ' f'{_types.fullname(control_net_cache_size)}() is now ' f'{control_net_cache_size()} Bytes ' f'({_memory.bytes_best_human_unit(control_net_cache_size())})') elif isinstance(controlnet, diffusers.models.ControlNetModel): _CONTROL_NET_CACHE_SIZE -= controlnet.DGENERATE_SIZE_ESTIMATE if _CONTROL_NET_CACHE_SIZE < 0: _CONTROL_NET_CACHE_SIZE = 0 _messages.debug_log(f'Cached ControlNetModel {_types.class_and_id_string(controlnet)} ' f'Size = {controlnet.DGENERATE_SIZE_ESTIMATE} Bytes ' f'({_memory.bytes_best_human_unit(controlnet.DGENERATE_SIZE_ESTIMATE)}) ' f'is leaving CPU side memory, {_types.fullname(control_net_cache_size)}() ' f'is now {control_net_cache_size()} Bytes ' f'({_memory.bytes_best_human_unit(control_net_cache_size())})') __all__ = _types.module_all()