-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
55 lines (52 loc) · 2.07 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# %%
from my_utils import *
import torch
import tqdm
# %%
encoder = AutoEncoder(cfg)
buffer = Buffer(cfg)
# Code used to remove the "rare freq direction", the shared direction among the ultra low frequency features.
# I experimented with removing it and retraining the autoencoder.
# %%
try:
wandb.init(project="sparse_encoder", name="orig")
num_batches = cfg["num_tokens"] // cfg["batch_size"]
encoder_optim = torch.optim.Adam(encoder.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
recons_scores = []
act_freq_scores_list = []
for i in tqdm.trange(num_batches):
i = i % all_tokens.shape[0]
acts = buffer.next()
loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts)
loss.backward()
encoder.make_decoder_weights_and_grad_unit_norm()
encoder_optim.step()
encoder_optim.zero_grad()
loss_dict = {"loss": loss.item(), "l2_loss": l2_loss.item(), "l1_loss": l1_loss.item()}
del loss, x_reconstruct, mid_acts, l2_loss, l1_loss, acts
if (i) % 100 == 0:
wandb.log(loss_dict)
print(loss_dict)
if (i) % 1000 == 0:
x = (get_recons_loss(local_encoder=encoder))
print("Reconstruction:", x)
recons_scores.append(x[0])
freqs = get_freqs(5, local_encoder=encoder)
act_freq_scores_list.append(freqs)
# histogram(freqs.log10(), marginal="box", histnorm="percent", title="Frequencies")
wandb.log({
"recons_score": x[0],
"dead": (freqs==0).float().mean().item(),
"below_1e-6": (freqs<1e-6).float().mean().item(),
"below_1e-5": (freqs<1e-5).float().mean().item(),
})
if (i+1) % 30000 == 0:
encoder.save()
wandb.log({"reset_neurons": 0.0})
freqs = get_freqs(50, local_encoder=encoder)
to_be_reset = (freqs<10**(-5.5))
print("Resetting neurons!", to_be_reset.sum())
re_init(to_be_reset, encoder)
finally:
encoder.save()
# %%