forked from BlinkDL/ChatRWKV
-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_lm_eval.py
150 lines (123 loc) · 5 KB
/
run_lm_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
#
# pip install rwkv lm_eval --upgrade
#
import os, sys, types, json, math, time
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import torch
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
from torch.nn import functional as F
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
from lm_eval import tasks, evaluator
from lm_eval.models.gpt2 import GPT2LM
########################################################################################################
MODEL_NAME = "/fsx/BlinkDL/HF-MODEL/rwkv-5-world/RWKV-5-World-1.5B-v2-OnlyForTest_14%_trained-20231001-ctx4096"
print(f'Loading model - {MODEL_NAME}')
model = RWKV(model=MODEL_NAME, strategy='cuda fp16', verbose=False)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
eval_tasks = []
eval_tasks += ['lambada_openai']
# eval_tasks += ['hellaswag','winogrande']
# eval_tasks += ['lambada_openai','piqa','storycloze_2016','hellaswag','winogrande']
# eval_tasks += ['arc_challenge','arc_easy','headqa','openbookqa','sciq']
# eval_tasks += ['record','copa']
# eval_tasks += ['triviaqa']
# eval_tasks += ['coqa']
RWKV_PAD = pipeline.tokenizer.encode('\n') # we will use '\n' as PAD
# RWKV_PAD = [0] # you can try using [0] as pad
print('RWKV_PAD', RWKV_PAD)
########################################################################################################
logitBuf = {}
correctBuf = {}
class TokenizerWrapper:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.eos_token_id = 0
def encode(self, string: str, add_special_tokens=False):
return self.tokenizer.encode(string)
def decode(self, tokens):
return self.tokenizer.decode(tokens)
class EvalHarnessAdapter(GPT2LM):
def __init__(self):
self.tokenizer = TokenizerWrapper(pipeline.tokenizer)
# def greedy_until(self, requests): # designed for coqa
# res = []
# for i in range(len(requests)):
# if i % 50 == 0:
# print(i)
# otoken = []
# while True:
# src = self.tokenizer.encode(requests[i][0]) + otoken
# src = src[-4096:]
# outputs, _ = model.forward(src, None)
# otoken += [int(torch.argmax(outputs))]
# ss = self.tokenizer.decode(otoken)
# if '\n' in ss or len(ss) > 200:
# if not ss.endswith('\n'):
# ss = ss + '\n'
# print(ss)
# res += [(ss)]
# break
# print(res)
# return res
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
global logitBuf, correctBuf
res = []
for COUNTER in range(len(requests)):
n = COUNTER
raw_src = requests[n][0][0] + requests[n][0][1]
src = requests[n][1] + requests[n][2]
raw_src = '\n' + raw_src
src = RWKV_PAD + src
sss = str(src)
correct = True
if sss in logitBuf:
logit = logitBuf[sss]
correct = correctBuf[sss]
else:
q_len = len(requests[n][1])
q_len += len(RWKV_PAD)
logit = 0
with torch.no_grad():
outputs, _ = model.forward(src, None, full_output=True)
for i in range(q_len-1, len(src)-1):
oo = outputs[i].detach().float()
dst = src[i+1]
logit += math.log(F.softmax(oo, dim=-1)[dst])
_, s_index = torch.sort(oo, descending=True)
pred = s_index[0].item()
if pred != dst:
correct = False
outputs = None
pred = None
logitBuf[sss] = logit
correctBuf[sss] = correct
res += [(logit, correct)]
if n % 1000 == 0:
print(f'{n//1000}/{len(requests)//1000}', end = ' ', flush=True)
return res
@torch.no_grad()
def run_eval(self, eval_tasks=None, num_fewshot=0, bootstrap_iters=2):
results = evaluator.evaluate(
lm=self,
task_dict=tasks.get_task_dict(eval_tasks),
provide_description=False,
num_fewshot=num_fewshot,
limit=None,
bootstrap_iters=bootstrap_iters,
)
return results
adapter = EvalHarnessAdapter()
results = adapter.run_eval(
eval_tasks=eval_tasks,
bootstrap_iters=10000,
)
print(results['results'])