# Copyright (c) 2023, Teriks
#
# dgenerate is distributed under the following BSD 3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import diffusers
import torch
import dgenerate.textprocessing as _textprocessing
import dgenerate.types as _types
from dgenerate.pipelinewrapper.enums import get_torch_dtype as _get_torch_dtype
from dgenerate.pipelinewrapper.uris import exceptions as _exceptions
_bnb_quantizer_uri_parser = _textprocessing.ConceptUriParser(
'BNB Quantizer',
['bits', 'bits4-compute-dtype', 'bits4-quant-type', 'bits4-use-double-quant', 'bits4-quant-storage'])
[docs]
class BNBQuantizerUri:
"""
Representation of ``--quantizer`` URI.
"""
_valid_dtypes = ["float16", "bfloat16", "float32", "float64", "int8", "uint8"]
# pipelinewrapper.uris.util.get_uri_accepted_args_schema metadata
NAMES = ['bnb', 'bitsandbytes']
[docs]
@staticmethod
def help():
return """
Bitsandbytes quantization backend configuration.
This backend can be specified as "bnb" or "bitsandbytes" in the URI.
URI Format: bnb;argument1=value1;argument2=value2
Example: bnb;bits=4;bits4-quant-type=nf4
The argument "bits" is Quantization bit width. Must be 4 or 8.
NOWRAP!
- bits=8: Uses LLM.int8() quantization method
- bits=4: Uses QLoRA 4-bit quantization method
The argument "bits4-compute-dtype" is the compute data type for 4-bit quantization.
Only applies when bits=4. When None, automatically determined. This should generally
match the dtype that you loaded the model with.
The argument "bits4-quant-type" is the quantization data type for 4-bit weights.
Only applies when bits=4.
NOWRAP!
- "fp4": 4-bit floating point (default)
- "nf4": Normal Float 4 data type, adapted for weights from normal distribution.
The argument "bits4-use-double-quant" Enables nested quantization for 4-bit mode.
Only applies when bits=4. When True, performs a second quantization of already
quantized weights to save an additional 0.4 bits/parameter with no performance cost.
The argument "bits4-quant-storage" is the storage data type for 4-bit quantized weights.
Only applies when bits=4. When None, uses default storage format. Controls memory
layout of quantized parameters.
"""
OPTION_ARGS = {
'bits': [8, 4],
'bits4-compute-dtype': _valid_dtypes,
'bits4-quant-type': ["fp4", "nf4"],
'bits4-quant-storage': _valid_dtypes
}
# ===
[docs]
def __init__(self,
bits: int = 8,
bits4_compute_dtype: str | None = None,
bits4_quant_type: str = "fp4",
bits4_use_double_quant: bool = False,
bits4_quant_storage: str | None = None):
if bits not in {4, 8}:
raise _exceptions.InvalidBNBQuantizerUriError(
'BNB Quant Config bits must be 4 or 8.')
if bits4_quant_type not in {'fp4', 'nf4'}:
raise _exceptions.InvalidBNBQuantizerUriError(
'BNB Quant Config bits must be fp4 or nf4.')
self.bits4_quant_storage = self._dtype_check(bits4_quant_storage)
self.bits4_compute_dtype = self._dtype_check(bits4_compute_dtype)
self.bits = bits
self.bits4_quant_type = bits4_quant_type
self.bits4_use_double_quant = bits4_use_double_quant
@staticmethod
def _dtype_check(s):
if s is None:
return None
if s not in BNBQuantizerUri._valid_dtypes:
raise _exceptions.InvalidBNBQuantizerUriError(
f'BNB Quant dtypes must be one of: '
f'{_textprocessing.oxford_comma(BNBQuantizerUri._valid_dtypes, "or")}.')
return s
[docs]
def to_config(self, compute_dtype: str | torch.dtype | None = None) -> diffusers.BitsAndBytesConfig:
compute_dtype = _get_torch_dtype(compute_dtype)
return diffusers.BitsAndBytesConfig(
load_in_4bit=self.bits == 4,
load_in_8bit=self.bits == 8,
bnb_4bit_use_double_quant=self.bits4_use_double_quant,
bnb_4bit_quant_type=self.bits4_quant_type,
bnb_4bit_quant_storage=self.bits4_quant_storage,
bnb_4bit_compute_dtype=_types.default(self.bits4_compute_dtype, compute_dtype)
)
[docs]
@staticmethod
def parse(uri: _types.Uri) -> 'BNBQuantizerUri':
try:
r = _bnb_quantizer_uri_parser.parse(uri)
if r.concept not in {'bnb', 'bitsandbytes'}:
raise _exceptions.InvalidBNBQuantizerUriError(
f'Unknown quantization backend: {r.concept}'
)
bits = int(r.args.get('bits', 8))
bits4_compute_dtype = r.args.get('bits4-compute-dtype', None)
bits4_quant_type = r.args.get('bits4-quant-type', 'fp4')
bits4_use_double_quant = _types.parse_bool(r.args.get('bits4-use-double-quant', False))
bits4_quant_storage = r.args.get('bits4-quant-storage', None)
return BNBQuantizerUri(
bits=bits,
bits4_compute_dtype=bits4_compute_dtype,
bits4_quant_type=bits4_quant_type,
bits4_use_double_quant=bits4_use_double_quant,
bits4_quant_storage=bits4_quant_storage
)
except _textprocessing.ConceptUriParseError as e:
raise _exceptions.InvalidBNBQuantizerUriError(e) from e