Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ mpgemm$(PYEXT): src/bindings.cpp $(HEADERS)

# run pytest
pytest: all
PYTHONPATH=. python3 -m pytest -q tests/test_post_process.py
PYTHONPATH=. python3 -m pytest -q tests/test_api.py

run: all
./$(TARGET_MAIN)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ mpGEMM/
│ └── bindings.cpp
├── tests/
│ ├── test_correctness.cpp
│ ├── test_post_process.py
│ ├── test_api.py
│ └── run_benchmark.cpp
├── scripts/
│ └── benchmark.py
Expand Down
49 changes: 26 additions & 23 deletions scripts/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,55 @@
import os
import sys

# Ensure project root is on PYTHONPATH
# 确保项目根目录在 PYTHONPATH
sys.path.insert(0, os.path.abspath(os.path.join(__file__, os.pardir, '..')))

import numpy as np
import mpgemm
from mpgemm import Activation


def main():
# Matrix dimensions
# 矩阵维度
M, K, N = 128, 128, 128

# Random number generator
rng = np.random.default_rng(42)
# 随机数生成器
rng = np.random.default_rng(2025)

# Generate random quantized weights (INT4 range 0-15)
# 生成随机 INT4 权重(0-15
weights = rng.integers(0, 16, size=(M, K), dtype=np.uint8)
# Generate random FP16 activations
# 生成随机 FP16 激活
activations = rng.standard_normal(size=(K, N)).astype(np.float16)
# Generate random bias (FP32)
# 随机 biasFP32
bias = rng.standard_normal(size=N).astype(np.float32)

# Initialize GEMM engine (LUT backend)
gemm = mpgemm.Engine("lut")
gemm.generate_lut(bit_width=4)

# Flatten inputs to Python lists
# 扁平化并转为 Python 列表
w_flat = weights.flatten().tolist()
# cast activations to Python float list
a_flat = activations.flatten().astype(float).tolist()
bias_list = bias.tolist()

# Perform matrix multiplication
out_flat = gemm.matmul(w_flat, a_flat, M, K, N)
# === 1. 基准参考输出 ===
gemm_ref = mpgemm.Engine("naive")
ref_flat = gemm_ref.matmul(w_flat, a_flat, M, K, N)

# Post-processing
out_biased = gemm.add_bias(out_flat, M, N, bias_list)
out_relu = gemm.apply_activation(out_biased, M, N, Activation.ReLU)
# === 2. LUT 后端输出 ===
gemm_lut = mpgemm.Engine("lut")
gemm_lut.generate_lut(bit_width=4)
out_flat = gemm_lut.matmul(w_flat, a_flat, M, K, N)

# Reshape back to matrix
output = np.array(out_relu, dtype=np.float32).reshape(M, N)
# === 3. 后处理示例 ===
out_biased = gemm_lut.add_bias(out_flat, M, N, bias_list)
out_relu = gemm_lut.apply_activation(out_biased, M, N, Activation.ReLU)

# Display results
# 还原成矩阵
output = np.array(out_relu, dtype=np.float32).reshape(M, N)
print(f"Output shape: {output.shape}")
print("Sample output [0,:5]:", output[0, :5])
print("Sample row[0, :5]:", output[0, :5])

# === 4. 误差分析示例 ===
stats = mpgemm.measure_error(ref_flat, out_flat)
print(f"\nError relative to naive:")
print(f" MSE = {stats['mse']:.6f}")
print(f" Max error = {stats['max_error']:.6f}")

if __name__ == "__main__":
main()
24 changes: 24 additions & 0 deletions src/accuracy_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once
#include <vector>
#include <cmath>
#include <cstddef>
#include <algorithm>

struct ErrorStats {
double mse;
double max_error;
};

// Computes MSE and max absolute error between two same-sized flat arrays.
inline ErrorStats measure_error(const std::vector<float>& ref,
const std::vector<float>& test) {
size_t N = ref.size();
double sum_sq = 0.0;
double max_err = 0.0;
for (size_t i = 0; i < N; ++i) {
double diff = double(test[i]) - double(ref[i]);
sum_sq += diff * diff;
max_err = std::max(max_err, std::abs(diff));
}
return { sum_sq / N, max_err };
}
23 changes: 23 additions & 0 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "lut_utils.hpp"
#include "post_processing.hpp"
#include "gemm_engine.hpp"
#include "accuracy_utils.hpp"

namespace py = pybind11;

Expand Down Expand Up @@ -71,4 +72,26 @@ PYBIND11_MODULE(mpgemm, m) {
.def("apply_activation", &Engine::apply_activation,
"Apply activation to GEMM output",
py::arg("C"), py::arg("M"), py::arg("N"), py::arg("act"));

// --- Error measurement ---
py::class_<ErrorStats>(m, "ErrorStats")
.def_readonly("mse", &ErrorStats::mse)
.def_readonly("max_error", &ErrorStats::max_error);
m.def("measure_error",
[](const std::vector<float>& ref,
const std::vector<float>& test) {
auto s = measure_error(ref, test);
py::dict d;
d["mse"] = s.mse;
d["max_error"] = s.max_error;
return d;
},
py::arg("reference"),
py::arg("test"),
R"(
Compute error statistics between two flat float lists:
- mse: mean squared error
- max_error: maximum absolute error
Returns a dict: {\"mse\": ..., \"max_error\": ...}
)");
}
1 change: 0 additions & 1 deletion src/gemm_engine.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// src/gemm_engine.hpp
#pragma once
#include <string>
#include <vector>
Expand Down
7 changes: 7 additions & 0 deletions tests/test_post_process.py → tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ def test_relu():
R = mpgemm.apply_activation(M.flatten().tolist(), 2, 2, mpgemm.Activation.ReLU)
R = np.array(R).reshape(2,2)
assert np.all(R >= 0)

def test_measure_error():
ref = [1.0, 2.0, 3.0]
test = [1.1, 1.9, 2.5]
stats = mpgemm.measure_error(ref, test)
assert abs(stats["mse"] - 0.09) < 1e-6
assert abs(stats["max_error"] - 0.5) < 1e-6
19 changes: 18 additions & 1 deletion tests/test_correctness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
#include "../src/lut_utils.hpp"
#include "../src/quant_utils.hpp"
#include "../src/post_processing.hpp"
#include "../src/accuracy_utils.hpp"

#include <iostream>
#include <fstream>
#include <sstream>
#include <stdexcept>
#include <random>
#include <cassert>

// Helper: compare two matrices for equality
template<typename T, typename Layout, typename Storage>
Expand Down Expand Up @@ -376,10 +378,24 @@ bool run_linear_test() {
return pass;
}

// 14. accuracy test
bool run_accuracy_test() {
std::cout << "Running accuracy test...\n";
std::vector<float> A = {1.0f, 2.0f, 3.0f};
std::vector<float> B = {1.1f, 1.9f, 2.5f};
auto stats = measure_error(A, B);
// manual:
// diffs = {0.1, -0.1, -0.5}, sq = {0.01,0.01,0.25}, mse = 0.27/3 = 0.09
assert(std::fabs(stats.mse - 0.09) < 1e-6);
assert(std::fabs(stats.max_error - 0.5) < 1e-6);
std::cout << "Accuracy test PASS\n";
return true;
}


int main() {
int passed=0;
int total=15;
int total=16;
if (run_basic_test()) ++passed;
if (run_negative_test()) ++passed;
if (run_non_square_test()) ++passed;
Expand All @@ -395,6 +411,7 @@ int main() {
if (run_sigmoid_test()) ++passed;
if (run_tanh_test()) ++passed;
if (run_linear_test()) ++passed;
if (run_accuracy_test()) ++passed;
#ifdef USE_MKL
++total; // 只有啟用 MKL 才加總數
if (run_mkl_test()) ++passed;
Expand Down