-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I want to propose a PR for an new ops, which could be in the form of a tritionor a CUDA kerne? #20658
Comments
@triton.jit def launch_triton_kernel(x, y, output, N): global void my_op_kernel(float *x, float *y, float *output, int N) { void launch_cuda_kernel(float *x, float *y, float *output, int N) { To propose a PR for a new operation (ops) in the form of either a Triton or CUDA kernel, here's a concise solution outline:
@triton.jit def launch_triton_kernel(x, y, output, N): global void my_op_kernel(float *x, float *y, float *output, int N) { void launch_cuda_kernel(float *x, float *y, float *output, int N) { |
@pass-lin are you planning to just contribute the ops? Or a model? Via KerasHub or a separate repo?
I'm not totally sure I follow here. Do you mean the RWKV core operator updates quite quickly? I don't think we would want to have Keras ops track version updates in another project. If there's some core op functionality we could pull into Keras, that will stay generally applicable for all models of this type over a long period of time, that's a good fit for Keras. If we are looking at something that is model specific and updates model version to version, that's probably a better fit for KerasHub, along with the actual model implementation it does with.
The triton question is a good one. I'm not totally sure. In general, we try to keep all Keras features and KerasHub models supporting both GPUs and TPUs. Would the same slow down apply to TPUs? If not, a fast path for cuda of some sort is reasonable, we already have some for regular RNNs I believe. |
The core kernel of rwkv has been updated several times in recent versions. According to your suggestion, I will bring up the relevant kernel after the stable version. |
RWKV is a new-generation RNN model. It has pre-trained versions of different sizes, ranging from 0.3B to 14B. It has performance similar to LLM and the inference advantages of MAMBA.
I want to contribute the RNN part of RWKV to Keras. But I have several questions now. Firstly, the core operator of RWKV, time-mix iteration, is quite fast. Should I wait for the stable version to submit a PR, or should I submit a new op for each minor version?
Secondly, we have implemented the RWKV-6-Keras, and found that if we only use keras' ops operations, the efficiency is relatively low. To achieve high efficiency, we need to implement it based on cuda or triton. Personally, I prefer to provide a triton implementation, and torch will come with the triton library by default. For jax, we only need to install jax-trition additionally to support it.Cuda implementation requires a complete cuda environment, and the jax and torch we usually install with pip cannot directly compile cuda operators. Therefore, the triton implementation seems to be more user-friendly.
The text was updated successfully, but these errors were encountered: