Skip to content

Commit

Permalink
Merge pull request #174 from MollySophia/rwkv6
Browse files Browse the repository at this point in the history
Add initial support for RWKV v6
  • Loading branch information
PicoCreator authored Jul 2, 2024
2 parents d8f13ff + 3c4c01f commit 970a813
Show file tree
Hide file tree
Showing 17 changed files with 598 additions and 78 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ This project provides [a C library rwkv.h](rwkv.h) and [a convinient Python wrap

[RWKV v5](https://huggingface.co/BlinkDL/rwkv-5-world) is a major upgrade to RWKV architecture, making it competitive with Transformers in quality. RWKV v5 models are supported.

[RWKV v6](https://huggingface.co/BlinkDL/rwkv-6-world) is a further improvement to RWKV architecture, with better quality. RWKV v6 models are supported.

Loading LoRA checkpoints in [Blealtan's format](https://github.com/Blealtan/RWKV-LM-LoRA) is supported through [merge_lora_into_ggml.py script](rwkv%2Fmerge_lora_into_ggml.py).

## Quality and performance
Expand Down
21 changes: 18 additions & 3 deletions python/convert_pytorch_to_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t

is_v5_1_or_2: bool = 'blocks.0.att.ln_x.weight' in state_dict
is_v5_2: bool = 'blocks.0.att.gate.weight' in state_dict
is_v6_0: bool = 'blocks.0.att.time_maa_x' in state_dict

if is_v5_2:
if is_v6_0:
print('Detected RWKV v6.0')
elif is_v5_2:
print('Detected RWKV v5.2')
elif is_v5_1_or_2:
print('Detected RWKV v5.1')
Expand All @@ -57,13 +60,25 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t
1 if is_FP16 else 0
))

if is_v6_0:
n_head: int = state_dict['blocks.0.att.time_faaaa'].shape[0]
for k in state_dict.keys():
tensor: torch.Tensor = state_dict[k].float()

if '.time_' in k:
tensor = tensor.squeeze()

if is_v5_1_or_2:
if is_v6_0:
if '.time_faaaa' in k:
tensor = tensor.unsqueeze(-1)
if '.time_maa_w1' in k or '.time_decay_w' in k:
tensor = tensor.transpose(0, 1)
if '.time_maa_w2' in k:
tensor = tensor.transpose(1, 2)
if '.time_decay' in k and '_w' not in k:
tensor = tensor.reshape(n_head, -1, 1)

elif is_v5_1_or_2:
if '.time_decay' in k:
if is_v5_2:
tensor = torch.exp(-torch.exp(tensor)).unsqueeze(-1)
Expand Down Expand Up @@ -105,7 +120,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t

out_file.write(k_encoded)

tensor.numpy().tofile(out_file)
tensor.detach().numpy().tofile(out_file)

def main() -> None:
args = parse_args()
Expand Down
16 changes: 13 additions & 3 deletions python/merge_lora_into_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def parse_args():
parser = argparse.ArgumentParser(description='Merge a PyTorch LoRA checkpoint (.pth) into an rwkv.cpp model file')
parser.add_argument('src_path', help='Path to source rwkv.cpp model')
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2', type=str, choices=['v4', 'v5.1', 'v5.2'])
parser.add_argument('rwkv_arch_version', help='Version of RWKV architecture: v4, v5.1, v5.2, v6.0', type=str, choices=['v4', 'v5.1', 'v5.2', 'v6.0'])
parser.add_argument('lora_path', help='Path to LoRA checkpoint in PyTorch format')
parser.add_argument('lora_alpha', help='Value of lora_alpha parameter used when training this LoRA checkpoint', type=int)
parser.add_argument('dest_path', help='Path to destination rwkv.cpp model, will be overwitten with the merged model')
Expand Down Expand Up @@ -47,7 +47,7 @@ def main() -> None:

arch_version: str = args.rwkv_arch_version

if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2'):
if not (arch_version == 'v4' or arch_version == 'v5.1' or arch_version == 'v5.2' or arch_version == 'v6.0'):
raise ValueError(f'Invalid RWKV architecture version {arch_version}')

print(f'Reading {args.lora_path}')
Expand Down Expand Up @@ -108,7 +108,17 @@ def main() -> None:
if '.time_' in key:
replacement = replacement.squeeze()

if arch_version == 'v5.1' or arch_version == 'v5.2':
if arch_version == 'v6.0':
if '.time_faaaa' in k:
replacement = replacement.unsqueeze(-1)
if '.time_maa_w1' in k or '.time_decay_w' in k:
replacement = replacement.transpose(0, 1)
if '.time_maa_w2' in k:
n_head: int = replacement.shape[1]
replacement = replacement.transpose(1, 2)
if '.time_decay' in k and '_w' not in k:
replacement = replacement.reshape(n_head, -1, 1)
elif arch_version == 'v5.1' or arch_version == 'v5.2':
if '.time_decay' in key:
if arch_version == 'v5.2':
replacement = torch.exp(-torch.exp(replacement)).unsqueeze(-1)
Expand Down
2 changes: 2 additions & 0 deletions rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ static_assert(sizeof(decltype(ftell(NULL))) >= 8, "File offsets should be 64-bit

#include "rwkv_operators_wkv_v5.inc"

#include "rwkv_operators_wkv_v6.inc"

#include "rwkv_graph.inc"

// API function.
Expand Down
Loading

0 comments on commit 970a813

Please sign in to comment.