Skip to content

Conversation

@Laurawly
Copy link
Contributor

@Laurawly Laurawly commented Jan 6, 2026

Add the kernelagent-oink vLLM plugin that registers Blackwell (SM100) RMSNorm
custom ops via torch.library.custom_op under the oink:: namespace:

  • oink::rmsnorm(x, weight, eps) -> Tensor
  • oink::fused_add_rms_norm(x!, residual!, weight, eps) -> () (in-place, vLLM semantics)

The SM100 CuTeDSL implementation is layout-aware and preserves padded-row
strides (stride(1)==1, stride(0)>=N) so torch.compile/CUDA-graph capture sees a
stable stride contract. Includes small-M latency tuning for DSv3-like N=7168
and maintains high-M bandwidth, with correctness-first fallbacks on non-SM100.

  oink::fused_add_rms_norm backed by an SM100 CuTeDSL RMSNorm kernel.

  The ops are torch.compile-friendly (stride-preserving for padded-row inputs)
  and the fused op matches vLLM's in-place residual-add RMSNorm semantics.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 6, 2026
@Laurawly Laurawly changed the title kernelagent-oink: add SM100 CuTeDSL RMSNorm custom ops plugin for vLLM KernelAgent-Oink: Add SM100 CuTeDSL RMSNorm custom ops plugin for vLLM Jan 6, 2026
Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial comments

Need to go through rmsnorm.py


import math
import operator
from typing import Callable, Optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Prefer <type> | None instead of Optional keyword

#66

Comment on lines 21 to 22
numerical behaviour and performance close to the original reference
implementations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original reference implementations.

Commit hash would be nice if you have it handy

if sm >= 100:
# Use the tuned CuTeDSL SM100 kernel. The public API already
# contains all necessary gating and layout checks internally.
_rms = _get_rmsnorm_mod()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: pull _rms out of conditional

    sm = _get_sm(x.device)
    _rms = _get_rmsnorm_mod()

    if sm >= 100: 
        return <>

    return _rms.rmsnorm_ref(...)

assert weight.dim() == 1, "weight must be 1D [N]"

sm = _get_sm(x.device)
if sm >= 100:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Check inverse to reduce nesting

if sm < 100:
    # Non-SM100: keep semantics in-place (correctness-first).

Comment on lines +33 to +39
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
try:
return int(local_rank)
except ValueError:
pass
return 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore suggestion if we want to guard/enable on "off/on" "yes/no"

Suggested change
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
try:
return int(local_rank)
except ValueError:
pass
return 0
rank = os.environ.get("LOCAL_RANK", "0")
return int(rank)

import subprocess
import sys
import threading
from typing import Optional, Tuple
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to keep it for <= Python 3.9 support that's fine. If not let's use | None and tuple for 3.10+

f"falling back to staged SMEM path (returncode={rc}).",
file=sys.stderr,
)
failing_proc = proc_128 if proc_128 is not None else proc_256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to spit out both error traces since both exist+fail?

Since 128 is the fallback, fixing the 256 probe makes more sense right?

_CLUSTER_DIRECT_GMEM_PROBE_WARNED = False


def _probe_cluster_direct_gmem_max_copy_bits() -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't called until ~line 2560, do we want to move this lower?

Specifically somewhere after 263-299 which are still configurating the env variables (and called on import)

"""Resolve copy width (in bits) from the (import-time) policy string."""
if _COPY_BITS_POLICY in {"128"}:
return 128
if _COPY_BITS_POLICY in {"256"} and can_use_256:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why in instead of ==?

# This relies on internal CuTeDSL runtime pointer fields (`_desc`, `_pointer`,
# etc.). If these internals change in a future CuTeDSL upgrade, callers
# should catch AttributeError and fall back to the regular launch path.
device_ptr = int(device_ptr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cast if the function expects an int?

@Laurawly Laurawly requested review from drisspg and v0i0 January 12, 2026 23:01
sm_count = (
sm_count * sm_count_multiple
if N <= 8192
else sm_count // 2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems strange, why would we ever want to run fewer than sm_count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for clustered launches sm_count is effectively a cluster-count heuristic (matching Quack’s naming/launch shape). Launch uses grid=[sm_count, cluster_n, 1], so total CTAs is sm_count * cluster_n.

_PTR_FAST_LAUNCH_TLS = threading.local()


def _env_flag(name: str, default: bool) -> bool:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this already exists in init, remove dupes like that

elif N <= 8192:
# Allow an override (used by 2-rows/CTA path for N≈6k/8k)
try:
return self._tpr_override # type: ignore[attr-defined]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems redundant wrt the first override check?

@@ -0,0 +1,2927 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Member

@msaroufim msaroufim Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review, it's nice that we have an extension system now to try things in VLLM so I mostly want to spend time reviewing the kernel itself and what'd make it easier for vendors like VLLM to actually merge this in. Mostly reiterating points I made here https://x.com/marksaroufim/status/2009096176789016600?s=20

A lot of it stems from this file is too long but I think it shouldn't be too hard to clean it up

  1. we don't need the cache to work over multiple cute DSL versions, presumably they're making breaking changes fairly frequently so let's just pick the latest version and update as needed
  2. The code almost looks like a splatted autotune run because it's trying to handle many cases and choose between different optimization. I think we should just try and ship the one specific config that is fast on some specific shapes on a specific model that the VLLM team cares about on B200. Otherwise they'll have trouble reviewing this code even if it's faster and I'd rather we generalize the code progressively as the need arises
  3. A lot of the pointer marshalling code can be deleted in favor of using tvm-ffi, a good chunk of the file is doing this and this will be error prone
  4. Point 2 also will have unexpected side effects, where tons of fallback makes it unpredictable for an end user precisely which kernel configuration will run which is something all of our numerics sensitive customers will really care about. A user would often like to explicitly state whether they want an op to be in place or not. I'd argue that instead of environment variables gating specific optimizations we should have arguments to a function or separate functions. Even further PyTorch now has an intra kernel dispatcher where we can make guarantees on which specific kernel will be called for a specific shape
  5. Finally while I think an e2e test in VLLM works great, we probably also want some smaller unit tests comparing numerics vs vanilla PyTorch code and Quack right here

- Switch correctness gate to PyTorch ref + record err stats\n- Tighten Softmax/LayerNorm tolerances (Quack-like)\n- Quack-style benchmark suite layout + SVG plots\n- Packaging/README polish for publishability
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants