Skip to content

[CuTe] adding DSL INT8 MMA support on SM80+, with CuTe C++ example#3097

Open
ayghri wants to merge 2 commits intoNVIDIA:mainfrom
ayghri:main
Open

[CuTe] adding DSL INT8 MMA support on SM80+, with CuTe C++ example#3097
ayghri wants to merge 2 commits intoNVIDIA:mainfrom
ayghri:main

Conversation

@ayghri
Copy link

@ayghri ayghri commented Mar 9, 2026

This PR adds INT8 warp-level MMA with an Ampere batched GEMM example. Partial solution to #3081

  • Add MmaI8Op and MmaIntOverflow in python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py, wraps mma.sync.aligned with .s8/.u8 qualifiers, shapes (16,8,16) and
    (16,8,32), saturate/wrap overflow modes.
  • examples/python/CuTeDSL/ampere/tensorop_gemm_i8.py. Follows tensorop_gemm.py conventions.
  • examples/cute/tutorial/igemm_sm80.cu. CuTE C++ implementation, follows sgemm_sm80.cu.

What works:

  • S8xS8 signedness, --b_dtype U8 kept for future IR support
  • Arbitrary M, N, K (tensors padded to tile boundaries)
  • N-major or M-major output C
  • Auto-selected BM tile (16/32/64/128) based on M dimension

What does not work:

  • DSL: S8xU8 supported by hardware but blocked by DSL compiler.

INT8 vs FP16 differences:

  • K-major inputs only: ldmatrix requires 128-bit aligned smem. FP16
    column-major gives 8×16b = 128 bits; INT8 column-major gives 8×8b = 64 bits.
  • Swizzle<2,4,3> instead of Swizzle<bits,3,3>: INT8 needs MBase=4
    (not 3) because log2(128/8) = 4.
  • Direct gmem epilogue: INT32 output makes the smem C buffer 128×128×4B =
    64KB, limiting occupancy to 1 CTA/SM (for my RTX 3090, I haven't benchmarked on A100 yet).
  • Python example script auto-selects BM tile based on my local benchmarks

Performance vs PyTorch 2.9.1+cu128 (torch._int_mm), RTX 3090

100 iterations, CUDA event timing, TOPS = 2MN*K/time.

N = K = 4096

M PyTorch (TOPS) CuTe DSL (TOPS) vs PyTorch
16 n/a 18.8 n/a
32 7.0 35.9 514%
64 13.8 60.1 435%
128 27.5 84.5 307%
256 55.0 165.0 300%
512 56.5 179.7 318%
1024 57.1 183.6 322%
2048 65.6 218.9 334%
4096 74.1 226.7 306%

N = K = 16384

M PyTorch (TOPS) CuTe DSL (TOPS) vs PyTorch
32 15.4 49.5 321%
64 28.9 101.2 350%
128 57.6 168.9 293%
256 58.3 179.8 309%
512 69.3 212.7 307%
1024 73.1 225.9 309%
4096 75.0 230.2 307%

Non-aligned dimensions

M N K PyTorch (TOPS) CuTe DSL (TOPS) vs PyTorch
33 4096 4096 7.4 32.7 440%
65 4096 4096 15.1 46.0 305%
200 4096 4096 46.2 138.2 299%
128 1000 4096 7.3 21.1 289%
128 4096 1000 24.2 65.4 270%

benchmarks show cuBLAS within ~1-25% of this kernel for most sizes.

Testing

python examples/python/CuTeDSL/ampere/tensorop_gemm_i8.py --mnkl 512,512,512,1
python examples/python/CuTeDSL/ampere/tensorop_gemm_i8.py --mnkl 33,200,100,1
python examples/python/CuTeDSL/ampere/tensorop_gemm_i8.py --mnkl 4096,4096,4096,1 --skip_ref_check --iterations 100

@ayghri ayghri changed the title adding INT8 MMA support on SM80+ [CuTe] adding DSL INT8 MMA support on SM80+, with CuTe C++ example Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant