diff --git a/python/convert_pytorch_to_ggml.py b/python/convert_pytorch_to_ggml.py index 0c3dfbf..37c1ee6 100644 --- a/python/convert_pytorch_to_ggml.py +++ b/python/convert_pytorch_to_ggml.py @@ -60,6 +60,8 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t 1 if is_FP16 else 0 )) + if is_v6_0: + n_head: int = state_dict['blocks.0.att.time_faaaa'].shape[0] for k in state_dict.keys(): tensor: torch.Tensor = state_dict[k].float() @@ -72,7 +74,6 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t if '.time_maa_w1' in k or '.time_decay_w' in k: tensor = tensor.transpose(0, 1) if '.time_maa_w2' in k: - n_head: int = tensor.shape[1] tensor = tensor.transpose(1, 2) if '.time_decay' in k and '_w' not in k: tensor = tensor.reshape(n_head, -1, 1) diff --git a/rwkv_graph.inc b/rwkv_graph.inc index 462a335..78e9755 100644 --- a/rwkv_graph.inc +++ b/rwkv_graph.inc @@ -364,7 +364,8 @@ static struct ggml_tensor * rwkv_att_v6( ctx, ggml_mul_mat(ctx, layer.att_time_maa_w1, xxx) ), - head_count, 1, 5, sequence_length + 32, 1, 5, sequence_length + // D_MIX_LORA = 32 ); xxx = ggml_cont( @@ -378,7 +379,8 @@ static struct ggml_tensor * rwkv_att_v6( ggml_reshape_4d( ctx, layer.att_time_maa_w2, - head_count, n_embed, 1, 5 + 32, n_embed, 1, 5 + // D_MIX_LORA = 32 ), xxx ); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 60c783a..7e5ec45 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -39,6 +39,12 @@ file(COPY tiny-rwkv-5v2-730K-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY tiny-rwkv-5v2-730K-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY expected-logits-5v2-730K.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-6v0-1m-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-6v0-1m-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-6v0-1m-Q5_0.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY tiny-rwkv-6v0-1m-Q5_1.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +file(COPY expected-logits-6v0-1m.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + rwkv_add_test(test_ggml_basics.c) rwkv_add_test(test_quantized_matmul_on_gpu.c) rwkv_add_test(test_tiny_rwkv.c) diff --git a/tests/expected-logits-6v0-1m.bin b/tests/expected-logits-6v0-1m.bin index 8b53948..ed2689e 100644 Binary files a/tests/expected-logits-6v0-1m.bin and b/tests/expected-logits-6v0-1m.bin differ diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 98767b3..e75cb89 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -46,7 +46,7 @@ int main(void) { +0.206919F, // FP16 // 6v0 +0.001000F, // FP32 - +0.206919F // FP16 + -0.184410F // FP16 }; // *** Why the hell the expected logit difference sum for v4 models is < 1, and for v5 models it can be as high as 160? *** @@ -83,11 +83,11 @@ int main(void) { +048.068733F, // Q5_1 -009.441034F, // Q8_0 // 6v0 - +035.271305F, // Q4_0 - +061.719509F, // Q4_1 - +025.273308F, // Q5_0 - +048.068733F, // Q5_1 - -009.441034F // Q8_0 + +039.715752F, // Q4_0 + +049.779972F, // Q4_1 + -005.441267F, // Q5_0 + -017.046452F, // Q5_1 + -000.220227F // Q8_0 }; const float expected_difference_sum_quantized_FP16[VERSION_COUNT * (FORMAT_COUNT - 2)] = { @@ -110,11 +110,11 @@ int main(void) { +029.726818F, // Q5_1 -007.242277F, // Q8_0 // 6v0 - +034.135971F, // Q4_0 - +059.066830F, // Q4_1 - +021.588751F, // Q5_0 - +029.726818F, // Q5_1 - -007.242277F // Q8_0 + +039.676075F, // Q4_0 + +049.956646F, // Q4_1 + -005.413362F, // Q5_0 + -016.773785F, // Q5_1 + -000.038582F // Q8_0 }; for (int i_version = 0; i_version < VERSION_COUNT; i_version++) { diff --git a/tests/tiny-rwkv-6v0-1m-FP16.bin b/tests/tiny-rwkv-6v0-1m-FP16.bin index f7ba004..bd32e38 100644 Binary files a/tests/tiny-rwkv-6v0-1m-FP16.bin and b/tests/tiny-rwkv-6v0-1m-FP16.bin differ diff --git a/tests/tiny-rwkv-6v0-1m-FP32.bin b/tests/tiny-rwkv-6v0-1m-FP32.bin index 9240d88..e40d325 100644 Binary files a/tests/tiny-rwkv-6v0-1m-FP32.bin and b/tests/tiny-rwkv-6v0-1m-FP32.bin differ diff --git a/tests/tiny-rwkv-6v0-1m-Q5_0.bin b/tests/tiny-rwkv-6v0-1m-Q5_0.bin index b3ac400..7ceb358 100644 Binary files a/tests/tiny-rwkv-6v0-1m-Q5_0.bin and b/tests/tiny-rwkv-6v0-1m-Q5_0.bin differ diff --git a/tests/tiny-rwkv-6v0-1m-Q5_1.bin b/tests/tiny-rwkv-6v0-1m-Q5_1.bin index aa7815a..313cf84 100644 Binary files a/tests/tiny-rwkv-6v0-1m-Q5_1.bin and b/tests/tiny-rwkv-6v0-1m-Q5_1.bin differ