Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] tuning deepseek v2/v3 fused_moe_triton crashed. #2599

Open
5 tasks done
BBuf opened this issue Dec 26, 2024 · 2 comments
Open
5 tasks done

[Bug] tuning deepseek v2/v3 fused_moe_triton crashed. #2599

BBuf opened this issue Dec 26, 2024 · 2 comments
Assignees

Comments

@BBuf
Copy link
Collaborator

BBuf commented Dec 26, 2024

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

When tuning fused_moe_triton for DeepSeek v2/v3, it crashed.

amespace(model='deepseek-ai/DeepSeek-V3-Base', tp_size=8, dtype='fp8_w8a8', seed=0, batch_size=None, tune=True)
A new version of the following files was downloaded from https://huggingface.co/deepseek-ai/DeepSeek-V3-Base:
- configuration_deepseek.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
2024-12-26 07:34:39,941	INFO worker.py:1821 -- Started a local Ray instance.
Start tuning over 1920 configurations...
(pid=5645)  0:   0%|                                                                                                                                                                                           | 0.00/1.92k [00:00<?, ?it/sTraceback (most recent call last):                                                                                                                                                                              | 0.00/1.92k [00:00<?, ?it/s]
  File "/sgl-workspace/sglang/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py", line 456, in <module>                                                                                            | 0.00/1.92k [00:00<?, ?it/s]
    main(args)   0%|                                                                                                                                                                                           | 0.00/1.92k [00:00<?, ?it/s]
  File "/sgl-workspace/sglang/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py", line 387, in main                                                                                                | 0.00/1.92k [00:00<?, ?it/s]
    configs = _distribute(                                                                                                                                                                                     | 0.00/1.92k [00:00<?, ?it/s]
  File "/sgl-workspace/sglang/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py", line 380, in _distribute                                                                                         | 0.00/1.92k [00:00<?, ?it/s]
    return ray.get(outputs)                                                                                                                                                                                    | 0.00/1.92k [00:00<?, ?it/s]
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2755, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 906, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::BenchmarkWorker.tune() (pid=5645, ip=172.17.0.3, actor_id=b0915f88d10281612620b09501000000, repr=<tuning_fused_moe_triton.BenchmarkWorker object at 0x7f8d7ff5dea0>)
  File "/sgl-workspace/sglang/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py", line 103, in run
    fused_moe(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 902, in fused_moe
    return fused_experts(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 649, in fused_experts
    torch.ops.sglang.inplace_fused_experts(
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 106, in __torch_function__
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 535, in inplace_fused_experts
    fused_experts_impl(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 782, in fused_experts_impl
    invoke_fused_moe_kernel(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 344, in invoke_fused_moe_kernel
    fused_moe_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [CUDA]: operation failed due to a previous error during capture

During handling of the above exception, another exception occurred:

ray::BenchmarkWorker.tune() (pid=5645, ip=172.17.0.3, actor_id=b0915f88d10281612620b09501000000, repr=<tuning_fused_moe_triton.BenchmarkWorker object at 0x7f8d7ff5dea0>)
  File "/sgl-workspace/sglang/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py", line 245, in tune
    kernel_time = benchmark_config(
  File "/sgl-workspace/sglang/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py", line 125, in benchmark_config
    with torch.cuda.graph(graph):
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/graphs.py", line 186, in __exit__
    self.cuda_graph.capture_end()
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/graphs.py", line 84, in capture_end
    super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Reproduction

python3 benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py --model deepseek-ai/DeepSeek-V3-Base --tp-size 8 --dtype fp8_w8a8 --tune

Environment

h100 or h200

@BBuf BBuf self-assigned this Dec 26, 2024
@BBuf
Copy link
Collaborator Author

BBuf commented Dec 27, 2024

Single reproduce:

import torch

def minimal_repro():
    torch.set_default_device("cuda:6")
    torch.cuda.manual_seed_all(0)
    
    x = torch.randn(1, 5120, dtype=torch.bfloat16)
    w1 = torch.randn(160, 6144, 5120, dtype=torch.float16)
    w2 = torch.randn(160, 5120, 3072, dtype=torch.float16)
    input_gating = torch.randn(1, 160, dtype=torch.float32)
    
    w1 = w1.to(torch.float8_e4m3fn)
    w2 = w2.to(torch.float8_e4m3fn)
    
    w1_scale = torch.randn(160, dtype=torch.float32)
    w2_scale = torch.randn(160, dtype=torch.float32)
    a1_scale = torch.randn(1, dtype=torch.float32)
    a2_scale = torch.randn(1, dtype=torch.float32)
    
    from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
        fused_moe
    )
    from sglang.srt.layers.moe.fused_moe_triton import override_config
    
    config = {
        "BLOCK_SIZE_M": 128,
        "BLOCK_SIZE_N": 64,
        "BLOCK_SIZE_K": 128,
        "GROUP_SIZE_M": 8,
        "num_warps": 4,
        "num_stages": 3,
    }
    
    def run():
        with override_config(config):
            fused_moe(
                x,
                w1,
                w2,
                input_gating,
                topk=6,
                renormalize=True,
                inplace=True,
                use_fp8_w8a8=True,
                use_int8_w8a16=False,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
            )
    
    run()
    torch.cuda.synchronize()
    
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        run()

    torch.cuda.synchronize()

if __name__ == "__main__":
    minimal_repro() 

Error message:

Traceback (most recent call last):
  File "/opt/dlami/nvme/bbuf/repro.py", line 70, in <module>
    minimal_repro() 
  File "/opt/dlami/nvme/bbuf/repro.py", line 58, in minimal_repro
    run()
  File "/opt/dlami/nvme/bbuf/repro.py", line 41, in run
    fused_moe(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 902, in fused_moe
    return fused_experts(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 649, in fused_experts
    torch.ops.sglang.inplace_fused_experts(
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 106, in __torch_function__
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 535, in inplace_fused_experts
    fused_experts_impl(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 782, in fused_experts_impl
    invoke_fused_moe_kernel(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 344, in invoke_fused_moe_kernel
    fused_moe_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 381, in __getattribute__
    self._init_handles()
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 376, in _init_handles
    self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

Use compute-sanitizer to debug illegal memory access :

 compute-sanitizer  python3 repro.py 

compute-sanitizer logs:

========= COMPUTE-SANITIZER
========= Invalid __global__ read of size 4 bytes
=========     at 0x8b0 in void vllm::moe::moe_align_block_size_kernel<int>(T1 *, int *, int *, int *, int, int, unsigned long)
=========     by thread (0,0,0) in block (0,0,0)
=========     Address 0x74e8fde03c00 is out of bounds
=========     and is 811,613,184 bytes after the nearest allocation at 0x74e8cd800000 of size 1 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame: [0x2f26f0]
=========                in /usr/lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame: [0x15804]
=========                in /usr/local/lib/python3.10/dist-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
=========     Host Frame:cudaLaunchKernel [0x75231]
=========                in /usr/local/lib/python3.10/dist-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
=========     Host Frame:moe_align_block_size(at::Tensor, long, long, at::Tensor, at::Tensor, at::Tensor) [0x149d0]
=========                in /usr/local/lib/python3.10/dist-packages/vllm/_moe_C.abi3.so
=========     Host Frame:c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<void (*)(at::Tensor, long, long, at::Tensor, at::Tensor, at::Tensor), void, c10::guts::typelist::typelist<at::Tensor, long, long, at::Tensor, at::Tensor, at::Tensor> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) [0x13851]
=========                in /usr/local/lib/python3.10/dist-packages/vllm/_moe_C.abi3.so
=========     Host Frame:c10::OperatorHandle::redispatchBoxed(c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const [0x55b224b]
=========                in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::basicAutogradNotImplementedFallbackImpl(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) [0x55afad9]
=========                in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:void c10::BoxedKernel::make_boxed_function<&(anonymous namespace)::autograd_fallback>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) [0x1a8c3f8]
=========                in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const [0xcff728]
=========                in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so
=========     Host Frame:torch::jit::invokeOperatorFromPython(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, pybind11::args const&, pybind11::kwargs const&, std::optional<c10::DispatchKey>) [0xa8e136]
=========                in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so
=========     Host Frame:torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args const&, pybind11::kwargs const&, bool, std::optional<c10::DispatchKey>) [0xa8e447]
=========                in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so
=========     Host Frame:pybind11::cpp_function::initialize<torch::jit::initJITBindings(_object*)::{lambda(std::string const&)#217}::operator()(std::string const&) const::{lambda(pybind11::args const&, pybind11::kwargs const&)#1}, pybind11::object, pybind11::args const, pybind11::kwargs const, pybind11::name, pybind11::doc>(torch::jit::initJITBindings(_object*)::{lambda(std::string const&)#217}::operator()(std::string const&) const::{lambda(pybind11::args const&, pybind11::kwargs const&)#1}&&, pybind11::object (*)(pybind11::args const, pybind11::kwargs const), pybind11::name const&, pybind11::doc const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail) [0x976c22]
=========                in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so
=========     Host Frame:pybind11::cpp_function::dispatcher(_object*, _object*, _object*) [0x4cb474]
=========                in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so
=========     Host Frame: [0x1399f9]
=========                in /usr/bin/python3
=========     Host Frame:PyObject_Call [0x22d1ab]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1afb7e]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:_PyObject_FastCallDictTstate [0x22ffb9]
=========                in /usr/bin/python3
=========     Host Frame:_PyObject_Call_Prepend [0x23014f]
=========                in /usr/bin/python3
=========     Host Frame: [0x2ed647]
=========                in /usr/bin/python3
=========     Host Frame:_PyObject_MakeTpCall [0x22e454]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1af24c]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:PyObject_Call [0x22d16a]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1ac60b]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1ae6fb]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1a9d0a]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1a9d0a]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:PyObject_Call [0x22d16a]
=========                in /usr/bin/python3
=========     Host Frame:pybind11::object 、
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1a9d0a]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1a9d0a]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:PyObject_Call [0x22d16a]
=========                in /usr/bin/python3
=========     Host Frame:pybind11::object 
Traceback (most recent call last):
  File "/opt/dlami/nvme/bbuf/repro.py", line 70, in <module>
    minimal_repro() 
  File "/opt/dlami/nvme/bbuf/repro.py", line 58, in minimal_repro
    run()
  File "/opt/dlami/nvme/bbuf/repro.py", line 41, in run
    fused_moe(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 902, in fused_moe
    return fused_experts(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 649, in fused_experts
    torch.ops.sglang.inplace_fused_experts(
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 106, in __torch_function__
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 535, in inplace_fused_experts
    fused_experts_impl(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 782, in fused_experts_impl
    invoke_fused_moe_kernel(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 344, in invoke_fused_moe_kernel
    fused_moe_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 381, in __getattribute__
    self._init_handles()
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 376, in _init_handles
    self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
RuntimeError: Triton Error [CUDA]: unspecified launch failure
========= Target application returned an error
========= ERROR SUMMARY: 7 errors

Conclusion:

========= Invalid __global__ read of size 4 bytes
=========     at 0x8b0 in void vllm::moe::moe_align_block_size_kernel<int>(T1 *, int *, int *, int *, int, int, unsigned long)
=========     by thread (4,0,0) in block (0,0,0)
=========     Address 0x74e8fde03c10 is out of bounds
=========     and is 811,613,200 bytes after the nearest allocation at 0x74e8cd800000 of size 1 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time

@zhyncs It seems that https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu kernel exists another memory out of bounds bug. And it should affect Deepseek V2/V3.

@BBuf
Copy link
Collaborator Author

BBuf commented Dec 29, 2024

2024/12/29:

========= COMPUTE-SANITIZER
========= Invalid __global__ read of size 16 bytes
=========     at 0xda0 in /opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py:175:fused_moe_kernel
=========     by thread (96,0,0) in block (44,0,0)
=========     Address 0x763995f9f000 is out of bounds
=========     and is 1,241,911,296 bytes before the nearest allocation at 0x7639e0000000 of size 5,033,164,800 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame: [0x2f26f0]
=========                in /usr/lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame: [0x2800]
=========                in /root/.triton/cache/b8a3dfcada0aedee787ec15cd5e0e9d8c94897015b9fdb243d4e29ce3a773d89/__triton_launcher.so
=========     Host Frame: [0x139a51]
=========                in /usr/bin/python3
=========     Host Frame:PyObject_Call [0x22d1ab]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1afb7e]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:_PyObject_FastCallDictTstate [0x22ffb9]
=========                in /usr/bin/python3
=========     Host Frame:_PyObject_Call_Prepend [0x23014f]
=========                in /usr/bin/python3
=========     Host Frame: [0x2ed647]
=========                in /usr/bin/python3
=========     Host Frame:PyObject_Call [0x22d1ab]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1ac60b]
=========                in /usr/bin/python3
=========     Host Frame: [0x1473ba]
=========                in /usr/bin/python3
=========     Host Frame:PyObject_Call [0x22d06c]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1ac60b]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:PyObject_Call [0x22d06c]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1ac60b]
=========                in /usr/bin/python3
=========     Host Frame:_PyFunction_Vectorcall [0x22d490]
=========                in /usr/bin/python3
=========     Host Frame:_PyEval_EvalFrameDefault [0x1aadc5]
=========                in /usr/bin/python3

Traceback (most recent call last):
  File "/opt/dlami/nvme/bbuf/sglang/../repro.py", line 62, in <module>
    minimal_repro() 
  File "/opt/dlami/nvme/bbuf/sglang/../repro.py", line 52, in minimal_repro
    run()
  File "/opt/dlami/nvme/bbuf/sglang/../repro.py", line 36, in run
    fused_moe(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 936, in fused_moe
    return fused_experts(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 683, in fused_experts
    torch.ops.sglang.inplace_fused_experts(
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 106, in __torch_function__
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 569, in inplace_fused_experts
    fused_experts_impl(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 838, in fused_experts_impl
    invoke_fused_moe_kernel(
  File "/opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 357, in invoke_fused_moe_kernel
    fused_moe_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 381, in __getattribute__
    self._init_handles()
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 376, in _init_handles
    self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
RuntimeError: Triton Error [CUDA]: unspecified launch failure
========= Target application returned an error
========= ERROR SUMMARY: 25955 errors
========= ERROR SUMMARY: 25855 errors were not printed. Use --print-limit option to adjust the number of printed errors

memory out of bound error happend in fused_moe_triton kernel with nightly sglang:

========= Invalid __global__ read of size 16 bytes
=========     at 0xda0 in /opt/dlami/nvme/bbuf/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py:175:fused_moe_kernel
=========     by thread (96,0,0) in block (44,0,0)
=========     Address 0x763995f9f000 is out of bounds
=========     and is 1,241,911,296 bytes before the nearest allocation at 0x7639e0000000 of size 5,033,164,800 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant