-
-
Notifications
You must be signed in to change notification settings - Fork 136
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
kernels: disambiguate quantized types via a new ScalarType
Co-authored-by: Lucas Wilkinson <[email protected]>
- Loading branch information
1 parent
7844103
commit 141672a
Showing
27 changed files
with
1,008 additions
and
295 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import importlib.util | ||
from enum import Enum | ||
from typing import TYPE_CHECKING, Optional, Union | ||
|
||
import torch | ||
from loguru import logger | ||
|
||
core_C_available = importlib.util.find_spec('._core_C', | ||
'aphrodite') is not None | ||
|
||
|
||
# Mirrors enum in `core/scalar_type.hpp` | ||
class NanRepr(Enum): | ||
NONE = 0 # nans are not supported | ||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s | ||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s | ||
|
||
|
||
if TYPE_CHECKING or not core_C_available: | ||
# On platforms were we cannot use/build the C++ core extension (i.e. namely | ||
# neuron and tpu), we define the mock ScalarType class here that partially | ||
# mimics the C++ ScalarType class. | ||
# | ||
# We also use this provide type signatures to the Python LSP for the methods | ||
# in the C++ ScalarType class. So these type signatures should be kept | ||
# in sync with csrc/core/scalar_type.hpp | ||
|
||
from dataclasses import dataclass | ||
|
||
@dataclass(frozen=True) | ||
class ScalarType: | ||
""" | ||
ScalarType can represent a wide range of floating point and integer | ||
types, in particular it can be used to represent sub-byte data types | ||
(something that torch.dtype currently does not support). It is also | ||
capable of representing types with a bias, i.e.: | ||
`stored_value = value + bias`, | ||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias | ||
of 8). The implementation for this class can be found in | ||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync | ||
with that file. | ||
""" | ||
|
||
exponent: int | ||
""" | ||
Number of bits in the exponent if this is a floating point type | ||
(zero if this an integer type) | ||
""" | ||
|
||
mantissa: int | ||
""" | ||
Number of bits in the mantissa if this is a floating point type, | ||
or the number bits representing an integer excluding the sign bit if | ||
this an integer type. | ||
""" | ||
|
||
bias: int | ||
""" | ||
bias used to encode the values in this scalar type | ||
(value = stored_value - bias, default 0) for example if we store the | ||
type as an unsigned integer with a bias of 128 then the value 0 will be | ||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. | ||
""" | ||
|
||
signed: bool | ||
"If the type is signed (i.e. has a sign bit)" | ||
|
||
_finite_values_only: bool = False | ||
""" | ||
Private: if NANs are supported, used `has_infs()` instead. | ||
""" | ||
|
||
nan_repr: int = NanRepr.IEEE_754.value | ||
""" | ||
How NaNs are represent in this scalar type, returns NanRepr value. | ||
(not applicable for integer types) | ||
""" | ||
|
||
@property | ||
def size_bits(self): | ||
return self.exponent + self.mantissa + int(self.signed) | ||
|
||
def min(self) -> Union[int, float]: | ||
""" | ||
Min representable value for this scalar type. | ||
(accounting for bias if there is one) | ||
""" | ||
raise NotImplementedError | ||
|
||
def max(self) -> Union[int, float]: | ||
""" | ||
Max representable value for this scalar type. | ||
(accounting for bias if there is one) | ||
""" | ||
raise NotImplementedError | ||
|
||
def is_signed(self) -> bool: | ||
""" | ||
If the type is signed (i.e. has a sign bit), same as `signed` | ||
added for consistency with: | ||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html | ||
""" | ||
... | ||
|
||
def is_floating_point(self): | ||
"If the type is a floating point type" | ||
return self.exponent != 0 | ||
|
||
def is_integer(self): | ||
"If the type is an integer type" | ||
return self.exponent == 0 | ||
|
||
def has_bias(self): | ||
"If the type has a non-zero bias" | ||
return self.bias != 0 | ||
|
||
def has_infs(self): | ||
"If the type is floating point and supports infinity" | ||
return not self._finite_values_only | ||
|
||
def has_nans(self): | ||
return self.nan_repr != NanRepr.NONE.value | ||
|
||
def is_ieee_754(self) -> bool: | ||
""" | ||
If the type is a floating point type that follows IEEE 754 | ||
conventions | ||
""" | ||
return self.nan_repr == NanRepr.IEEE_754.value and \ | ||
not self._finite_values_only | ||
|
||
def __str__(self) -> str: | ||
raise NotImplementedError | ||
|
||
def __repr__(self) -> str: | ||
raise NotImplementedError | ||
|
||
# | ||
# Convenience Constructors | ||
# | ||
|
||
@classmethod | ||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': | ||
"Create a signed integer scalar type (size_bits includes sign-bit)." | ||
return cls(size_bits - 1, size_bits, bias if bias else 0, True) | ||
|
||
@classmethod | ||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': | ||
"""Create a unsigned integer scalar type.""" | ||
return cls(size_bits, size_bits, bias if bias else 0, False) | ||
|
||
@classmethod | ||
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': | ||
""" | ||
Create a standard floating point type | ||
(i.e. follows IEEE 754 conventions). | ||
""" | ||
return cls(exponent, mantissa, 0, True) | ||
|
||
@classmethod | ||
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, | ||
nan_repr: int): | ||
""" | ||
Create a non-standard floating point type | ||
(i.e. does not follow IEEE 754 conventions). | ||
""" | ||
return cls(exponent, mantissa, 0, True, finite_values_only, | ||
nan_repr) | ||
|
||
elif core_C_available: | ||
try: | ||
import aphrodite._core_C # noqa: F401 | ||
except ImportError as e: | ||
logger.warning(f"Failed to import from aphrodite._core_C with {e}") | ||
|
||
ScalarType = torch.classes._core_C.ScalarType |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.