Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I want to propose a PR for an new ops, which could be in the form of a tritionor a CUDA kerne? #20658

Open
pass-lin opened this issue Dec 18, 2024 · 3 comments
Assignees
Labels
type:feature The user is asking for a new feature.

Comments

@pass-lin
Copy link

pass-lin commented Dec 18, 2024

RWKV is a new-generation RNN model. It has pre-trained versions of different sizes, ranging from 0.3B to 14B. It has performance similar to LLM and the inference advantages of MAMBA.
I want to contribute the RNN part of RWKV to Keras. But I have several questions now. Firstly, the core operator of RWKV, time-mix iteration, is quite fast. Should I wait for the stable version to submit a PR, or should I submit a new op for each minor version?
Secondly, we have implemented the RWKV-6-Keras, and found that if we only use keras' ops operations, the efficiency is relatively low. To achieve high efficiency, we need to implement it based on cuda or triton. Personally, I prefer to provide a triton implementation, and torch will come with the triton library by default. For jax, we only need to install jax-trition additionally to support it.Cuda implementation requires a complete cuda environment, and the jax and torch we usually install with pip cannot directly compile cuda operators. Therefore, the triton implementation seems to be more user-friendly.

@mehtamansi29 mehtamansi29 added type:feature The user is asking for a new feature. keras-team-review-pending Pending review by a Keras team member. labels Dec 18, 2024
@Mr-back007
Copy link

  1. Triton Kernel Implementation
    `import triton
    import triton. language as tl

@triton.jit
def my_op_kernel(x, y, output, N: tl.constexpr):
pid = tl.program_id(0)
start = pid * N
x_data = tl.load(x + start)
y_data = tl.load(y + start)
result = x_data + y_data
tl.store(output + start, result)

def launch_triton_kernel(x, y, output, N):
grid = (N // 1024,)
my_op_kernel[grid](x, y, output, N)
2. CUDA Kernel Implementation#include <cuda_runtime.h>

global void my_op_kernel(float *x, float *y, float *output, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) {
output[idx] = x[idx] + y[idx];
}
}

void launch_cuda_kernel(float *x, float *y, float *output, int N) {
int threadsPerBlock = 256;
int blocks = (N + threadsPerBlock - 1) / threadsPerBlock;
my_op_kernel<<<blocks, threadsPerBlock>>>(x, y, output, N);
cudaDeviceSynchronize();
}
`

To propose a PR for a new operation (ops) in the form of either a Triton or CUDA kernel, here's a concise solution outline:

  1. Triton Kernel Implementation
    python
    Copy code
    import triton
    import triton.language as tl

@triton.jit
def my_op_kernel(x, y, output, N: tl.constexpr):
pid = tl.program_id(0)
start = pid * N
x_data = tl.load(x + start)
y_data = tl.load(y + start)
result = x_data + y_data
tl.store(output + start, result)

def launch_triton_kernel(x, y, output, N):
grid = (N // 1024,)
my_op_kernel[grid](x, y, output, N)
2. CUDA Kernel Implementation
cpp
Copy code
#include <cuda_runtime.h>

global void my_op_kernel(float *x, float *y, float *output, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) {
output[idx] = x[idx] + y[idx];
}
}

void launch_cuda_kernel(float *x, float *y, float *output, int N) {
int threadsPerBlock = 256;
int blocks = (N + threadsPerBlock - 1) / threadsPerBlock;
my_op_kernel<<<blocks, threadsPerBlock>>>(x, y, output, N);
cudaDeviceSynchronize();
}
3. Proposed PR Structure
Title: "Add Custom Operation Kernel (Triton or CUDA)"
Description:
Triton Kernel: Optimized for integration in ML workloads.
CUDA Kernel: Offers low-level control for maximum performance.
Provide the user with the ability to toggle between the two implementations.
def new_op(x, y, output, N, use_triton=True): if use_triton: launch_triton_kernel(x, y, output, N) else: launch_cuda_kernel(x, y, output, N)
Performance: Show that both implementations deliver faster execution for large data.
Testing: Add tests for both kernels.

@mattdangerw mattdangerw removed the keras-team-review-pending Pending review by a Keras team member. label Dec 23, 2024
@mattdangerw mattdangerw self-assigned this Dec 23, 2024
@mattdangerw
Copy link
Member

@pass-lin are you planning to just contribute the ops? Or a model? Via KerasHub or a separate repo?

Firstly, the core operator of RWKV, time-mix iteration, is quite fast. Should I wait for the stable version to submit a PR, or should I submit a new op for each minor version?

I'm not totally sure I follow here. Do you mean the RWKV core operator updates quite quickly? I don't think we would want to have Keras ops track version updates in another project.

If there's some core op functionality we could pull into Keras, that will stay generally applicable for all models of this type over a long period of time, that's a good fit for Keras.

If we are looking at something that is model specific and updates model version to version, that's probably a better fit for KerasHub, along with the actual model implementation it does with.

Secondly, we have implemented the RWKV-6-Keras, and found that if we only use keras' ops operations, the efficiency is relatively low.

The triton question is a good one. I'm not totally sure. In general, we try to keep all Keras features and KerasHub models supporting both GPUs and TPUs. Would the same slow down apply to TPUs? If not, a fast path for cuda of some sort is reasonable, we already have some for regular RNNs I believe.

@pass-lin
Copy link
Author

@pass-lin are you planning to just contribute the ops? Or a model? Via KerasHub or a separate repo?

Firstly, the core operator of RWKV, time-mix iteration, is quite fast. Should I wait for the stable version to submit a PR, or should I submit a new op for each minor version?

I'm not totally sure I follow here. Do you mean the RWKV core operator updates quite quickly? I don't think we would want to have Keras ops track version updates in another project.

If there's some core op functionality we could pull into Keras, that will stay generally applicable for all models of this type over a long period of time, that's a good fit for Keras.

If we are looking at something that is model specific and updates model version to version, that's probably a better fit for KerasHub, along with the actual model implementation it does with.

Secondly, we have implemented the RWKV-6-Keras, and found that if we only use keras' ops operations, the efficiency is relatively low.

The triton question is a good one. I'm not totally sure. In general, we try to keep all Keras features and KerasHub models supporting both GPUs and TPUs. Would the same slow down apply to TPUs? If not, a fast path for cuda of some sort is reasonable, we already have some for regular RNNs I believe.

The core kernel of rwkv has been updated several times in recent versions. According to your suggestion, I will bring up the relevant kernel after the stable version.
I plan to provide kernels in Keras' ops and implement models in Keras_hub.
When I was implementing RWKV6-Keras, I found that the efficiency of RNN based on for loops was significantly lower than that of CUDA operators. In our preliminary tests on GPUs, JAX with CUDA operators can be 40 times faster when prefilling. The TF implementation without CUDA operators is also similar.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature The user is asking for a new feature.
Projects
None yet
Development

No branches or pull requests

4 participants