From 8ddfa2f28c6d85274fe7565bd5b50a8be7cb2dfe Mon Sep 17 00:00:00 2001 From: robinhundt <24554122+robinhundt@users.noreply.github.com> Date: Fri, 23 May 2025 16:33:05 +0200 Subject: [PATCH 1/7] Remove unused_unsafe fixes #9 --- cryprot-core/src/alloc.rs | 5 +-- cryprot-core/src/block/gf128.rs | 64 +++++++++++++----------------- cryprot-core/src/transpose/avx2.rs | 37 ++++++++--------- 3 files changed, 45 insertions(+), 61 deletions(-) diff --git a/cryprot-core/src/alloc.rs b/cryprot-core/src/alloc.rs index 633f634..ca62376 100644 --- a/cryprot-core/src/alloc.rs +++ b/cryprot-core/src/alloc.rs @@ -51,10 +51,7 @@ impl HugePageMemory { // new_len <= self.capacity // self[len..new_len] is initialized either because of Self::zeroed // or with data written to it. - #[allow(unused_unsafe)] - unsafe { - self.len = new_len; - } + self.len = new_len; } } diff --git a/cryprot-core/src/block/gf128.rs b/cryprot-core/src/block/gf128.rs index 0f0d1eb..dfc3f63 100644 --- a/cryprot-core/src/block/gf128.rs +++ b/cryprot-core/src/block/gf128.rs @@ -104,11 +104,9 @@ mod clmul { #[target_feature(enable = "pclmulqdq")] #[inline] - pub unsafe fn gf128_mul(a: __m128i, b: __m128i) -> __m128i { - unsafe { - let (low, high) = clmul128(a, b); - gf128_reduce(low, high) - } + pub fn gf128_mul(a: __m128i, b: __m128i) -> __m128i { + let (low, high) = clmul128(a, b); + gf128_reduce(low, high) } /// Carry-less multiply of two 128-bit numbers. @@ -116,47 +114,39 @@ mod clmul { /// Return (low, high) bits #[target_feature(enable = "pclmulqdq")] #[inline] - pub unsafe fn clmul128(a: __m128i, b: __m128i) -> (__m128i, __m128i) { - // This is currently needed because we run the nightly version of - // clippy where this is an unused unsafe because the used - // intrinsincs have been marked safe on nightly but not yet on - // stable. - #[allow(unused_unsafe)] - unsafe { - // NOTE: I tried using karatsuba but it was slightly slower than the naive - // multiplication - let ab_low = _mm_clmulepi64_si128::<0x00>(a, b); - let ab_high = _mm_clmulepi64_si128::<0x11>(a, b); - let ab_lohi1 = _mm_clmulepi64_si128::<0x01>(a, b); - let ab_lohi2 = _mm_clmulepi64_si128::<0x10>(a, b); - let ab_mid = _mm_xor_si128(ab_lohi1, ab_lohi2); - let low = _mm_xor_si128(ab_low, _mm_slli_si128::<8>(ab_mid)); - let high = _mm_xor_si128(ab_high, _mm_srli_si128::<8>(ab_mid)); - (low, high) - } + pub fn clmul128(a: __m128i, b: __m128i) -> (__m128i, __m128i) { + // NOTE: I tried using karatsuba but it was slightly slower than the naive + // multiplication + let ab_low = _mm_clmulepi64_si128::<0x00>(a, b); + let ab_high = _mm_clmulepi64_si128::<0x11>(a, b); + let ab_lohi1 = _mm_clmulepi64_si128::<0x01>(a, b); + let ab_lohi2 = _mm_clmulepi64_si128::<0x10>(a, b); + let ab_mid = _mm_xor_si128(ab_lohi1, ab_lohi2); + let low = _mm_xor_si128(ab_low, _mm_slli_si128::<8>(ab_mid)); + let high = _mm_xor_si128(ab_high, _mm_srli_si128::<8>(ab_mid)); + (low, high) } #[target_feature(enable = "pclmulqdq")] #[inline] - pub unsafe fn gf128_reduce(mut low: __m128i, mut high: __m128i) -> __m128i { + pub fn gf128_reduce(mut low: __m128i, mut high: __m128i) -> __m128i { // NOTE: I tried a sse shift based reduction but it was slower than the clmul // implementation - unsafe { - let modulus = [MOD, 0]; - let modulus = _mm_loadu_si64(modulus.as_ptr().cast()); + let modulus = [MOD, 0]; + // SAFETY: Ptr to modulus is valid and pclmulqdq implies sse2 is enabled + let modulus = unsafe { _mm_loadu_si64(modulus.as_ptr().cast()) }; - let tmp = _mm_clmulepi64_si128::<0x01>(high, modulus); - let tmp_shifted = _mm_slli_si128::<8>(tmp); - low = _mm_xor_si128(low, tmp_shifted); - high = _mm_xor_si128(high, tmp_shifted); + let tmp = _mm_clmulepi64_si128::<0x01>(high, modulus); + let tmp_shifted = _mm_slli_si128::<8>(tmp); + low = _mm_xor_si128(low, tmp_shifted); + high = _mm_xor_si128(high, tmp_shifted); - // reduce overflow - let tmp = _mm_clmulepi64_si128::<0x01>(tmp, modulus); - low = _mm_xor_si128(low, tmp); + // reduce overflow + let tmp = _mm_clmulepi64_si128::<0x01>(tmp, modulus); + low = _mm_xor_si128(low, tmp); - let tmp = _mm_clmulepi64_si128::<0x00>(high, modulus); - _mm_xor_si128(low, tmp) - } + let tmp = _mm_clmulepi64_si128::<0x00>(high, modulus); + _mm_xor_si128(low, tmp) } #[cfg(all(test, target_feature = "pclmulqdq"))] diff --git a/cryprot-core/src/transpose/avx2.rs b/cryprot-core/src/transpose/avx2.rs index 5a32716..6409349 100644 --- a/cryprot-core/src/transpose/avx2.rs +++ b/cryprot-core/src/transpose/avx2.rs @@ -9,15 +9,14 @@ unsafe fn _mm256_slli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i { matches!(shift, 2 | 4 | 8 | 16 | 32), "Must be called with correct shift" ); - unsafe { - match shift { - 2 => _mm256_slli_epi64::<2>(a), - 4 => _mm256_slli_epi64::<4>(a), - 8 => _mm256_slli_epi64::<8>(a), - 16 => _mm256_slli_epi64::<16>(a), - 32 => _mm256_slli_epi64::<32>(a), - _ => unreachable_unchecked(), - } + match shift { + 2 => _mm256_slli_epi64::<2>(a), + 4 => _mm256_slli_epi64::<4>(a), + 8 => _mm256_slli_epi64::<8>(a), + 16 => _mm256_slli_epi64::<16>(a), + 32 => _mm256_slli_epi64::<32>(a), + // SAFETY: Shift is upheld by caller + _ => unsafe { unreachable_unchecked() }, } } @@ -29,15 +28,14 @@ unsafe fn _mm256_srli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i { matches!(shift, 2 | 4 | 8 | 16 | 32), "Must be called with correct shift" ); - unsafe { - match shift { - 2 => _mm256_srli_epi64::<2>(a), - 4 => _mm256_srli_epi64::<4>(a), - 8 => _mm256_srli_epi64::<8>(a), - 16 => _mm256_srli_epi64::<16>(a), - 32 => _mm256_srli_epi64::<32>(a), - _ => unreachable_unchecked(), - } + match shift { + 2 => _mm256_srli_epi64::<2>(a), + 4 => _mm256_srli_epi64::<4>(a), + 8 => _mm256_srli_epi64::<8>(a), + 16 => _mm256_srli_epi64::<16>(a), + 32 => _mm256_srli_epi64::<32>(a), + // SAFETY: Shift is upheld by caller + _ => unsafe { unreachable_unchecked() }, } } @@ -202,8 +200,7 @@ pub unsafe fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) let cols = input.len() * 8 / rows; assert_eq!(0, cols % 128); assert_eq!(0, rows % 128); - #[allow(unused_unsafe)] - let mut buf = [unsafe { _mm256_setzero_si256() }; 64]; + let mut buf = [_mm256_setzero_si256(); 64]; let in_stride = cols / 8; let out_stride = rows / 8; From 9df3626062234338929bf819a55cb3bc1ba6c650 Mon Sep 17 00:00:00 2001 From: "robinhundt (aider)" <24554122+robinhundt@users.noreply.github.com> Date: Fri, 23 May 2025 17:19:22 +0200 Subject: [PATCH 2/7] refactor: Improve AVX2 transpose clarity and readability --- Cargo.lock | 1 + cryprot-core/Cargo.toml | 1 + cryprot-core/src/transpose/avx2.rs | 357 ++++++++++++++--------------- 3 files changed, 170 insertions(+), 189 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c991410..712a850 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -533,6 +533,7 @@ dependencies = [ "rand_core 0.6.4", "rand_core 0.9.3", "rayon", + "seq-macro", "serde", "subtle", "thiserror", diff --git a/cryprot-core/Cargo.toml b/cryprot-core/Cargo.toml index 182c09d..af78388 100644 --- a/cryprot-core/Cargo.toml +++ b/cryprot-core/Cargo.toml @@ -32,6 +32,7 @@ rand.workspace = true rand_core.workspace = true rand_core_0_6.workspace = true rayon = { workspace = true, optional = true } +seq-macro.workspace = true serde = { workspace = true, features = ["derive"] } subtle.workspace = true thiserror.workspace = true diff --git a/cryprot-core/src/transpose/avx2.rs b/cryprot-core/src/transpose/avx2.rs index 6409349..3cbde4a 100644 --- a/cryprot-core/src/transpose/avx2.rs +++ b/cryprot-core/src/transpose/avx2.rs @@ -1,239 +1,218 @@ //! Implementation of AVX2 BitMatrix transpose based on libOTe. -use std::{arch::x86_64::*, hint::unreachable_unchecked}; +use std::arch::x86_64::*; +use bytemuck::{must_cast_slice, must_cast_slice_mut}; +use seq_macro::seq; + +/// Performs a 2x2 bit transpose operation on two 256-bit vectors representing a +/// 4x128 matrix. #[inline] #[target_feature(enable = "avx2")] -/// Must be called with `matches!(shift, 2 | 4 | 8 | 16 | 32)` -unsafe fn _mm256_slli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i { - debug_assert!( - matches!(shift, 2 | 4 | 8 | 16 | 32), - "Must be called with correct shift" - ); - match shift { - 2 => _mm256_slli_epi64::<2>(a), - 4 => _mm256_slli_epi64::<4>(a), - 8 => _mm256_slli_epi64::<8>(a), - 16 => _mm256_slli_epi64::<16>(a), - 32 => _mm256_slli_epi64::<32>(a), - // SAFETY: Shift is upheld by caller - _ => unsafe { unreachable_unchecked() }, - } +fn transpose_2x2_matrices(x: &mut __m256i, y: &mut __m256i) { + // x = [x_H | x_L] and y = [y_H | y_L] + // u = [y_L | x_L] u is the low 128 bits of x and y + let u = _mm256_permute2x128_si256(*x, *y, 0x20); + // v = [y_H | x_H] v is the high 128 bits of x and y + let v = _mm256_permute2x128_si256(*x, *y, 0x31); + // Shift v by one left so each element in at (i, j) aligns with (i+1, j-1) and + // compute the difference. the row shift i+1 is done by the permute + // instructions before and the column by the sll instruction + let mut diff = _mm256_xor_si256(u, _mm256_slli_epi16(v, 1)); + // select all odd indices of diff and zero out even indices. the idea is to + // calculate the difference of all odd numbered indices j of the even + // numbered row i with the even numbered indices j-1 in row i+1. + // These are precisely the elements in the 2x2 matrices that make up x and y + // that potentially need to be swapped for the transpose if they differ + diff = _mm256_and_si256(diff, _mm256_set1_epi16(0b1010101010101010_u16 as i16)); + // perform the swaps in u, which corresponds the lower bits of x and y by XORing + // the diff + let u = _mm256_xor_si256(u, diff); + // for the bottom row in the 2x2 matrices (the high bits of x and y) we need to + // shift the diff by 1 to the right so it aligns with the even numbered indices + let v = _mm256_xor_si256(v, _mm256_srli_epi16(diff, 1)); + // the permuted 2x2 matrices are split over u and v, with the upper row in u and + // the lower in v. We perform the same permutation as in the beginning, thereby + // writing the 2x2 permuted bits of x and y back + *x = _mm256_permute2x128_si256(u, v, 0x20); + *y = _mm256_permute2x128_si256(u, v, 0x31); } +/// Performs a general bit-level transpose. +/// +/// `SHIFT_AMOUNT` is the constant shift value (e.g., 2, 4, 8, 16, 32) for the +/// intrinsics. `MASK` is the bitmask for the XOR-swap. #[inline] #[target_feature(enable = "avx2")] -/// Must be called with `matches!(shift, 2 | 4 | 8 | 16 | 32)` -unsafe fn _mm256_srli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i { - debug_assert!( - matches!(shift, 2 | 4 | 8 | 16 | 32), - "Must be called with correct shift" - ); - match shift { - 2 => _mm256_srli_epi64::<2>(a), - 4 => _mm256_srli_epi64::<4>(a), - 8 => _mm256_srli_epi64::<8>(a), - 16 => _mm256_srli_epi64::<16>(a), - 32 => _mm256_srli_epi64::<32>(a), - // SAFETY: Shift is upheld by caller - _ => unsafe { unreachable_unchecked() }, - } +fn partial_swap_sub_matrices( + x: &mut __m256i, + y: &mut __m256i, +) { + // calculate the diff of the bits that need to be potentially swapped + let mut diff = _mm256_xor_si256(*x, _mm256_slli_epi64::(*y)); + diff = _mm256_and_si256(diff, _mm256_set1_epi64x(MASK as i64)); + // swap the bits in x by xoring the difference + *x = _mm256_xor_si256(*x, diff); + // and in y + *y = _mm256_xor_si256(*y, _mm256_srli_epi64::(diff)); } -// Transpose a 2^block_size_shift x 2^block_size_shift block within a larger -// matrix Only handles first two rows out of every 2^block_rows_shift rows in -// each block +/// Performs a partial 64x64 bit matrix swap. This is used to swap the rows in +/// the upper right quadrant with those of the lower left in the 128x128 matrix. #[inline] #[target_feature(enable = "avx2")] -unsafe fn avx_transpose_block_iter1( - in_out: *mut __m256i, - block_size_shift: usize, - block_rows_shift: usize, - j: usize, -) { - if j < (1 << block_size_shift) && block_size_shift == 6 { - unsafe { - let x = in_out.add(j / 2); - let y = in_out.add(j / 2 + 32); +fn partial_swap_64x64_matrices(x: &mut __m256i, y: &mut __m256i) { + let out_x = _mm256_unpacklo_epi64(*x, *y); + let out_y = _mm256_unpackhi_epi64(*x, *y); + *x = out_x; + *y = out_y; +} - let out_x = _mm256_unpacklo_epi64(*x, *y); - let out_y = _mm256_unpackhi_epi64(*x, *y); - *x = out_x; - *y = out_y; - return; +/// Transpose a 128x128 bit matrix using AVX2 intrinsics. +/// +/// # Safety +/// AVX2 needs to be enabled. +#[target_feature(enable = "avx2")] +pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) { + // This algorithm implements a bit-transpose of a 128x128 bit matrix using a + // divide-and-conquer algorithm. The idea is that for + // A = [ A B ] + // [ C D ] + // A^T is equal to + // [ A^T C^T ] + // [ B^T D^T ] + // + // We first divide our matrix into 2x2 bit matrices which we transpose at the + // bit level. Then we swap the 2x2 bit matrices to complete a 4x4 + // transpose. We swap the 4x4 bit matrices to complete a 8x8 transpose and so on + // until we swap 64x64 bit matrices and thus complete the intended 128x128 bit + // transpose. + + // Part 1: Specialized 2x2 block transpose transposing individual bits + for chunk in in_out.chunks_exact_mut(2) { + if let [x, y] = chunk { + transpose_2x2_matrices(x, y); + } else { + unreachable!("chunk size is 2") } } - if block_size_shift == 0 || block_size_shift >= 6 || block_rows_shift < 1 { - return; - } - - // Calculate mask for the current block size - let mut mask = (!0u64) << 32; - for k in (block_size_shift as i32..=4).rev() { - mask ^= mask >> (1 << k); - } - - unsafe { - let x = in_out.add(j / 2); - let y = in_out.add(j / 2 + (1 << (block_size_shift - 1))); - - // Special case for 2x2 blocks (block_size_shift == 1) - if block_size_shift == 1 { - let u = _mm256_permute2x128_si256(*x, *y, 0x20); - let v = _mm256_permute2x128_si256(*x, *y, 0x31); - - let mut diff = _mm256_xor_si256(u, _mm256_slli_epi16(v, 1)); - diff = _mm256_and_si256(diff, _mm256_set1_epi16(0b1010101010101010_u16 as i16)); - let u = _mm256_xor_si256(u, diff); - let v = _mm256_xor_si256(v, _mm256_srli_epi16(diff, 1)); - - *x = _mm256_permute2x128_si256(u, v, 0x20); - *y = _mm256_permute2x128_si256(u, v, 0x31); + // Phases 1-5: swap sub-matrices of size 2x2, 4x4, 8x8, 16x16, 32x32 bit + // Using seq_macro to reduce repetition + seq!(N in 1..=5 { + const SHIFT_~N: i32 = 1 << N; + // Our mask selects the part of the sub-matrix that needs to be potentially + // swapped allong the diagonal. The lower 2^SHIFT bits are 0 and the following + // 2^SHIFT bits are 1, repeated to a 64 bit mask + const MASK_~N: u64 = match N { + 1 => mask(0b1100, 4), + 2 => mask(0b11110000, 8), + 3 => mask(0b1111111100000000, 16), + 4 => mask(0b11111111111111110000000000000000, 32), + 5 => 0xffffffff00000000, + _ => unreachable!(), + }; + // The offset between x and y for matrix rows that need to be swapped in terms + // of 256 bit elements. In the first iteration we swap the 2x2 matrices that + // are at positions in_out[i] and in_out[j], so the offset is 1. For 4x4 matrices + // the offset is 2 + const OFFSET~N: usize = 1 << (N - 1); + + for chunk in in_out.chunks_exact_mut(2 * OFFSET~N) { + let (x_chunk, y_chunk) = chunk.split_at_mut(OFFSET~N); + // For larger matrices, and larger offsets, we need to iterate over all + // rows of the sub-matrices + for (x, y) in x_chunk.iter_mut().zip(y_chunk.iter_mut()) { + partial_swap_sub_matrices::(x, y); + } } + }); - let mut diff = _mm256_xor_si256(*x, _mm256_slli_epi64_var_shift(*y, 1 << block_size_shift)); - diff = _mm256_and_si256(diff, _mm256_set1_epi64x(mask as i64)); - *x = _mm256_xor_si256(*x, diff); - *y = _mm256_xor_si256(*y, _mm256_srli_epi64_var_shift(diff, 1 << block_size_shift)); - } -} - -#[inline] // Process a range of rows in the matrix -#[target_feature(enable = "avx2")] -unsafe fn avx_transpose_block_iter2( - in_out: *mut __m256i, - block_size_shift: usize, - block_rows_shift: usize, - n_rows: usize, -) { - let mat_size = 1 << (block_size_shift + 1); + // Phase 6: swap 64x64 bit-matrices therfore completing the 128x128 bit + // transpose + const SHIFT_6: usize = 6; + const OFFSET_6: usize = 1 << (SHIFT_6 - 1); // 32 - for i in (0..n_rows).step_by(mat_size) { - for j in (0..(1 << block_size_shift)).step_by(1 << block_rows_shift) { - unsafe { - avx_transpose_block_iter1(in_out.add(i / 2), block_size_shift, block_rows_shift, j); - } + for chunk in in_out.chunks_exact_mut(2 * OFFSET_6) { + let (x_chunk, y_chunk) = chunk.split_at_mut(OFFSET_6); + for (x, y) in x_chunk.iter_mut().zip(y_chunk.iter_mut()) { + partial_swap_64x64_matrices(x, y); } } } -#[inline] // Main transpose function for blocks within the matrix -#[target_feature(enable = "avx2")] -unsafe fn avx_transpose_block( - in_out: *mut __m256i, - block_size_shift: usize, - mat_size_shift: usize, - block_rows_shift: usize, - mat_rows_shift: usize, -) { - if block_size_shift >= mat_size_shift { - return; - } - - // Process current block size - let total_rows = 1 << (mat_rows_shift + mat_size_shift); - - unsafe { - avx_transpose_block_iter2(in_out, block_size_shift, block_rows_shift, total_rows); +/// Create a u64 bit mask based on the pattern which is repeated to fill the u54 +const fn mask(pattern: u64, pattern_len: u32) -> u64 { + let mut mask = pattern; + let mut current_block_len = pattern_len; - // Recursively process larger blocks - avx_transpose_block( - in_out, - block_size_shift + 1, - mat_size_shift, - block_rows_shift, - mat_rows_shift, - ); + // We keep doubling the effective length of our repeating block + // until it covers 64 bits. + while current_block_len < 64 { + mask = (mask << current_block_len) | mask; + current_block_len *= 2; } -} - -const AVX_BLOCK_SHIFT: usize = 4; -const AVX_BLOCK_SIZE: usize = 1 << AVX_BLOCK_SHIFT; -/// Transpose 128x128 bit matrix using AVX2. -/// -/// # Safety -/// AVX2 needs to be enabled. -#[target_feature(enable = "avx2")] -pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) { - const MAT_SIZE_SHIFT: usize = 7; - unsafe { - let in_out = in_out.as_mut_ptr(); - for i in (0..64).step_by(AVX_BLOCK_SIZE) { - avx_transpose_block( - in_out.add(i), - 1, - MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT, - 1, - AVX_BLOCK_SHIFT + 1 - (MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT), - ); - } - - // Process larger blocks - let block_size_shift = MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT; - - // Special case for full matrix - for i in 0..(1 << (block_size_shift - 1)) { - avx_transpose_block( - in_out.add(i), - block_size_shift, - MAT_SIZE_SHIFT, - block_size_shift, - 0, - ); - } - } + mask } -/// Transpose a bit matrix. +/// Transpose a bit matrix of arbitrary (but constrained) dimensions using AVX2. /// /// # Panics -/// If the input is not divisable by 128. -/// If the number of columns (= input.len() * 8 / 128) is less than 128. +/// If the input is not divisible by 128. +/// If the number of columns (= input.len() * 8 / rows) is less than 128. /// If `input.len() != output.len()` /// /// # Safety /// AVX2 instruction set must be available. #[target_feature(enable = "avx2")] -pub unsafe fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { +pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { assert_eq!(input.len(), output.len()); let cols = input.len() * 8 / rows; - assert_eq!(0, cols % 128); - assert_eq!(0, rows % 128); + assert_eq!( + 0, + cols % 128, + "Number of columns must be a multiple of 128." + ); + assert_eq!(0, rows % 128, "Number of rows must be a multiple of 128."); + assert!(cols >= 128, "Number of columns must be at least 128."); + + // Buffer to hold a single 128x128 bit square (64 __m256i registers = 2048 + // bytes) let mut buf = [_mm256_setzero_si256(); 64]; - let in_stride = cols / 8; - let out_stride = rows / 8; + let in_stride = cols / 8; // Stride in bytes for input rows + let out_stride = rows / 8; // Stride in bytes for output rows - // Number of 128x128 bit squares + // Number of 128x128 bit squares in rows and columns let r_main = rows / 128; let c_main = cols / 128; + // Iterate through each 128x128 bit square in the matrix + // Row block index for i in 0..r_main { + // Column block index for j in 0..c_main { - // Process each 128x128 bit square - unsafe { - let src_ptr = input.as_ptr().add(i * 128 * in_stride + j * 16); - - let buf_u8_ptr = buf.as_mut_ptr() as *mut u8; - - // Copy 128 rows into buffer - for k in 0..128 { - let src_row = src_ptr.add(k * in_stride); - std::ptr::copy_nonoverlapping(src_row, buf_u8_ptr.add(k * 16), 16); - } + // Load 128x128 bit sub-matrix into `buf` + let input_block_start_byte_idx = i * 128 * in_stride + j * 16; + let buf_as_bytes: &mut [u8] = must_cast_slice_mut(&mut buf); + + for k in 0..128 { + let src_slice = &input[input_block_start_byte_idx + k * in_stride + ..input_block_start_byte_idx + k * in_stride + 16]; + buf_as_bytes[k * 16..(k + 1) * 16].copy_from_slice(src_slice); } - // Transpose the 128x128 bit square + + // Transpose the 128x128 bit sub-matrix in `buf` avx_transpose128x128(&mut buf); - unsafe { - // needs to be recreated because prev &mut borrow invalidates ptr - let buf_u8_ptr = buf.as_mut_ptr() as *mut u8; - // Copy transposed data to output - let dst_ptr = output.as_mut_ptr().add(j * 128 * out_stride + i * 16); - for k in 0..128 { - let dst_row = dst_ptr.add(k * out_stride); - std::ptr::copy_nonoverlapping(buf_u8_ptr.add(k * 16), dst_row, 16); - } + // Copy the transposed data from `buf` to the output slice. + let output_block_start_byte_idx = j * 128 * out_stride + i * 16; + let buf_as_bytes: &[u8] = must_cast_slice(&buf); // Now read-only + + for k in 0..128 { + let src_slice = &buf_as_bytes[k * 16..(k + 1) * 16]; + let dst_slice = &mut output[output_block_start_byte_idx + k * out_stride + ..output_block_start_byte_idx + k * out_stride + 16]; + dst_slice.copy_from_slice(src_slice); } } } From 1588d2583c901f469b66ac8c4555b13d427c2c82 Mon Sep 17 00:00:00 2001 From: robinhundt <24554122+robinhundt@users.noreply.github.com> Date: Tue, 27 May 2025 14:45:58 +0200 Subject: [PATCH 3/7] avx2 handle whole cache line --- .gitignore | 3 +- cryprot-core/src/transpose/avx2.rs | 130 ++++++++++++++++++++++++----- 2 files changed, 112 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index 1e83165..4962aba 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ target profile*.json* perf.data* - -.flyio +*.annotation diff --git a/cryprot-core/src/transpose/avx2.rs b/cryprot-core/src/transpose/avx2.rs index 3cbde4a..f5046c0 100644 --- a/cryprot-core/src/transpose/avx2.rs +++ b/cryprot-core/src/transpose/avx2.rs @@ -1,5 +1,5 @@ //! Implementation of AVX2 BitMatrix transpose based on libOTe. -use std::arch::x86_64::*; +use std::{arch::x86_64::*, cmp}; use bytemuck::{must_cast_slice, must_cast_slice_mut}; use seq_macro::seq; @@ -178,7 +178,7 @@ pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { // Buffer to hold a single 128x128 bit square (64 __m256i registers = 2048 // bytes) - let mut buf = [_mm256_setzero_si256(); 64]; + let mut buf = [_mm256_setzero_si256(); 64 * 4]; let in_stride = cols / 8; // Stride in bytes for input rows let out_stride = rows / 8; // Stride in bytes for output rows @@ -190,30 +190,80 @@ pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { // Row block index for i in 0..r_main { // Column block index - for j in 0..c_main { - // Load 128x128 bit sub-matrix into `buf` - let input_block_start_byte_idx = i * 128 * in_stride + j * 16; + let mut j = 0; + while j < c_main { + let input_offset = i * 128 * in_stride + j * 16; + let curr_addr = input[input_offset..].as_ptr().addr(); + let next_cache_line_addr = (curr_addr + 1).next_multiple_of(64); // cache line size + let blocks_in_cache_line = (next_cache_line_addr - curr_addr) / 16; + + let remaining_blocks_in_cache_line = if blocks_in_cache_line == 0 { + // will cross over a cache line, but if the blocks are not 16-byte aligned, this + // is the best we can do + 4 + } else { + blocks_in_cache_line + }; + // Ensure we don't read OOB of the input + let remaining_blocks_in_cache_line = + cmp::min(remaining_blocks_in_cache_line, c_main - j); + let buf_as_bytes: &mut [u8] = must_cast_slice_mut(&mut buf); - for k in 0..128 { - let src_slice = &input[input_block_start_byte_idx + k * in_stride - ..input_block_start_byte_idx + k * in_stride + 16]; - buf_as_bytes[k * 16..(k + 1) * 16].copy_from_slice(src_slice); + // The loading loop loads the input data into the buf. By using a macro and + // matching on 4 blocks in a cache line (each row in a block is 16 bytes, so the + // rows 4 consecutive blocks are 64 bytes long) the optimizer uses a loop + // unrolled version for this case. + macro_rules! loading_loop { + ($remaining_blocks_in_cache_line:expr) => { + for k in 0..128 { + let src_slice = &input[input_offset + k * in_stride + ..input_offset + k * in_stride + 16 * remaining_blocks_in_cache_line]; + + for block in 0..remaining_blocks_in_cache_line { + buf_as_bytes[block * 2048 + k * 16..block * 2048 + (k + 1) * 16] + .copy_from_slice(&src_slice[block * 16..(block + 1) * 16]); + } + } + }; + } + + // This gets optimized to the unrolled loop for the default case of 4 blocks + match remaining_blocks_in_cache_line { + 4 => loading_loop!(4), + #[allow(unused_variables)] // false positive + other => loading_loop!(other), } - // Transpose the 128x128 bit sub-matrix in `buf` - avx_transpose128x128(&mut buf); + for block in 0..remaining_blocks_in_cache_line { + avx_transpose128x128((&mut buf[block * 64..(block + 1) * 64]).try_into().unwrap()); + } - // Copy the transposed data from `buf` to the output slice. - let output_block_start_byte_idx = j * 128 * out_stride + i * 16; - let buf_as_bytes: &[u8] = must_cast_slice(&buf); // Now read-only + let mut output_offset = j * 128 * out_stride + i * 16; + let buf_as_bytes: &[u8] = must_cast_slice(&buf); - for k in 0..128 { - let src_slice = &buf_as_bytes[k * 16..(k + 1) * 16]; - let dst_slice = &mut output[output_block_start_byte_idx + k * out_stride - ..output_block_start_byte_idx + k * out_stride + 16]; - dst_slice.copy_from_slice(src_slice); + if out_stride == 16 { + // if the out_stride is 16 bytes, the transposed sub-matrices are in contigous + // memory in the output, so we can use a single copy_from_slice. This is + // especially helpfule for the case of transposing a 128xl matrix as done in OT + // extension. + let dst_slice = &mut output + [output_offset..output_offset + 16 * 128 * remaining_blocks_in_cache_line]; + dst_slice.copy_from_slice(&buf_as_bytes[..remaining_blocks_in_cache_line * 2048]); + } else { + for block in 0..remaining_blocks_in_cache_line { + for k in 0..128 { + let src_slice = + &buf_as_bytes[block * 2048 + k * 16..block * 2048 + (k + 1) * 16]; + let dst_slice = &mut output + [output_offset + k * out_stride..output_offset + k * out_stride + 16]; + dst_slice.copy_from_slice(src_slice); + } + output_offset += 128 * out_stride; + } } + + j += remaining_blocks_in_cache_line; } } } @@ -268,4 +318,46 @@ mod tests { assert_eq!(sse_transposed, avx_transposed); } + + #[test] + fn test_avx_transpose_unaligned_data() { + let rows = 128 * 2; + let cols = 128 * 2; + let mut v = vec![0_u8; rows * (cols + 128) / 8]; + StdRng::seed_from_u64(42).fill_bytes(&mut v); + + let v = { + let addr = v.as_ptr().addr(); + let offset = addr.next_multiple_of(3) - addr; + &v[offset..rows * cols / 8] + }; + assert_eq!(0, v.as_ptr().addr() % 3); + // allocate out bufs with same dims + let mut avx_transposed = v.to_owned(); + let mut sse_transposed = v.to_owned(); + + unsafe { + transpose_bitmatrix(&v, &mut avx_transposed, rows); + } + crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows); + + assert_eq!(sse_transposed, avx_transposed); + } + + #[test] + fn test_avx_transpose_larger_cols_divisable_by_4_times_128() { + let rows = 128; + let cols = 128 * 8; + let mut v = vec![0_u8; rows * cols / 8]; + StdRng::seed_from_u64(42).fill_bytes(&mut v); + + let mut avx_transposed = v.clone(); + let mut sse_transposed = v.clone(); + unsafe { + transpose_bitmatrix(&v, &mut avx_transposed, rows); + } + crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows); + + assert_eq!(sse_transposed, avx_transposed); + } } From a9316aca5a6727ff6d528a332cc80fc00a9d0121 Mon Sep 17 00:00:00 2001 From: robinhundt <24554122+robinhundt@users.noreply.github.com> Date: Fri, 6 Jun 2025 11:09:43 +0200 Subject: [PATCH 4/7] cryprot-core: avx2 transpose handle rest cols Adds handling for number of columns that is not divisable by 128. --- cryprot-core/src/transpose/avx2.rs | 94 +++++++++++++++++++++++++----- 1 file changed, 81 insertions(+), 13 deletions(-) diff --git a/cryprot-core/src/transpose/avx2.rs b/cryprot-core/src/transpose/avx2.rs index f5046c0..3387aae 100644 --- a/cryprot-core/src/transpose/avx2.rs +++ b/cryprot-core/src/transpose/avx2.rs @@ -155,28 +155,30 @@ const fn mask(pattern: u64, pattern_len: u32) -> u64 { mask } -/// Transpose a bit matrix of arbitrary (but constrained) dimensions using AVX2. +/// Transpose a bit matrix using AVX2. +/// +/// This implementation is specifically tuned for transposing `128 x l` matrices +/// as done in OT protocols. Performance might be better if `input` is 16-byte +/// aligned and the number of columns is divisable by 512 on systems with +/// 64-byte cache lines. /// /// # Panics -/// If the input is not divisible by 128. -/// If the number of columns (= input.len() * 8 / rows) is less than 128. /// If `input.len() != output.len()` +/// If the number of rows is less than 128. +/// If the number of rows is not divisable by 128. +/// If the number of columns (= input.len() * 8 / rows) is not divisable by 8. /// /// # Safety /// AVX2 instruction set must be available. #[target_feature(enable = "avx2")] pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { assert_eq!(input.len(), output.len()); - let cols = input.len() * 8 / rows; - assert_eq!( - 0, - cols % 128, - "Number of columns must be a multiple of 128." - ); + assert!(rows >= 128, "Number of rows must be >= 128."); assert_eq!(0, rows % 128, "Number of rows must be a multiple of 128."); - assert!(cols >= 128, "Number of columns must be at least 128."); + let cols = input.len() * 8 / rows; + assert_eq!(0, cols % 8, "Number of columns must be a multiple of 8."); - // Buffer to hold a single 128x128 bit square (64 __m256i registers = 2048 + // Buffer to hold a 4 128x128 bit squares (64 * 4 __m256i registers = 2048 * 4 // bytes) let mut buf = [_mm256_setzero_si256(); 64 * 4]; let in_stride = cols / 8; // Stride in bytes for input rows @@ -185,6 +187,7 @@ pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { // Number of 128x128 bit squares in rows and columns let r_main = rows / 128; let c_main = cols / 128; + let c_rest = cols % 128; // Iterate through each 128x128 bit square in the matrix // Row block index @@ -236,7 +239,11 @@ pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { } for block in 0..remaining_blocks_in_cache_line { - avx_transpose128x128((&mut buf[block * 64..(block + 1) * 64]).try_into().unwrap()); + avx_transpose128x128( + (&mut buf[block * 64..(block + 1) * 64]) + .try_into() + .expect("slice has length 64"), + ); } let mut output_offset = j * 128 * out_stride + i * 16; @@ -265,6 +272,50 @@ pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { j += remaining_blocks_in_cache_line; } + + if c_rest > 0 { + handle_rest_cols(input, output, &mut buf, in_stride, out_stride, c_rest, i, j); + } + } +} + +// Inline never to reduce code size of main method. +#[inline(never)] +#[target_feature(enable = "avx2")] +fn handle_rest_cols( + input: &[u8], + output: &mut [u8], + buf: &mut [__m256i; 256], + in_stride: usize, + out_stride: usize, + c_rest: usize, + i: usize, + j: usize, +) { + let input_offset = i * 128 * in_stride + j * 16; + let remaining_cols_bytes = c_rest / 8; + buf[0..64].fill(_mm256_setzero_si256()); + let buf_as_bytes: &mut [u8] = must_cast_slice_mut(buf); + + for k in 0..128 { + let src_row_offset = input_offset + k * in_stride; + let src_slice = &input[src_row_offset..src_row_offset + remaining_cols_bytes]; + // we use 16 because we still transpose a 128x128 matrix, of which only a part + // is filled + let buf_offset = k * 16; + buf_as_bytes[buf_offset..buf_offset + remaining_cols_bytes].copy_from_slice(&src_slice); + } + + avx_transpose128x128((&mut buf[..64]).try_into().expect("slice has length 64")); + + let output_offset = j * 128 * out_stride + i * 16; + let buf_as_bytes: &[u8] = must_cast_slice(&*buf); + + for k in 0..c_rest { + let src_slice = &buf_as_bytes[k * 16..(k + 1) * 16]; + let dst_slice = + &mut output[output_offset + k * out_stride..output_offset + k * out_stride + 16]; + dst_slice.copy_from_slice(src_slice); } } @@ -329,7 +380,7 @@ mod tests { let v = { let addr = v.as_ptr().addr(); let offset = addr.next_multiple_of(3) - addr; - &v[offset..rows * cols / 8] + &v[offset..offset + rows * cols / 8] }; assert_eq!(0, v.as_ptr().addr() % 3); // allocate out bufs with same dims @@ -360,4 +411,21 @@ mod tests { assert_eq!(sse_transposed, avx_transposed); } + + #[test] + fn test_avx_transpose_larger_cols_divisable_by_8() { + let rows = 128; + let cols = 128 + 32; + let mut v = vec![0_u8; rows * cols / 8]; + StdRng::seed_from_u64(42).fill_bytes(&mut v); + + let mut avx_transposed = v.clone(); + let mut sse_transposed = v.clone(); + unsafe { + transpose_bitmatrix(&v, &mut avx_transposed, rows); + } + crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows); + + assert_eq!(sse_transposed, avx_transposed); + } } From 92f7e4af81667deefa0fe8125bba2d848d264fca Mon Sep 17 00:00:00 2001 From: robinhundt <24554122+robinhundt@users.noreply.github.com> Date: Fri, 6 Jun 2025 11:21:20 +0200 Subject: [PATCH 5/7] Fix new compiler warning --- cryprot-net/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cryprot-net/src/lib.rs b/cryprot-net/src/lib.rs index e130e3a..621c819 100644 --- a/cryprot-net/src/lib.rs +++ b/cryprot-net/src/lib.rs @@ -469,7 +469,7 @@ impl SendStreamBytes { self.inner.close().await.map_err(StreamError::Close) } - pub fn as_stream(&mut self) -> SendStreamTemp { + pub fn as_stream(&mut self) -> SendStreamTemp<'_, T> { let framed_send = default_codec().new_write(self); SymmetricallyFramed::new(framed_send, Bincode::default()) } @@ -517,7 +517,7 @@ fn trace_poll(p: Poll>) -> Poll> { } impl ReceiveStreamBytes { - pub fn as_stream(&mut self) -> ReceiveStreamTemp { + pub fn as_stream(&mut self) -> ReceiveStreamTemp<'_, T> { let framed_read = default_codec().new_read(self); SymmetricallyFramed::new(framed_read, Bincode::default()) } From f674a09d7ff94d3e9b16ff47fd494565d03c7f40 Mon Sep 17 00:00:00 2001 From: robinhundt <24554122+robinhundt@users.noreply.github.com> Date: Fri, 6 Jun 2025 11:21:20 +0200 Subject: [PATCH 6/7] CI: Update nightly version --- .github/workflows/pull_request.yml | 6 +++--- .github/workflows/push.yml | 2 +- .github/workflows/rustdoc.yml | 2 +- .vscode/settings.json | 5 +---- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 95ec339..a5c82ec 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -23,7 +23,7 @@ jobs: uses: dtolnay/rust-toolchain@master id: toolchain with: - toolchain: nightly-2025-02-14 + toolchain: nightly-2025-05-30 components: "rustfmt, miri" - name: Override default toolchain run: rustup override set ${{steps.toolchain.outputs.name}} @@ -52,7 +52,7 @@ jobs: uses: dtolnay/rust-toolchain@master id: toolchain with: - toolchain: nightly-2025-02-14 + toolchain: nightly-2025-05-30 components: "clippy, rustfmt" - name: Override default toolchain run: rustup override set ${{steps.toolchain.outputs.name}} @@ -80,7 +80,7 @@ jobs: uses: dtolnay/rust-toolchain@master id: toolchain with: - toolchain: nightly-2025-02-14 + toolchain: nightly-2025-05-30 components: rust-docs - name: Override default toolchain run: rustup override set ${{steps.toolchain.outputs.name}} diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 145189c..26abed7 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -17,7 +17,7 @@ jobs: uses: dtolnay/rust-toolchain@master id: toolchain with: - toolchain: nightly-2025-02-14 + toolchain: nightly-2025-05-30 - name: Override default toolchain run: rustup override set ${{steps.toolchain.outputs.name}} - run: cargo --version diff --git a/.github/workflows/rustdoc.yml b/.github/workflows/rustdoc.yml index 32fc847..22c3818 100644 --- a/.github/workflows/rustdoc.yml +++ b/.github/workflows/rustdoc.yml @@ -21,7 +21,7 @@ jobs: uses: dtolnay/rust-toolchain@master id: toolchain with: - toolchain: nightly-2025-02-14 + toolchain: nightly-2025-05-30 components: rust-docs - name: Override default toolchain run: rustup override set ${{steps.toolchain.outputs.name}} diff --git a/.vscode/settings.json b/.vscode/settings.json index 852d76c..0e8a88d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,8 +1,5 @@ { "rust-analyzer.rustfmt.extraArgs": [ "+nightly" - ], - "rust-analyzer.cargo.extraEnv": { - "RUSTUP_TOOLCHAIN": "nightly" - }, + ] } \ No newline at end of file From 91363af97a36b81cf0ebf76d09ea824a9e21e427 Mon Sep 17 00:00:00 2001 From: robinhundt <24554122+robinhundt@users.noreply.github.com> Date: Fri, 6 Jun 2025 11:53:10 +0200 Subject: [PATCH 7/7] cryprot-core: Fix clippy errors --- cryprot-core/src/transpose/avx2.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cryprot-core/src/transpose/avx2.rs b/cryprot-core/src/transpose/avx2.rs index 3387aae..cae75f7 100644 --- a/cryprot-core/src/transpose/avx2.rs +++ b/cryprot-core/src/transpose/avx2.rs @@ -115,6 +115,7 @@ pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) { // of 256 bit elements. In the first iteration we swap the 2x2 matrices that // are at positions in_out[i] and in_out[j], so the offset is 1. For 4x4 matrices // the offset is 2 + #[allow(clippy::eq_op)] // false positive due to use of seq! const OFFSET~N: usize = 1 << (N - 1); for chunk in in_out.chunks_exact_mut(2 * OFFSET~N) { @@ -282,6 +283,7 @@ pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { // Inline never to reduce code size of main method. #[inline(never)] #[target_feature(enable = "avx2")] +#[allow(clippy::too_many_arguments)] fn handle_rest_cols( input: &[u8], output: &mut [u8], @@ -303,7 +305,7 @@ fn handle_rest_cols( // we use 16 because we still transpose a 128x128 matrix, of which only a part // is filled let buf_offset = k * 16; - buf_as_bytes[buf_offset..buf_offset + remaining_cols_bytes].copy_from_slice(&src_slice); + buf_as_bytes[buf_offset..buf_offset + remaining_cols_bytes].copy_from_slice(src_slice); } avx_transpose128x128((&mut buf[..64]).try_into().expect("slice has length 64"));