diff --git a/tests/pytorch/triton_kernels/test_norms.py b/tests/pytorch/triton_kernels/test_norms.py index 44c481e29..b9b652600 100644 --- a/tests/pytorch/triton_kernels/test_norms.py +++ b/tests/pytorch/triton_kernels/test_norms.py @@ -17,13 +17,11 @@ ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor -from transformer_engine.pytorch.triton_kernels.rmsnorm import ( - te_rmsnorm_bwd_triton, - te_rmsnorm_fwd_triton, -) -from transformer_engine.pytorch.triton_kernels.layernorm import ( +from transformer_engine.pytorch.triton_kernels.norms import ( te_layernorm_bwd_triton, te_layernorm_fwd_triton, + te_rmsnorm_bwd_triton, + te_rmsnorm_fwd_triton, ) from test_common import dtype_tols, te_compare_results, str_to_torch_dtype, fill_uniform diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index c9a823fe3..42efc3e0f 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -24,8 +24,12 @@ _use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0"))) if IS_HIP_EXTENSION: - from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton - from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton + from ..triton_kernels.norms import ( + te_layernorm_fwd_triton, + te_layernorm_bwd_triton, + te_rmsnorm_fwd_triton, + te_rmsnorm_bwd_triton + ) def _get_normalization_func(normalization: str, forward: bool): use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 264f5d937..dbb052733 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -68,8 +68,7 @@ ) if IS_HIP_EXTENSION: - from ..triton_kernels.layernorm import te_layernorm_bwd_triton - from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton + from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a8a95bfac..cd5766bbb 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -75,8 +75,7 @@ ) if IS_HIP_EXTENSION: - from ..triton_kernels.layernorm import te_layernorm_bwd_triton - from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton + from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index c94459bc3..d6294ee89 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -17,7 +17,7 @@ from transformer_engine_torch import layernorm_bwd, layernorm_fwd from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: - from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton + from ...triton_kernels.norms import te_layernorm_fwd_triton, te_layernorm_bwd_triton from ...fp8 import FP8GlobalStateManager from ...tensor import QuantizedTensor from ...constants import TE_DType diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index e945d25fc..a7bbe5d7f 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -17,7 +17,7 @@ from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: - from ...triton_kernels.rmsnorm import ( + from ...triton_kernels.norms import ( te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton ) diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index 265093b73..27e279f40 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -3,25 +3,9 @@ from itertools import product -import os -import torch - -from ..tensor.float8_tensor import Float8Quantizer -from ..constants import TE_DType -from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor.quantized_tensor import Quantizer -from ..triton_kernels.cast import te_quantize_triton import triton import triton.language as tl -import warnings -import transformer_engine_torch as tex -from .common import ( - get_fp8_max, - te_dtype_to_torch_dtype, - te_dtype_to_triton_dtype, -) -from .norm_common import make_ln_out def get_autotune_config(full_tuning_space=False): if full_tuning_space: @@ -36,20 +20,20 @@ def get_autotune_config(full_tuning_space=False): @triton.jit def _layernorm_fwd_triton_impl( - x_ptr, - y_ptr, - w_ptr, + input_ptr, + output_ptr, + g_ptr, b_ptr, mean_ptr, - rstd_ptr, - scale_ptr, - amax_ptr, - scale_inv_ptr, - x_row_stride, - y_row_stride, + rsigma_ptr, + input_row_stride, + output_row_stride, n_rows, n_cols, - eps, + epsilon, + q_amax_ptr, + q_scale_ptr, + scale_inv_ptr, out_transpose_ptr, out_transpose_stride, ZERO_CENTERED_GAMMA: tl.constexpr, @@ -81,12 +65,12 @@ def _layernorm_fwd_triton_impl( start_row = pid if IS_FP8: - scale = tl.load(scale_ptr) + scale = tl.load(q_scale_ptr) amax = 0.0 for row_idx in range(start_row, start_row + rows_per_tile): - x_ptr_start = x_ptr + (row_idx * x_row_stride) - y_ptr_start = y_ptr + (row_idx * y_row_stride) + x_ptr_start = input_ptr + (row_idx * input_row_stride) + y_ptr_start = output_ptr + (row_idx * output_row_stride) n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1 @@ -118,16 +102,16 @@ def _layernorm_fwd_triton_impl( _var += x_block * x_block var = tl.sum(_var, axis=0) / n_cols - rstd = tl.rsqrt(var + eps) + rstd = tl.rsqrt(var + epsilon) # Write mean / rstd tl.store(mean_ptr + row_idx, mean) - tl.store(rstd_ptr + row_idx, rstd) + tl.store(rsigma_ptr + row_idx, rstd) # Normalize and store for blk_idx in range(0, n_cols_blks): cols = blk_idx * BLOCK_SIZE + col_offsets - w_block = tl.load(w_ptr + cols).to(tl.float32) + w_block = tl.load(g_ptr + cols).to(tl.float32) b_block = tl.load(b_ptr + cols).to(tl.float32) x_block = tl.load(x_ptr_start + cols).to(tl.float32) if ZERO_CENTERED_GAMMA: @@ -139,7 +123,7 @@ def _layernorm_fwd_triton_impl( amax = amax_temp if amax_temp > amax else amax y_block = y_block * scale y_block = tl.clamp(y_block, -FP8_MAX, FP8_MAX) - y_block = y_block.to(y_ptr.type.element_ty) + y_block = y_block.to(output_ptr.type.element_ty) tl.store(y_ptr_start + cols, y_block) if MAKE_TRANSPOSE: output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx @@ -148,7 +132,7 @@ def _layernorm_fwd_triton_impl( # For last iteration, do masked load and store cols = n_cols_blks * BLOCK_SIZE + col_offsets mask = cols < n_cols - w_block = tl.load(w_ptr + cols, mask=mask, other=0.0).to(tl.float32) + w_block = tl.load(g_ptr + cols, mask=mask, other=0.0).to(tl.float32) b_block = tl.load(b_ptr + cols, mask=mask, other=0.0).to(tl.float32) x_block = tl.load(x_ptr_start + cols, mask=mask, other=0.0).to(tl.float32) if ZERO_CENTERED_GAMMA: @@ -160,20 +144,20 @@ def _layernorm_fwd_triton_impl( amax = amax_temp if amax_temp > amax else amax y_block = y_block * scale y_block = tl.clamp(y_block, -FP8_MAX, FP8_MAX) - y_block = y_block.to(y_ptr.type.element_ty) + y_block = y_block.to(output_ptr.type.element_ty) tl.store(y_ptr_start + cols, y_block, mask=mask) if MAKE_TRANSPOSE: output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx tl.store(output_t_ptrs, y_block, mask=mask) if IS_FP8: + if pid == 0: + scale_inv = tl.fdiv(1.0, scale) + tl.store(scale_inv_ptr, scale_inv) if APPLY_ATOMIC: - if pid == 0: - scale_inv = tl.fdiv(1.0, scale) - tl.store(scale_inv_ptr, scale_inv) - tl.atomic_max(amax_ptr, amax, sem="relaxed") + tl.atomic_max(q_amax_ptr, amax, sem="relaxed") else: - tl.store(amax_ptr + pid, amax) + tl.store(q_amax_ptr + pid, amax) autotune_dec = triton.autotune(configs=get_autotune_config(), key=["n_rows", "n_cols"], use_cuda_graph=True) _layernorm_fwd_triton = autotune_dec(_layernorm_fwd_triton_impl) @@ -182,8 +166,6 @@ def _layernorm_fwd_triton_impl( def _layernorm_fwd_reduce_triton( amax_input_ptr, amax_output_ptr, - scale_ptr, - scale_inv_ptr, n_rows, BLOCK_SIZE: tl.constexpr, ): @@ -200,12 +182,6 @@ def _layernorm_fwd_reduce_triton( tl.atomic_max(amax_output_ptr, amax, sem="relaxed") - if pid == 0: - scale = tl.load(scale_ptr) - scale_inv = tl.fdiv(1.0, scale) - tl.store(scale_inv_ptr, scale_inv) - - @triton.jit def _layernorm_bwd_dx_fused_triton( DX, # pointer to the input gradient @@ -455,214 +431,3 @@ def _layernorm_bwd_dwdb_triton_v2( sum_db = tl.sum(db, axis=0) tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.type.element_ty), mask=cols < N) tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.type.element_ty), mask=cols < N) - -def te_layernorm_fwd_triton(input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - ln_out: torch.Tensor, - quantizer: Quantizer, - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - autotune: bool = True,): - if sm_margin is not None and sm_margin > 0: - warnings.warn( - '"sm_margin" is not supported in the Triton based forward layer-norm kernel. ' - + f"sm_margin={sm_margin} will be ignored." - ) - device = input.device - M, N = input.shape - - IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) - MAKE_TRANSPOSE = False - - # Create empty tensors for mu and rsigma - mu = torch.empty((M,), dtype=torch.float32, device=device) - rsigma = torch.empty((M,), dtype=torch.float32, device=device) - torch_out_dtype = ( - otype if isinstance(otype, torch.dtype) - else te_dtype_to_torch_dtype(otype) - ) - # Create ln_out - ln_out = make_ln_out(ln_out, quantizer=quantizer, input_shape=input.shape, out_dtype=torch_out_dtype) - # To update the amax ptr directly with atomic max - APPLY_ATOMIC = M < 512 - - # MXFP8 is handled regularly, hence quantizer of Float8Quantizer is considered FP8 - IS_FP8 = isinstance(quantizer, Float8Quantizer) - - amax_temp = torch.empty((M,), dtype=torch.float32, device=device) if IS_FP8 else None - - max_fused_size = 16384 // input.element_size() - BLOCK_SIZE = min(max_fused_size, triton.next_power_of_2(N)) - - out_transpose_ptr = None - out_transpose_stride = None - - # Create necessary values for fp8 if needed - if IS_FP8: - scale = quantizer.scale - amax_out = quantizer.amax - scale_inv = ln_out._scale_inv - cast_out = ln_out._data - MAKE_TRANSPOSE = quantizer.columnwise_usage - if MAKE_TRANSPOSE: - tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) - if ln_out._transpose_invalid: - ln_out._transpose = torch.empty((ln_out._data.shape[1], ln_out._data.shape[0]), dtype=ln_out._data.dtype, device=device) - ln_out._transpose_invalid = False - out_transpose_ptr = triton.reinterpret(ln_out._transpose, tl_dtype) - out_transpose_stride = ln_out._transpose.stride(0) - else: - scale = None - amax_out = None - scale_inv = None - cast_out = ln_out - - kernel = _layernorm_fwd_triton if autotune else _layernorm_fwd_triton_impl - kernel[(M,)]( - input, - triton.reinterpret(cast_out, te_dtype_to_triton_dtype(ln_out._fp8_dtype)) if IS_FP8 else cast_out, - weight, - bias, - mu, - rsigma, - scale, - amax_out if APPLY_ATOMIC else amax_temp, - scale_inv, - input.stride(0), - cast_out.stride(0), - M, - N, - eps, - out_transpose_ptr, - out_transpose_stride, - ZERO_CENTERED_GAMMA=zero_centered_gamma, - BLOCK_SIZE=BLOCK_SIZE, - IS_FP8=IS_FP8, - APPLY_ATOMIC=APPLY_ATOMIC, - # TODO: Improve performance with persistent kernel - # Persistent kernel currently lags behind non persistent version - # It also lags behind TE implementation in a few cases - PERSISTENT=False, - FP8_MAX=get_fp8_max(quantizer.dtype) if IS_FP8 else None, - MAKE_TRANSPOSE=MAKE_TRANSPOSE - ) - - # For MXFP8, we do regular layernorm and then quantize it separately - if IS_MXFP8: - ln_out = te_quantize_triton(ln_out, quantizer) - - # Reduce and find amax if "not APPLY_ATOMIC" is True. - if IS_FP8 and not APPLY_ATOMIC: - _layernorm_fwd_reduce_triton[(triton.cdiv(M, 256),)]( - amax_temp, - amax_out, - scale, - scale_inv, - M, - 256, - ) - return ln_out, mu, rsigma - -# drop in replacement for transformer_engine::pytorch::layernorm_bwd -# TODO: Add support for `sm_margin > 0`. -def te_layernorm_bwd_triton( - dz: torch.Tensor, - x: torch.Tensor, - mu: torch.Tensor, - rsigma: torch.Tensor, - gamma: torch.Tensor, - sm_margin: int, - zero_centered_gamma: bool -): - if sm_margin is not None and sm_margin > 0: - warnings.warn( - '"sm_margin" is not supported in the Triton based backward layer-norm kernel. ' - + f"sm_margin={sm_margin} will be ignored." - ) - M, N = x.shape - # calculate dw and db separately when M is small - IGNORE_DW_DB_IN_FUSED = M <= 512 - tile_num = max(min(256, M // 4), 1) - if M <= 512 and M * N < 64 * 1024 * 1024: - tile_num = M - elif M >= 8192: - tile_num = 2048 - max_fused_size = 32768 // x.element_size() - next_power = triton.next_power_of_2(N) - BLOCK_SIZE = min(max_fused_size, next_power) - # For cases with small M and large N, decrease block size to help with occupancy and register spill - if tile_num == M: - if tile_num > 256: - BLOCK_SIZE = min(BLOCK_SIZE, 2048) - else: - BLOCK_SIZE = min(BLOCK_SIZE, 4096) - USE_BLOCKED = N > BLOCK_SIZE - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - - dx = torch.empty_like(x) - if not IGNORE_DW_DB_IN_FUSED: - _dgamma = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) - _dbeta = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) - else: - _dgamma = None - _dbeta = None - dgamma = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) - dbeta = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) - grid_bwd = (tile_num,) - _layernorm_bwd_dx_fused_triton[grid_bwd]( - dx, - dz, - _dgamma, - _dbeta, - x, - gamma, - mu, - rsigma, - x.stride(0), - N, - ZERO_CENTERED_GAMMA=zero_centered_gamma, - NUM_ROWS=M, - BLOCK_SIZE_N=BLOCK_SIZE, - USE_BLOCKED=USE_BLOCKED, - num_warps=num_warps, - IGNORE_DW_DB=IGNORE_DW_DB_IN_FUSED, - ) - grid_reduce = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE_N"]),) - if not IGNORE_DW_DB_IN_FUSED: - dwdb_block_n = max(16, N // 256) - dwdb_block_n = triton.next_power_of_2(dwdb_block_n) - dwdb_block_m = (64 * 128) // dwdb_block_n - dwdb_block_m = min(triton.next_power_of_2(tile_num), dwdb_block_m) - _layernorm_bwd_dwdb_triton[grid_reduce]( - _dgamma, - _dbeta, - dgamma, - dbeta, - min(tile_num, M), - N, - BLOCK_SIZE_M=dwdb_block_m, - BLOCK_SIZE_N=dwdb_block_n, - ) - else: - dwdb_block_n = max(16, N // 256) - dwdb_block_n = triton.next_power_of_2(dwdb_block_n) - dwdb_block_m = (64 * 128) // dwdb_block_n - dwdb_block_m = min(triton.next_power_of_2(M), dwdb_block_m) - _layernorm_bwd_dwdb_triton_v2[grid_reduce]( - x, - dz, - mu, - rsigma, - x.stride(0), - dgamma, - dbeta, - M, - N, - BLOCK_SIZE_M=dwdb_block_m, - BLOCK_SIZE_N=dwdb_block_n, - ) - - return dx, dgamma, dbeta diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py new file mode 100644 index 000000000..efb9e5cbf --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -0,0 +1,348 @@ +import torch +import triton +import warnings +import transformer_engine_torch as tex + +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.triton_kernels.common import ( + te_dtype_to_torch_dtype, + te_dtype_to_triton_dtype, +) +from ..tensor.quantized_tensor import Quantizer +from .norm_common import num_programs, block_size, use_blocked, make_ln_out +from .common import get_fp8_max +from .rmsnorm import ( + _rmsnorm_fwd_triton, + _rmsnorm_fwd_triton_impl, + _rmsnorm_bwd_triton, + _rmsnorm_bwd_dg_reduce_triton, +) +from .layernorm import ( + _layernorm_fwd_triton, + _layernorm_fwd_triton_impl, + _layernorm_fwd_reduce_triton, + _layernorm_bwd_dwdb_triton, + _layernorm_bwd_dwdb_triton_v2, + _layernorm_bwd_dx_fused_triton, +) + +_norm_kernels={ + "rms":{ + True: _rmsnorm_fwd_triton, + False: _rmsnorm_fwd_triton_impl, + }, + "layer":{ + True: _layernorm_fwd_triton, + False: _layernorm_fwd_triton_impl, + } +} +# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd +def te_rmsnorm_fwd_triton( + input: torch.Tensor, + weight: torch.Tensor, + eps: float, + ln_out: torch.Tensor, + quantizer: Quantizer, + otype: tex.DType, + sm_margin: int, + zero_centered_gamma: bool, + autotune: bool = True, +): + return te_norm_fwd_triton( + kernel='rms', + input_tensor=input, + weight=weight, + bias=None, + eps=eps, + ln_out=ln_out, + quantizer=quantizer, + otype=otype, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + autotune=autotune, + ) + +# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd +def te_layernorm_fwd_triton( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + ln_out: torch.Tensor, + quantizer: Quantizer, + otype: tex.DType, + sm_margin: int, + zero_centered_gamma: bool, + autotune: bool = True, +): + return te_norm_fwd_triton( + kernel='layer', + input_tensor=input, + weight=weight, + bias=bias, + eps=eps, + ln_out=ln_out, + quantizer=quantizer, + otype=otype, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + autotune=autotune, + ) + +# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd +def te_norm_fwd_triton( + kernel: str, + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + ln_out: torch.Tensor, + quantizer: Quantizer, + otype: tex.DType, + sm_margin: int, + zero_centered_gamma: bool, + autotune: bool = True, +): + if kernel not in {'rms', 'layer'}: + raise ValueError(f"Expected `kernel` in ('rms', 'layer') but got {kernel=} instead.") + if eps < 0: + raise ValueError(f"`eps` must be non-negative, but a value of {eps} was passed") + if len(input_tensor.shape) != 2: + raise ValueError( + f"The input must be a 2-dimensional matrix, but an input with {input_tensor.ndim} was passed.") + + device = input_tensor.device + N, H = input_tensor.shape + if weight.shape[0] != H: + raise ValueError( + f"The shape of `weight` must be feature-aligned, " + f"but {weight.shape[0]=} while {input_tensor.shape[1]=}" + ) + IS_FP8 = isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) + IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) + BLOCK_SIZE = block_size(input_tensor) + USE_BLOCKED = use_blocked(input_tensor) + NUM_PRGMS = num_programs(input_tensor, sm_margin) + MAKE_TRANSPOSE = False + APPLY_ATOMIC = N < 512 or kernel == 'rms' + ATOMIC_REDUCTION_BLOCK_SIZE=256 + + mu = torch.empty((N,), dtype=torch.float32, device=device) if kernel == 'layer' else None + rsigma = torch.empty((N,), dtype=torch.float32, device=device) + torch_out_dtype = ( + otype if isinstance(otype, torch.dtype) + else te_dtype_to_torch_dtype(otype) + ) + out = make_ln_out( + ln_out, + quantizer=quantizer, + input_shape=input_tensor.shape, + out_dtype=torch_out_dtype + ) + amax = None + tl_dtype = None + scale_inv_ptr = None + q_scale = None + out_ptr = out + out_transpose_ptr = None + out_transpose_stride = None + FP8_MAX = None + if IS_FP8: + MAKE_TRANSPOSE = quantizer.columnwise_usage + amax = ( + quantizer.amax if APPLY_ATOMIC else + torch.empty((NUM_PRGMS,), dtype=torch.float32, device=device) + ) + tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) + scale_inv_ptr = out._scale_inv + q_scale = quantizer.scale + out_ptr = triton.reinterpret(out._data, tl_dtype) + FP8_MAX = get_fp8_max(quantizer.dtype) + if MAKE_TRANSPOSE: + if out._transpose_invalid: + out._transpose = torch.empty( + (out._data.shape[1], out._data.shape[0]), + dtype=out._data.dtype, device=device + ) + out._transpose_invalid = False + out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) + out_transpose_stride = out._transpose.stride(0) + + grid_fwd = lambda meta: (N if kernel=='layer' else NUM_PRGMS,) + kernel_func = _norm_kernels[kernel][autotune] + kwargs = dict( + input_ptr=input_tensor, + output_ptr=out_ptr, + g_ptr=weight, + rsigma_ptr=rsigma, + input_row_stride=input_tensor.stride(0), + output_row_stride=out_ptr.stride(0), + n_rows=N, n_cols=H, + epsilon=eps, + q_amax_ptr=amax, + q_scale_ptr=q_scale, + scale_inv_ptr=scale_inv_ptr, + out_transpose_ptr=out_transpose_ptr, + out_transpose_stride=out_transpose_stride, + ZERO_CENTERED_GAMMA=zero_centered_gamma, + BLOCK_SIZE=BLOCK_SIZE, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + MAKE_TRANSPOSE=MAKE_TRANSPOSE, + ) + if kernel == 'layer': + kwargs["APPLY_ATOMIC"]=APPLY_ATOMIC + kwargs["PERSISTENT"]=False # TODO: Improve persistent algo performance + kwargs["b_ptr"]=bias + kwargs["mean_ptr"]=mu + elif kernel == "rms": + kwargs["USE_BLOCKED"]=USE_BLOCKED + kwargs["NUM_PRGMS"]=NUM_PRGMS + + kernel_func[grid_fwd](**kwargs) + + # Reduce and find amax if "not APPLY_ATOMIC" is True for layernorm. + if IS_FP8 and not APPLY_ATOMIC: + _layernorm_fwd_reduce_triton[(triton.cdiv(N, ATOMIC_REDUCTION_BLOCK_SIZE),)]( + amax, + quantizer.amax, + N, ATOMIC_REDUCTION_BLOCK_SIZE, + ) + elif IS_MXFP8: + out = quantizer.quantize(out, out=ln_out) + + return out, mu, rsigma + + +# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd +def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): + # may take non-contiguous inputs + dz_ = dz.contiguous() + x_ = x.contiguous() + rsigma_ = rsigma.contiguous() + gamma_ = gamma.contiguous() + + dx = torch.empty_like(x_) + dgamma = torch.empty_like(gamma_) + + M, N = x_.shape + blk_size = block_size(x_) + USE_BLOCKED = use_blocked(x_) + NUM_PRGMS = num_programs(x_, sm_margin) + need_reduction = N > 1 + dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin) + dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None + + grid_bwd = lambda meta: (NUM_PRGMS, ) + _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, + x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, + USE_BLOCKED, NUM_PRGMS, num_warps=8) + + if need_reduction: + grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + _rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], + BLOCK_SIZE_M=128, BLOCK_SIZE_N=64) + + return dx, dgamma + +# drop in replacement for transformer_engine::pytorch::layernorm_bwd +# TODO: Add support for `sm_margin > 0`. +def te_layernorm_bwd_triton( + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool +): + if sm_margin is not None and sm_margin > 0: + warnings.warn( + '"sm_margin" is not supported in the Triton based backward layer-norm kernel. ' + + f"sm_margin={sm_margin} will be ignored." + ) + M, N = x.shape + # calculate dw and db separately when M is small + IGNORE_DW_DB_IN_FUSED = M <= 512 + tile_num = max(min(256, M // 4), 1) + if M <= 512 and M * N < 64 * 1024 * 1024: + tile_num = M + elif M >= 8192: + tile_num = 2048 + max_fused_size = 32768 // x.element_size() + next_power = triton.next_power_of_2(N) + BLOCK_SIZE = min(max_fused_size, next_power) + # For cases with small M and large N, decrease block size to help with occupancy and register spill + if tile_num == M: + if tile_num > 256: + BLOCK_SIZE = min(BLOCK_SIZE, 2048) + else: + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + USE_BLOCKED = N > BLOCK_SIZE + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + + dx = torch.empty_like(x) + if not IGNORE_DW_DB_IN_FUSED: + _dgamma = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) + _dbeta = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) + else: + _dgamma = None + _dbeta = None + dgamma = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) + dbeta = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) + grid_bwd = (tile_num,) + _layernorm_bwd_dx_fused_triton[grid_bwd]( + dx, + dz, + _dgamma, + _dbeta, + x, + gamma, + mu, + rsigma, + x.stride(0), + N, + ZERO_CENTERED_GAMMA=zero_centered_gamma, + NUM_ROWS=M, + BLOCK_SIZE_N=BLOCK_SIZE, + USE_BLOCKED=USE_BLOCKED, + num_warps=num_warps, + IGNORE_DW_DB=IGNORE_DW_DB_IN_FUSED, + ) + grid_reduce = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE_N"]),) + if not IGNORE_DW_DB_IN_FUSED: + dwdb_block_n = max(16, N // 256) + dwdb_block_n = triton.next_power_of_2(dwdb_block_n) + dwdb_block_m = (64 * 128) // dwdb_block_n + dwdb_block_m = min(triton.next_power_of_2(tile_num), dwdb_block_m) + _layernorm_bwd_dwdb_triton[grid_reduce]( + _dgamma, + _dbeta, + dgamma, + dbeta, + min(tile_num, M), + N, + BLOCK_SIZE_M=dwdb_block_m, + BLOCK_SIZE_N=dwdb_block_n, + ) + else: + dwdb_block_n = max(16, N // 256) + dwdb_block_n = triton.next_power_of_2(dwdb_block_n) + dwdb_block_m = (64 * 128) // dwdb_block_n + dwdb_block_m = min(triton.next_power_of_2(M), dwdb_block_m) + _layernorm_bwd_dwdb_triton_v2[grid_reduce]( + x, + dz, + mu, + rsigma, + x.stride(0), + dgamma, + dbeta, + M, + N, + BLOCK_SIZE_M=dwdb_block_m, + BLOCK_SIZE_N=dwdb_block_n, + ) + + return dx, dgamma, dbeta diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index b62d61ced..0394ccf73 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -5,40 +5,27 @@ import triton import triton.language as tl from itertools import product -from .norm_common import num_programs, block_size, use_blocked, make_ln_out -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.pytorch.triton_kernels.common import ( - te_dtype_to_torch_dtype, - te_dtype_to_triton_dtype, -) -from .common import get_fp8_max -from ..tensor.quantized_tensor import Quantizer -import transformer_engine_torch as tex - -def dg_tmp_rows(x, sm_margin=None): - return x.shape[0] if use_blocked(x) else num_programs(x, sm_margin) - def get_autotune_config(): return [triton.Config({'waves_per_eu': we}, num_warps=nw) for (we, nw) in product([0, 1, 2, 4], [4, 8, 16])] +# TODO(micky774) Implement fused MXFP8 quantization within the kernel @triton.jit def _rmsnorm_fwd_triton_impl( - output_ptr, input_ptr, - g_ptr, rsigma_ptr, + output_ptr, + g_ptr, + rsigma_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, - amax_ptr, q_amax_ptr, q_scale_ptr, scale_inv_ptr, out_transpose_ptr, - transpose_row_stride, + out_transpose_stride, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, @@ -115,7 +102,7 @@ def _rmsnorm_fwd_triton_impl( rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + cols * transpose_row_stride + row_idx + output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx tl.store(output_t_ptrs, rms_norm.to(output_type)) tl.store(output_ptrs, rms_norm.to(output_type)) @@ -136,7 +123,7 @@ def _rmsnorm_fwd_triton_impl( rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + cols * transpose_row_stride + row_idx + output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx tl.store(output_t_ptrs, rms_norm.to(output_type), mask=mask) tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) @@ -167,11 +154,10 @@ def _rmsnorm_fwd_triton_impl( rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + col_offsets * transpose_row_stride + row_idx + output_t_ptrs = out_transpose_ptr + col_offsets * out_transpose_stride + row_idx tl.store(output_t_ptrs, rms_norm.to(output_type), mask=mask) tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) if IS_FP8: - tl.store(amax_ptr + row_start, amax) tl.atomic_max(q_amax_ptr, amax, sem="relaxed") if row_start == 0: scale = tl.load(q_scale_ptr) @@ -333,134 +319,3 @@ def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n sum_dg = tl.sum(acc, axis=0) tl.store(dg_out_ptr + cols, sum_dg.to(dg_out_ptr.type.element_ty), mask=cols < n_cols) -# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd -def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): - # may take non-contiguous inputs - dz_ = dz.contiguous() - x_ = x.contiguous() - rsigma_ = rsigma.contiguous() - gamma_ = gamma.contiguous() - - dx = torch.empty_like(x_) - dgamma = torch.empty_like(gamma_) - - M, N = x_.shape - blk_size = block_size(x_) - USE_BLOCKED = use_blocked(x_) - NUM_PRGMS = num_programs(x_, sm_margin) - need_reduction = N > 1 - dg_tmp = torch.empty(dg_tmp_rows(x_, sm_margin), N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None - - grid_bwd = lambda meta: (NUM_PRGMS, ) - _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, - x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, - USE_BLOCKED, NUM_PRGMS, num_warps=8) - - if need_reduction: - grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] - _rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], - BLOCK_SIZE_M=128, BLOCK_SIZE_N=64) - - return dx, dgamma - -# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd -def te_rmsnorm_fwd_triton( - input: torch.Tensor, - weight: torch.Tensor, - eps: float, - ln_out: torch.Tensor, - quantizer: Quantizer, - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - autotune: bool = True, -): - if eps < 0: - raise ValueError(f"`eps` must be non-negative, but a value of {eps} was passed") - if len(input.shape) != 2: - raise ValueError( - f"The input must be a 2-dimensional matrix, but an input with {input.ndim} was passed.") - - device = input.device - N, H = input.shape - if weight.shape[0] != H: - raise ValueError( - f"The shape of `weight` must be feature-aligned, " - f"but {weight.shape[0]=} while {input.shape[1]=}" - ) - IS_FP8 = isinstance(quantizer, Float8Quantizer) - IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) - BLOCK_SIZE = block_size(input) - USE_BLOCKED = use_blocked(input) - NUM_PRGMS = num_programs(input, sm_margin) - MAKE_TRANSPOSE = False - - rsigma = torch.empty((N,), dtype=torch.float32, device=device) - torch_out_dtype = ( - otype if isinstance(otype, torch.dtype) - else te_dtype_to_torch_dtype(otype) - ) - out = make_ln_out( - ln_out, - quantizer=quantizer, - input_shape=input.shape, - out_dtype=torch_out_dtype - ) - if IS_FP8: - MAKE_TRANSPOSE = quantizer.columnwise_usage - amax = torch.empty((NUM_PRGMS,), dtype=torch.float32, device=device) - tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) - scale_inv_ptr = out._scale_inv - q_scale = quantizer.scale - q_amax = quantizer.amax - out_ptr = triton.reinterpret(out._data, tl_dtype) - FP8_MAX = get_fp8_max(quantizer.dtype) - if MAKE_TRANSPOSE: - if out._transpose_invalid: - out._transpose = torch.empty((out._data.shape[1], out._data.shape[0]), dtype=out._data.dtype, device=device) - out._transpose_invalid = False - out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) - out_transpose_stride = out._transpose.stride(0) - else: - out_transpose_ptr = None - out_transpose_stride = None - else: - amax = None - tl_dtype = None - scale_inv_ptr = None - q_scale = None - q_amax = None - out_ptr = out - out_transpose_ptr = None - out_transpose_stride = None - FP8_MAX = None - - grid_fwd = lambda meta: (NUM_PRGMS, ) - # TODO(micky774) Implement fused MXFP8 quantization within the kernel - kernel = _rmsnorm_fwd_triton if autotune else _rmsnorm_fwd_triton_impl - kernel[grid_fwd]( - out_ptr, - input, - weight, - rsigma, - input.stride(0), - out_ptr.stride(0), - N, H, eps, - amax, - q_amax, - q_scale, - scale_inv_ptr, - out_transpose_ptr, - out_transpose_stride, - zero_centered_gamma, - BLOCK_SIZE, - USE_BLOCKED, - NUM_PRGMS, - IS_FP8, - FP8_MAX, - MAKE_TRANSPOSE, - ) - if IS_MXFP8: - out = quantizer.quantize(out, out=ln_out) - - return out, None, rsigma