Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
0563c1a
Add general C API for setting tensor params
timmoon10 Dec 4, 2025
5c9b1be
Implement general accessors for NVTETensor
timmoon10 Dec 4, 2025
219ddc1
Merge branch 'main' into tmoon/pre-swizzled-scales
timmoon10 Dec 4, 2025
1c49646
Refactor tex swizzling to skip if scales are already swizzled
timmoon10 Dec 5, 2025
5f60184
Add checks for non-swizzled scales in MXFP8 and NVFP4 kernels
timmoon10 Dec 5, 2025
21ec928
Support pre-swizzled scales in MXFP8Tensor
timmoon10 Dec 5, 2025
fa7e7c0
Add tex function to swizzle MXFP8 scales
timmoon10 Dec 6, 2025
b796c96
Fix bug in inplace swizzle function
timmoon10 Dec 6, 2025
52ce3a4
Tweak comments to use "compact/swizzled format"
timmoon10 Dec 6, 2025
5c7c1d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2025
dfb4b94
MXFP8 quantize kernel with pre-swizzled scales
timmoon10 Dec 9, 2025
1a8b551
Expose pre-swizzled scales in modules
timmoon10 Dec 9, 2025
cb1254a
Fix bug in multi-swizzle
timmoon10 Dec 10, 2025
8b10300
Support MXFP8 gated activations with swizzled scales
timmoon10 Dec 10, 2025
1de4b5e
Merge branch 'main' into tmoon/pre-swizzled-scales
timmoon10 Dec 10, 2025
8c6ea61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2025
a0184bc
Add PyTorch infrastructure for pre-swizzled NVFP4 tensors
timmoon10 Dec 10, 2025
2365821
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
bf12da9
Deprecate DSv3-specific quantization logic in C API
timmoon10 Dec 11, 2025
a89c006
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
b7eced8
Remove support for DSv3 compact data from quantizer
timmoon10 Dec 12, 2025
1da2c19
Remove DSv3 compact data format from core lib
timmoon10 Dec 12, 2025
9ed62cb
Fix bug in FP8 all-gather
timmoon10 Dec 12, 2025
43c8132
Fix linter warnings
timmoon10 Dec 12, 2025
f37036e
Update JAX to use new swizzled scale API
timmoon10 Dec 12, 2025
c549e90
Merge branch 'main' into tmoon/pre-swizzled-scales
timmoon10 Dec 12, 2025
4b06462
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2025
6c11bb5
Review suggestion from @greptile-apps
timmoon10 Dec 12, 2025
736a971
Review suggestions from @greptile-apps
timmoon10 Dec 12, 2025
8b5e43d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2025
78b572c
Update C++ swizzle test with swizzled scales API
timmoon10 Dec 12, 2025
d13760c
Return default tensor params when querying params for invalid NVTETensor
timmoon10 Dec 13, 2025
9cc7fe4
Debug DSv3 FP8 test failures
timmoon10 Dec 13, 2025
41c8d51
Debug Userbuffers test failures
timmoon10 Dec 13, 2025
7b3e231
Make sure gated activations populate FP8 transpose if needed
timmoon10 Dec 13, 2025
732425c
Merge branch 'main' into tmoon/pre-swizzled-scales
timmoon10 Dec 13, 2025
7b55b9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2025
dc235e9
Review suggestions from @greptile-apps
timmoon10 Dec 13, 2025
5aec484
Disable pre-swizzling with debug quantizer
timmoon10 Dec 15, 2025
f05fd06
Merge branch 'main' into tmoon/pre-swizzled-scales
timmoon10 Dec 22, 2025
c6f12e1
Review suggestion from @greptile-apps
timmoon10 Dec 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions qa/L0_pytorch_debug_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ mkdir -p "$XML_LOG_DIR"

pip install pytest==8.2.1 || error_exit "Failed to install pytest"

pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py"
pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py"
pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py"
pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py"
NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py"
pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py"
pytest -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py"
pytest -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py"
pytest -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py"
pytest -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py"
NVTE_TORCH_COMPILE=0 pytest -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py"
pytest -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py"

# standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py"
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py"
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py"
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py"

if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/operator/test_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row
std::vector<int> scaling_mode = {SF_MODE_X, SF_MODE_Y, 0};
Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING);
output.set_with_gemm_swizzled_scales(true);

fillUniform(&input);

Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ class Tensor {
tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape);
}

void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales){
tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales);
}

void to_cpu() const;
void from_cpu() const;
void set_scale(float scale);
Expand Down
120 changes: 0 additions & 120 deletions tests/pytorch/test_float8_blockwise_scaling_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,126 +87,6 @@ def initialize_for_many_scales(
return result


@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(303, 300),
(305, 256),
# Some larger tiles.
(2000, 2000),
(2048, 2000),
(2000, 1024),
(2048, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
def test_quantization_1D_block_tiling_with_compact_data_and_scales(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
pow_2_scales: bool,
) -> None:
te_dtype = TE_DType[quant_dtype]
tile_size = (1, 128)
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer = BlockwiseQuantizerReference()
sut_quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=eps,
force_pow_2_scales=pow_2_scales,
block_scaling_dim=1,
all_gather_usage=True,
)

# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Input
x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device)

x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False)
x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut)
x_fp8_sut_cpp_alloc = sut_quantizer(x)

assert x_fp8_sut._rowwise_data is not None
qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype)
assert x_fp8_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv
qx_t = x_fp8_sut._columnwise_data
sx_t = x_fp8_sut._columnwise_scale_inv

qresult_ref = ref_quantizer.quantize(
x,
quant_dtype=quant_dtype,
return_transpose=True,
eps=eps,
pow_2_scales=pow_2_scales,
quant_tile_shape=tile_size,
munge_scale_shapes=False,
)
qx_ref, sx_ref, qx_t_ref, sx_t_ref = (
qresult_ref.data,
qresult_ref.scale,
qresult_ref.data_t,
qresult_ref.scale_t,
)

# match the reference quantize transpose output with the columnwise non-transpose method
qx_t_ref = qx_t_ref.transpose(-1, -2).contiguous()
sx_t_ref = sx_t_ref.transpose(-1, -2).contiguous()

# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
assert qx_t is not None
qx_t = qx_t.view(dtype=quant_dtype)
assert qx_t_ref is not None
assert sx_t is not None
assert sx_t_ref is not None
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)

# check that the C++ and Python allocators are equivalent
torch.testing.assert_close(
x_fp8_sut._rowwise_data, x_fp8_sut_cpp_alloc._rowwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._rowwise_scale_inv, x_fp8_sut_cpp_alloc._rowwise_scale_inv, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_data, x_fp8_sut_cpp_alloc._columnwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_scale_inv,
x_fp8_sut_cpp_alloc._columnwise_scale_inv,
atol=0.0,
rtol=0.0,
)

# check if the fp8 output between C++ and Python are the same
assert x_fp8_sut._data_format == x_fp8_sut_cpp_alloc._data_format


def check_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
Expand Down
15 changes: 1 addition & 14 deletions tests/pytorch/test_float8blockwisetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,24 +175,19 @@ def test_quantize_dequantize_columnwise_only(
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False])
@pytest.mark.parametrize("all_gather_usage", [True, False])
def test_quantize_dequantize_dims(
self,
dims: DimsType,
block_scaling_dim: int,
dq_columnwise: bool,
all_gather_usage: bool,
) -> None:
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
atol = _tols[tex.DType.kFloat8E4M3]["atol"]
rtol = _tols[tex.DType.kFloat8E4M3]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
)
self._test_quantize_dequantize(
quantizer=quantizer,
Expand All @@ -218,7 +213,6 @@ def test_quantize_dequantize_compact_format(
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=(block_scaling_dim == 1),
)
self._test_quantize_dequantize(
quantizer=quantizer,
Expand Down Expand Up @@ -283,13 +277,8 @@ def test_data_accessors(self, dims: DimsType, block_scaling_dim: int) -> None:

@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("all_gather_usage", [True, False])
def test_serialization(
self, dims: DimsType, block_scaling_dim: int, all_gather_usage: bool
) -> None:
def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None:
"""Test serialization of Float8BlockwiseQTensor"""
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
device = "cuda"
dtype = torch.bfloat16
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
Expand All @@ -298,7 +287,6 @@ def test_serialization(
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
)

# Create FP8 tensor
Expand All @@ -322,7 +310,6 @@ def test_serialization(
assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled
assert x_fp8_loaded.dtype == x_fp8.dtype
assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype
assert x_fp8_loaded._data_format == x_fp8._data_format

# Test that dequantized values match
x_fp8_dequant = x_fp8.dequantize()
Expand Down
Loading
Loading