Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for QRWKV6 hybrid models & slight optimization for RWKV6 #11001

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

MollySophia
Copy link
Contributor

@MollySophia MollySophia commented Dec 28, 2024

QRWKV6-32B is a new model by Recursal which is a combination of the Qwen2.5 architecture and RWKV6.
It 'converts' a Qwen2.5-32B-Instruct model's QKV attention into RWKV6 linear attention, keeping knowledges in the origin Qwen model while gaining the advantages of linear models (constant vram usage and flops, independent of ctxlen).
More info/model for testing: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1
Some converted GGUF for testing: https://huggingface.co/mollysama/QRWKV6-32B-Instruct-Preview-GGUF

Changes in this PR:

  • Add OP gated linear attention with CPU and CUDA impl, which looks like a simplified version of RWKV6 wkv attention.
  • Model conversion and inferencing for QRWKV6-32B
  • RWKV6 optimizations: graph simplification; concated lerp weights to reduce cpu overhead during inference (credit to @compilade)

Testing details:

  • 32B Q4_0/Q4_K quantized model running on a single 4090 with decent speed:
$ ./build/bin/llama-bench -m ../QRWKV6-32B-Instruct-Preview-v0.1-Q4_0.gguf -ngl 99
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4090, compute capability 8.9, VMM: yes
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| rwkv6qwen2 32B Q4_0            |  19.34 GiB |    34.74 B | CUDA       |  99 |         pp512 |        819.60 ± 1.01 |
| rwkv6qwen2 32B Q4_0            |  19.34 GiB |    34.74 B | CUDA       |  99 |         tg128 |         32.72 ± 0.01 |

build: 5a73dbcb (4397)
  • wikitext2 PPLs:
Quant type PPL
f32 5.6987 +/- 0.03365
q8_0 5.7005 +/- 0.03370
q5_k_s 5.7339 +/- 0.03393
q4_k_m 5.7921 +/- 0.03428
q4_0 5.8568 +/- 0.03481
q2_k 7.4547 +/- 0.04597
  • Performance of QRWKV6-32B difference before/after concating lerp weights together:
image (Sry for the image attachment)
before:
$ ./build/bin/llama-bench -m ../QRWKV6-32B-Instruct-Preview-v0.1/QRWKV6-32B-Instruct-Preview-v0.1-F16.gguf -sm none -mg 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 8 CUDA devices:
  Device 0: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 1: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 2: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 3: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 4: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 5: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 6: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 7: NVIDIA H800, compute capability 9.0, VMM: yes
| model                          |       size |     params | backend    | ngl |    sm |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | ------------: | -------------------: |
| rwkv6qwen2 32B F16             |  65.26 GiB |    34.74 B | CUDA       |  99 |  none |         pp512 |        697.64 ± 0.59 |
| rwkv6qwen2 32B F16             |  65.26 GiB |    34.74 B | CUDA       |  99 |  none |         tg128 |         21.91 ± 0.00 |

build: b7b45753 (4397)
after:
$ ./build/bin/llama-bench -m ../QRWKV6-32B-Instruct-Preview-v0.1/QRWKV6-32B-Instruct-Preview-v0.1-F16-fused-lerp.gguf -sm none -mg 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 8 CUDA devices:
  Device 0: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 1: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 2: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 3: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 4: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 5: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 6: NVIDIA H800, compute capability 9.0, VMM: yes
  Device 7: NVIDIA H800, compute capability 9.0, VMM: yes
| model                          |       size |     params | backend    | ngl |    sm |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | ------------: | -------------------: |
| rwkv6qwen2 32B F16             |  65.26 GiB |    34.74 B | CUDA       |  99 |  none |         pp512 |        731.32 ± 1.10 |
| rwkv6qwen2 32B F16             |  65.26 GiB |    34.74 B | CUDA       |  99 |  none |         tg128 |         26.51 ± 0.01 |

build: b7b45753 (4397)

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend python python script changes ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language labels Dec 28, 2024
Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
@github-actions github-actions bot added the testing Everything test related label Dec 28, 2024
Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant