-
Notifications
You must be signed in to change notification settings - Fork 0
/
quant.py
136 lines (98 loc) · 4.09 KB
/
quant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import os
#! I need quant the main model
from transformer_lens import HookedTransformer
from transformer_lens import HookedTransformer
import torch
def fake_quantize(w: torch.Tensor, num_bits: int = 8, symmetric: bool = True) -> torch.Tensor:
"""
Simulate quantization of weights while keeping them in float32.
Args:
w: Input weights as torch.Tensor
num_bits: Number of bits to quantize to
symmetric: If True, use symmetric quantization around 0
Returns:
Quantized weights as float32 tensor
"""
with torch.no_grad():
if symmetric:
n_levels = 2**(num_bits - 1) - 1 # One bit for sign
scale = torch.max(torch.abs(w))
min_val = -scale
max_val = scale
else:
n_levels = 2**num_bits - 1
min_val = torch.min(w)
max_val = torch.max(w)
scale = (max_val - min_val) / n_levels
# print(scale)
# Clip weights to min/max range
w_clipped = torch.clamp(w, min_val, max_val)
# Quantize
w_quantized = torch.round((w_clipped - min_val) / scale) * scale + min_val
return w_quantized
def fixed_fake_quantize(w: torch.Tensor, num_bits: int = 8, symmetric: bool = True) -> torch.Tensor:
with torch.no_grad():
if symmetric:
n_levels = 2**(num_bits - 1) - 1
max_abs = torch.max(torch.abs(w))
scale = max_abs / n_levels # Divide by n_levels to preserve scale
min_val = -max_abs
max_val = max_abs
else:
n_levels = 2**num_bits - 1
min_val = torch.min(w)
max_val = torch.max(w)
scale = (max_val - min_val) / n_levels
w_clipped = torch.clamp(w, min_val, max_val)
w_int = torch.round((w_clipped - min_val) / scale)
w_quantized = w_int * scale + min_val
return w_quantized
model = HookedTransformer.from_pretrained("gelu-2l").to(torch.float32).to("cuda:0")
# a = torch.load("something.bin")
# print(type(a))
# model.load_state_dict(torch.load("something.bin"))
#model.load_state_dict(torch.load("something.bin"))
for name, param in model.named_parameters():
if "block" in name and "W" in name:
param.data = fixed_fake_quantize(param.data, num_bits = 4)
torch.save(model.state_dict(), "quant_4_f.bin")
# def print_size_of_model(model):
# torch.save(model.state_dict(), "temp.p")
# print('Size (MB):', os.path.getsize("temp.p")/1e6)
# os.remove('temp.p')
# def fake_quantize(w: torch.Tensor, num_bits: int = 8, symmetric: bool = True) -> torch.Tensor:
# """
# Simulate quantization of weights while keeping them in float32.
# Args:
# w: Input weights as torch.Tensor
# num_bits: Number of bits to quantize to
# symmetric: If True, use symmetric quantization around 0
# Returns:
# Quantized weights as float32 tensor
# """
# with torch.no_grad():
# if symmetric:
# n_levels = 2**(num_bits - 1) - 1 # One bit for sign
# scale = torch.max(torch.abs(w))
# min_val = -scale
# max_val = scale
# else:
# n_levels = 2**num_bits - 1
# min_val = torch.min(w)
# max_val = torch.max(w)
# scale = (max_val - min_val) / n_levels
# # Clip weights to min/max range
# w_clipped = torch.clamp(w, min_val, max_val)
# # Quantize
# w_quantized = torch.round((w_clipped - min_val) / scale) * scale + min_val
# return w_quantized
# model = HookedTransformer.from_pretrained("gelu-2l").to(torch.float32).to("cuda:0")
# # model_dynamic_quantized = torch.ao.quantization.quantize_dynamic(
# # model, dtype=torch.qint8
# # )
# for name, param in model.named_parameters():
# # print_size_of_model(model)
# # print_size_of_model(model_dynamic_quantized)
# # for name, param in model_dynamic_quantized.named_parameters():
# # print(f"Layer: {name}, Shape: {param.shape}, Dtype: {param.dtype}")