diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 95ec339..a5c82ec 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -23,7 +23,7 @@ jobs: uses: dtolnay/rust-toolchain@master id: toolchain with: - toolchain: nightly-2025-02-14 + toolchain: nightly-2025-05-30 components: "rustfmt, miri" - name: Override default toolchain run: rustup override set ${{steps.toolchain.outputs.name}} @@ -52,7 +52,7 @@ jobs: uses: dtolnay/rust-toolchain@master id: toolchain with: - toolchain: nightly-2025-02-14 + toolchain: nightly-2025-05-30 components: "clippy, rustfmt" - name: Override default toolchain run: rustup override set ${{steps.toolchain.outputs.name}} @@ -80,7 +80,7 @@ jobs: uses: dtolnay/rust-toolchain@master id: toolchain with: - toolchain: nightly-2025-02-14 + toolchain: nightly-2025-05-30 components: rust-docs - name: Override default toolchain run: rustup override set ${{steps.toolchain.outputs.name}} diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 145189c..26abed7 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -17,7 +17,7 @@ jobs: uses: dtolnay/rust-toolchain@master id: toolchain with: - toolchain: nightly-2025-02-14 + toolchain: nightly-2025-05-30 - name: Override default toolchain run: rustup override set ${{steps.toolchain.outputs.name}} - run: cargo --version diff --git a/.github/workflows/rustdoc.yml b/.github/workflows/rustdoc.yml index 32fc847..22c3818 100644 --- a/.github/workflows/rustdoc.yml +++ b/.github/workflows/rustdoc.yml @@ -21,7 +21,7 @@ jobs: uses: dtolnay/rust-toolchain@master id: toolchain with: - toolchain: nightly-2025-02-14 + toolchain: nightly-2025-05-30 components: rust-docs - name: Override default toolchain run: rustup override set ${{steps.toolchain.outputs.name}} diff --git a/.gitignore b/.gitignore index 1e83165..4962aba 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ target profile*.json* perf.data* - -.flyio +*.annotation diff --git a/.vscode/settings.json b/.vscode/settings.json index 852d76c..0e8a88d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,8 +1,5 @@ { "rust-analyzer.rustfmt.extraArgs": [ "+nightly" - ], - "rust-analyzer.cargo.extraEnv": { - "RUSTUP_TOOLCHAIN": "nightly" - }, + ] } \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index c991410..712a850 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -533,6 +533,7 @@ dependencies = [ "rand_core 0.6.4", "rand_core 0.9.3", "rayon", + "seq-macro", "serde", "subtle", "thiserror", diff --git a/cryprot-core/Cargo.toml b/cryprot-core/Cargo.toml index 182c09d..af78388 100644 --- a/cryprot-core/Cargo.toml +++ b/cryprot-core/Cargo.toml @@ -32,6 +32,7 @@ rand.workspace = true rand_core.workspace = true rand_core_0_6.workspace = true rayon = { workspace = true, optional = true } +seq-macro.workspace = true serde = { workspace = true, features = ["derive"] } subtle.workspace = true thiserror.workspace = true diff --git a/cryprot-core/src/alloc.rs b/cryprot-core/src/alloc.rs index 633f634..ca62376 100644 --- a/cryprot-core/src/alloc.rs +++ b/cryprot-core/src/alloc.rs @@ -51,10 +51,7 @@ impl HugePageMemory { // new_len <= self.capacity // self[len..new_len] is initialized either because of Self::zeroed // or with data written to it. - #[allow(unused_unsafe)] - unsafe { - self.len = new_len; - } + self.len = new_len; } } diff --git a/cryprot-core/src/block/gf128.rs b/cryprot-core/src/block/gf128.rs index 0f0d1eb..dfc3f63 100644 --- a/cryprot-core/src/block/gf128.rs +++ b/cryprot-core/src/block/gf128.rs @@ -104,11 +104,9 @@ mod clmul { #[target_feature(enable = "pclmulqdq")] #[inline] - pub unsafe fn gf128_mul(a: __m128i, b: __m128i) -> __m128i { - unsafe { - let (low, high) = clmul128(a, b); - gf128_reduce(low, high) - } + pub fn gf128_mul(a: __m128i, b: __m128i) -> __m128i { + let (low, high) = clmul128(a, b); + gf128_reduce(low, high) } /// Carry-less multiply of two 128-bit numbers. @@ -116,47 +114,39 @@ mod clmul { /// Return (low, high) bits #[target_feature(enable = "pclmulqdq")] #[inline] - pub unsafe fn clmul128(a: __m128i, b: __m128i) -> (__m128i, __m128i) { - // This is currently needed because we run the nightly version of - // clippy where this is an unused unsafe because the used - // intrinsincs have been marked safe on nightly but not yet on - // stable. - #[allow(unused_unsafe)] - unsafe { - // NOTE: I tried using karatsuba but it was slightly slower than the naive - // multiplication - let ab_low = _mm_clmulepi64_si128::<0x00>(a, b); - let ab_high = _mm_clmulepi64_si128::<0x11>(a, b); - let ab_lohi1 = _mm_clmulepi64_si128::<0x01>(a, b); - let ab_lohi2 = _mm_clmulepi64_si128::<0x10>(a, b); - let ab_mid = _mm_xor_si128(ab_lohi1, ab_lohi2); - let low = _mm_xor_si128(ab_low, _mm_slli_si128::<8>(ab_mid)); - let high = _mm_xor_si128(ab_high, _mm_srli_si128::<8>(ab_mid)); - (low, high) - } + pub fn clmul128(a: __m128i, b: __m128i) -> (__m128i, __m128i) { + // NOTE: I tried using karatsuba but it was slightly slower than the naive + // multiplication + let ab_low = _mm_clmulepi64_si128::<0x00>(a, b); + let ab_high = _mm_clmulepi64_si128::<0x11>(a, b); + let ab_lohi1 = _mm_clmulepi64_si128::<0x01>(a, b); + let ab_lohi2 = _mm_clmulepi64_si128::<0x10>(a, b); + let ab_mid = _mm_xor_si128(ab_lohi1, ab_lohi2); + let low = _mm_xor_si128(ab_low, _mm_slli_si128::<8>(ab_mid)); + let high = _mm_xor_si128(ab_high, _mm_srli_si128::<8>(ab_mid)); + (low, high) } #[target_feature(enable = "pclmulqdq")] #[inline] - pub unsafe fn gf128_reduce(mut low: __m128i, mut high: __m128i) -> __m128i { + pub fn gf128_reduce(mut low: __m128i, mut high: __m128i) -> __m128i { // NOTE: I tried a sse shift based reduction but it was slower than the clmul // implementation - unsafe { - let modulus = [MOD, 0]; - let modulus = _mm_loadu_si64(modulus.as_ptr().cast()); + let modulus = [MOD, 0]; + // SAFETY: Ptr to modulus is valid and pclmulqdq implies sse2 is enabled + let modulus = unsafe { _mm_loadu_si64(modulus.as_ptr().cast()) }; - let tmp = _mm_clmulepi64_si128::<0x01>(high, modulus); - let tmp_shifted = _mm_slli_si128::<8>(tmp); - low = _mm_xor_si128(low, tmp_shifted); - high = _mm_xor_si128(high, tmp_shifted); + let tmp = _mm_clmulepi64_si128::<0x01>(high, modulus); + let tmp_shifted = _mm_slli_si128::<8>(tmp); + low = _mm_xor_si128(low, tmp_shifted); + high = _mm_xor_si128(high, tmp_shifted); - // reduce overflow - let tmp = _mm_clmulepi64_si128::<0x01>(tmp, modulus); - low = _mm_xor_si128(low, tmp); + // reduce overflow + let tmp = _mm_clmulepi64_si128::<0x01>(tmp, modulus); + low = _mm_xor_si128(low, tmp); - let tmp = _mm_clmulepi64_si128::<0x00>(high, modulus); - _mm_xor_si128(low, tmp) - } + let tmp = _mm_clmulepi64_si128::<0x00>(high, modulus); + _mm_xor_si128(low, tmp) } #[cfg(all(test, target_feature = "pclmulqdq"))] diff --git a/cryprot-core/src/transpose/avx2.rs b/cryprot-core/src/transpose/avx2.rs index 5a32716..cae75f7 100644 --- a/cryprot-core/src/transpose/avx2.rs +++ b/cryprot-core/src/transpose/avx2.rs @@ -1,244 +1,323 @@ //! Implementation of AVX2 BitMatrix transpose based on libOTe. -use std::{arch::x86_64::*, hint::unreachable_unchecked}; +use std::{arch::x86_64::*, cmp}; -#[inline] -#[target_feature(enable = "avx2")] -/// Must be called with `matches!(shift, 2 | 4 | 8 | 16 | 32)` -unsafe fn _mm256_slli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i { - debug_assert!( - matches!(shift, 2 | 4 | 8 | 16 | 32), - "Must be called with correct shift" - ); - unsafe { - match shift { - 2 => _mm256_slli_epi64::<2>(a), - 4 => _mm256_slli_epi64::<4>(a), - 8 => _mm256_slli_epi64::<8>(a), - 16 => _mm256_slli_epi64::<16>(a), - 32 => _mm256_slli_epi64::<32>(a), - _ => unreachable_unchecked(), - } - } -} +use bytemuck::{must_cast_slice, must_cast_slice_mut}; +use seq_macro::seq; +/// Performs a 2x2 bit transpose operation on two 256-bit vectors representing a +/// 4x128 matrix. #[inline] #[target_feature(enable = "avx2")] -/// Must be called with `matches!(shift, 2 | 4 | 8 | 16 | 32)` -unsafe fn _mm256_srli_epi64_var_shift(a: __m256i, shift: usize) -> __m256i { - debug_assert!( - matches!(shift, 2 | 4 | 8 | 16 | 32), - "Must be called with correct shift" - ); - unsafe { - match shift { - 2 => _mm256_srli_epi64::<2>(a), - 4 => _mm256_srli_epi64::<4>(a), - 8 => _mm256_srli_epi64::<8>(a), - 16 => _mm256_srli_epi64::<16>(a), - 32 => _mm256_srli_epi64::<32>(a), - _ => unreachable_unchecked(), - } - } +fn transpose_2x2_matrices(x: &mut __m256i, y: &mut __m256i) { + // x = [x_H | x_L] and y = [y_H | y_L] + // u = [y_L | x_L] u is the low 128 bits of x and y + let u = _mm256_permute2x128_si256(*x, *y, 0x20); + // v = [y_H | x_H] v is the high 128 bits of x and y + let v = _mm256_permute2x128_si256(*x, *y, 0x31); + // Shift v by one left so each element in at (i, j) aligns with (i+1, j-1) and + // compute the difference. the row shift i+1 is done by the permute + // instructions before and the column by the sll instruction + let mut diff = _mm256_xor_si256(u, _mm256_slli_epi16(v, 1)); + // select all odd indices of diff and zero out even indices. the idea is to + // calculate the difference of all odd numbered indices j of the even + // numbered row i with the even numbered indices j-1 in row i+1. + // These are precisely the elements in the 2x2 matrices that make up x and y + // that potentially need to be swapped for the transpose if they differ + diff = _mm256_and_si256(diff, _mm256_set1_epi16(0b1010101010101010_u16 as i16)); + // perform the swaps in u, which corresponds the lower bits of x and y by XORing + // the diff + let u = _mm256_xor_si256(u, diff); + // for the bottom row in the 2x2 matrices (the high bits of x and y) we need to + // shift the diff by 1 to the right so it aligns with the even numbered indices + let v = _mm256_xor_si256(v, _mm256_srli_epi16(diff, 1)); + // the permuted 2x2 matrices are split over u and v, with the upper row in u and + // the lower in v. We perform the same permutation as in the beginning, thereby + // writing the 2x2 permuted bits of x and y back + *x = _mm256_permute2x128_si256(u, v, 0x20); + *y = _mm256_permute2x128_si256(u, v, 0x31); } -// Transpose a 2^block_size_shift x 2^block_size_shift block within a larger -// matrix Only handles first two rows out of every 2^block_rows_shift rows in -// each block +/// Performs a general bit-level transpose. +/// +/// `SHIFT_AMOUNT` is the constant shift value (e.g., 2, 4, 8, 16, 32) for the +/// intrinsics. `MASK` is the bitmask for the XOR-swap. #[inline] #[target_feature(enable = "avx2")] -unsafe fn avx_transpose_block_iter1( - in_out: *mut __m256i, - block_size_shift: usize, - block_rows_shift: usize, - j: usize, +fn partial_swap_sub_matrices( + x: &mut __m256i, + y: &mut __m256i, ) { - if j < (1 << block_size_shift) && block_size_shift == 6 { - unsafe { - let x = in_out.add(j / 2); - let y = in_out.add(j / 2 + 32); - - let out_x = _mm256_unpacklo_epi64(*x, *y); - let out_y = _mm256_unpackhi_epi64(*x, *y); - *x = out_x; - *y = out_y; - return; - } - } - - if block_size_shift == 0 || block_size_shift >= 6 || block_rows_shift < 1 { - return; - } - - // Calculate mask for the current block size - let mut mask = (!0u64) << 32; - for k in (block_size_shift as i32..=4).rev() { - mask ^= mask >> (1 << k); - } - - unsafe { - let x = in_out.add(j / 2); - let y = in_out.add(j / 2 + (1 << (block_size_shift - 1))); - - // Special case for 2x2 blocks (block_size_shift == 1) - if block_size_shift == 1 { - let u = _mm256_permute2x128_si256(*x, *y, 0x20); - let v = _mm256_permute2x128_si256(*x, *y, 0x31); - - let mut diff = _mm256_xor_si256(u, _mm256_slli_epi16(v, 1)); - diff = _mm256_and_si256(diff, _mm256_set1_epi16(0b1010101010101010_u16 as i16)); - let u = _mm256_xor_si256(u, diff); - let v = _mm256_xor_si256(v, _mm256_srli_epi16(diff, 1)); - - *x = _mm256_permute2x128_si256(u, v, 0x20); - *y = _mm256_permute2x128_si256(u, v, 0x31); - } - - let mut diff = _mm256_xor_si256(*x, _mm256_slli_epi64_var_shift(*y, 1 << block_size_shift)); - diff = _mm256_and_si256(diff, _mm256_set1_epi64x(mask as i64)); - *x = _mm256_xor_si256(*x, diff); - *y = _mm256_xor_si256(*y, _mm256_srli_epi64_var_shift(diff, 1 << block_size_shift)); - } + // calculate the diff of the bits that need to be potentially swapped + let mut diff = _mm256_xor_si256(*x, _mm256_slli_epi64::(*y)); + diff = _mm256_and_si256(diff, _mm256_set1_epi64x(MASK as i64)); + // swap the bits in x by xoring the difference + *x = _mm256_xor_si256(*x, diff); + // and in y + *y = _mm256_xor_si256(*y, _mm256_srli_epi64::(diff)); } -#[inline] // Process a range of rows in the matrix +/// Performs a partial 64x64 bit matrix swap. This is used to swap the rows in +/// the upper right quadrant with those of the lower left in the 128x128 matrix. +#[inline] #[target_feature(enable = "avx2")] -unsafe fn avx_transpose_block_iter2( - in_out: *mut __m256i, - block_size_shift: usize, - block_rows_shift: usize, - n_rows: usize, -) { - let mat_size = 1 << (block_size_shift + 1); - - for i in (0..n_rows).step_by(mat_size) { - for j in (0..(1 << block_size_shift)).step_by(1 << block_rows_shift) { - unsafe { - avx_transpose_block_iter1(in_out.add(i / 2), block_size_shift, block_rows_shift, j); - } - } - } +fn partial_swap_64x64_matrices(x: &mut __m256i, y: &mut __m256i) { + let out_x = _mm256_unpacklo_epi64(*x, *y); + let out_y = _mm256_unpackhi_epi64(*x, *y); + *x = out_x; + *y = out_y; } -#[inline] // Main transpose function for blocks within the matrix +/// Transpose a 128x128 bit matrix using AVX2 intrinsics. +/// +/// # Safety +/// AVX2 needs to be enabled. #[target_feature(enable = "avx2")] -unsafe fn avx_transpose_block( - in_out: *mut __m256i, - block_size_shift: usize, - mat_size_shift: usize, - block_rows_shift: usize, - mat_rows_shift: usize, -) { - if block_size_shift >= mat_size_shift { - return; +pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) { + // This algorithm implements a bit-transpose of a 128x128 bit matrix using a + // divide-and-conquer algorithm. The idea is that for + // A = [ A B ] + // [ C D ] + // A^T is equal to + // [ A^T C^T ] + // [ B^T D^T ] + // + // We first divide our matrix into 2x2 bit matrices which we transpose at the + // bit level. Then we swap the 2x2 bit matrices to complete a 4x4 + // transpose. We swap the 4x4 bit matrices to complete a 8x8 transpose and so on + // until we swap 64x64 bit matrices and thus complete the intended 128x128 bit + // transpose. + + // Part 1: Specialized 2x2 block transpose transposing individual bits + for chunk in in_out.chunks_exact_mut(2) { + if let [x, y] = chunk { + transpose_2x2_matrices(x, y); + } else { + unreachable!("chunk size is 2") + } } - // Process current block size - let total_rows = 1 << (mat_rows_shift + mat_size_shift); + // Phases 1-5: swap sub-matrices of size 2x2, 4x4, 8x8, 16x16, 32x32 bit + // Using seq_macro to reduce repetition + seq!(N in 1..=5 { + const SHIFT_~N: i32 = 1 << N; + // Our mask selects the part of the sub-matrix that needs to be potentially + // swapped allong the diagonal. The lower 2^SHIFT bits are 0 and the following + // 2^SHIFT bits are 1, repeated to a 64 bit mask + const MASK_~N: u64 = match N { + 1 => mask(0b1100, 4), + 2 => mask(0b11110000, 8), + 3 => mask(0b1111111100000000, 16), + 4 => mask(0b11111111111111110000000000000000, 32), + 5 => 0xffffffff00000000, + _ => unreachable!(), + }; + // The offset between x and y for matrix rows that need to be swapped in terms + // of 256 bit elements. In the first iteration we swap the 2x2 matrices that + // are at positions in_out[i] and in_out[j], so the offset is 1. For 4x4 matrices + // the offset is 2 + #[allow(clippy::eq_op)] // false positive due to use of seq! + const OFFSET~N: usize = 1 << (N - 1); + + for chunk in in_out.chunks_exact_mut(2 * OFFSET~N) { + let (x_chunk, y_chunk) = chunk.split_at_mut(OFFSET~N); + // For larger matrices, and larger offsets, we need to iterate over all + // rows of the sub-matrices + for (x, y) in x_chunk.iter_mut().zip(y_chunk.iter_mut()) { + partial_swap_sub_matrices::(x, y); + } + } + }); - unsafe { - avx_transpose_block_iter2(in_out, block_size_shift, block_rows_shift, total_rows); + // Phase 6: swap 64x64 bit-matrices therfore completing the 128x128 bit + // transpose + const SHIFT_6: usize = 6; + const OFFSET_6: usize = 1 << (SHIFT_6 - 1); // 32 - // Recursively process larger blocks - avx_transpose_block( - in_out, - block_size_shift + 1, - mat_size_shift, - block_rows_shift, - mat_rows_shift, - ); + for chunk in in_out.chunks_exact_mut(2 * OFFSET_6) { + let (x_chunk, y_chunk) = chunk.split_at_mut(OFFSET_6); + for (x, y) in x_chunk.iter_mut().zip(y_chunk.iter_mut()) { + partial_swap_64x64_matrices(x, y); + } } } -const AVX_BLOCK_SHIFT: usize = 4; -const AVX_BLOCK_SIZE: usize = 1 << AVX_BLOCK_SHIFT; +/// Create a u64 bit mask based on the pattern which is repeated to fill the u54 +const fn mask(pattern: u64, pattern_len: u32) -> u64 { + let mut mask = pattern; + let mut current_block_len = pattern_len; -/// Transpose 128x128 bit matrix using AVX2. -/// -/// # Safety -/// AVX2 needs to be enabled. -#[target_feature(enable = "avx2")] -pub fn avx_transpose128x128(in_out: &mut [__m256i; 64]) { - const MAT_SIZE_SHIFT: usize = 7; - unsafe { - let in_out = in_out.as_mut_ptr(); - for i in (0..64).step_by(AVX_BLOCK_SIZE) { - avx_transpose_block( - in_out.add(i), - 1, - MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT, - 1, - AVX_BLOCK_SHIFT + 1 - (MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT), - ); - } - - // Process larger blocks - let block_size_shift = MAT_SIZE_SHIFT - AVX_BLOCK_SHIFT; - - // Special case for full matrix - for i in 0..(1 << (block_size_shift - 1)) { - avx_transpose_block( - in_out.add(i), - block_size_shift, - MAT_SIZE_SHIFT, - block_size_shift, - 0, - ); - } + // We keep doubling the effective length of our repeating block + // until it covers 64 bits. + while current_block_len < 64 { + mask = (mask << current_block_len) | mask; + current_block_len *= 2; } + + mask } -/// Transpose a bit matrix. +/// Transpose a bit matrix using AVX2. +/// +/// This implementation is specifically tuned for transposing `128 x l` matrices +/// as done in OT protocols. Performance might be better if `input` is 16-byte +/// aligned and the number of columns is divisable by 512 on systems with +/// 64-byte cache lines. /// /// # Panics -/// If the input is not divisable by 128. -/// If the number of columns (= input.len() * 8 / 128) is less than 128. /// If `input.len() != output.len()` +/// If the number of rows is less than 128. +/// If the number of rows is not divisable by 128. +/// If the number of columns (= input.len() * 8 / rows) is not divisable by 8. /// /// # Safety /// AVX2 instruction set must be available. #[target_feature(enable = "avx2")] -pub unsafe fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { +pub fn transpose_bitmatrix(input: &[u8], output: &mut [u8], rows: usize) { assert_eq!(input.len(), output.len()); + assert!(rows >= 128, "Number of rows must be >= 128."); + assert_eq!(0, rows % 128, "Number of rows must be a multiple of 128."); let cols = input.len() * 8 / rows; - assert_eq!(0, cols % 128); - assert_eq!(0, rows % 128); - #[allow(unused_unsafe)] - let mut buf = [unsafe { _mm256_setzero_si256() }; 64]; - let in_stride = cols / 8; - let out_stride = rows / 8; - - // Number of 128x128 bit squares + assert_eq!(0, cols % 8, "Number of columns must be a multiple of 8."); + + // Buffer to hold a 4 128x128 bit squares (64 * 4 __m256i registers = 2048 * 4 + // bytes) + let mut buf = [_mm256_setzero_si256(); 64 * 4]; + let in_stride = cols / 8; // Stride in bytes for input rows + let out_stride = rows / 8; // Stride in bytes for output rows + + // Number of 128x128 bit squares in rows and columns let r_main = rows / 128; let c_main = cols / 128; + let c_rest = cols % 128; + // Iterate through each 128x128 bit square in the matrix + // Row block index for i in 0..r_main { - for j in 0..c_main { - // Process each 128x128 bit square - unsafe { - let src_ptr = input.as_ptr().add(i * 128 * in_stride + j * 16); + // Column block index + let mut j = 0; + while j < c_main { + let input_offset = i * 128 * in_stride + j * 16; + let curr_addr = input[input_offset..].as_ptr().addr(); + let next_cache_line_addr = (curr_addr + 1).next_multiple_of(64); // cache line size + let blocks_in_cache_line = (next_cache_line_addr - curr_addr) / 16; + + let remaining_blocks_in_cache_line = if blocks_in_cache_line == 0 { + // will cross over a cache line, but if the blocks are not 16-byte aligned, this + // is the best we can do + 4 + } else { + blocks_in_cache_line + }; + // Ensure we don't read OOB of the input + let remaining_blocks_in_cache_line = + cmp::min(remaining_blocks_in_cache_line, c_main - j); + + let buf_as_bytes: &mut [u8] = must_cast_slice_mut(&mut buf); + + // The loading loop loads the input data into the buf. By using a macro and + // matching on 4 blocks in a cache line (each row in a block is 16 bytes, so the + // rows 4 consecutive blocks are 64 bytes long) the optimizer uses a loop + // unrolled version for this case. + macro_rules! loading_loop { + ($remaining_blocks_in_cache_line:expr) => { + for k in 0..128 { + let src_slice = &input[input_offset + k * in_stride + ..input_offset + k * in_stride + 16 * remaining_blocks_in_cache_line]; + + for block in 0..remaining_blocks_in_cache_line { + buf_as_bytes[block * 2048 + k * 16..block * 2048 + (k + 1) * 16] + .copy_from_slice(&src_slice[block * 16..(block + 1) * 16]); + } + } + }; + } - let buf_u8_ptr = buf.as_mut_ptr() as *mut u8; + // This gets optimized to the unrolled loop for the default case of 4 blocks + match remaining_blocks_in_cache_line { + 4 => loading_loop!(4), + #[allow(unused_variables)] // false positive + other => loading_loop!(other), + } - // Copy 128 rows into buffer - for k in 0..128 { - let src_row = src_ptr.add(k * in_stride); - std::ptr::copy_nonoverlapping(src_row, buf_u8_ptr.add(k * 16), 16); - } + for block in 0..remaining_blocks_in_cache_line { + avx_transpose128x128( + (&mut buf[block * 64..(block + 1) * 64]) + .try_into() + .expect("slice has length 64"), + ); } - // Transpose the 128x128 bit square - avx_transpose128x128(&mut buf); - - unsafe { - // needs to be recreated because prev &mut borrow invalidates ptr - let buf_u8_ptr = buf.as_mut_ptr() as *mut u8; - // Copy transposed data to output - let dst_ptr = output.as_mut_ptr().add(j * 128 * out_stride + i * 16); - for k in 0..128 { - let dst_row = dst_ptr.add(k * out_stride); - std::ptr::copy_nonoverlapping(buf_u8_ptr.add(k * 16), dst_row, 16); + + let mut output_offset = j * 128 * out_stride + i * 16; + let buf_as_bytes: &[u8] = must_cast_slice(&buf); + + if out_stride == 16 { + // if the out_stride is 16 bytes, the transposed sub-matrices are in contigous + // memory in the output, so we can use a single copy_from_slice. This is + // especially helpfule for the case of transposing a 128xl matrix as done in OT + // extension. + let dst_slice = &mut output + [output_offset..output_offset + 16 * 128 * remaining_blocks_in_cache_line]; + dst_slice.copy_from_slice(&buf_as_bytes[..remaining_blocks_in_cache_line * 2048]); + } else { + for block in 0..remaining_blocks_in_cache_line { + for k in 0..128 { + let src_slice = + &buf_as_bytes[block * 2048 + k * 16..block * 2048 + (k + 1) * 16]; + let dst_slice = &mut output + [output_offset + k * out_stride..output_offset + k * out_stride + 16]; + dst_slice.copy_from_slice(src_slice); + } + output_offset += 128 * out_stride; } } + + j += remaining_blocks_in_cache_line; } + + if c_rest > 0 { + handle_rest_cols(input, output, &mut buf, in_stride, out_stride, c_rest, i, j); + } + } +} + +// Inline never to reduce code size of main method. +#[inline(never)] +#[target_feature(enable = "avx2")] +#[allow(clippy::too_many_arguments)] +fn handle_rest_cols( + input: &[u8], + output: &mut [u8], + buf: &mut [__m256i; 256], + in_stride: usize, + out_stride: usize, + c_rest: usize, + i: usize, + j: usize, +) { + let input_offset = i * 128 * in_stride + j * 16; + let remaining_cols_bytes = c_rest / 8; + buf[0..64].fill(_mm256_setzero_si256()); + let buf_as_bytes: &mut [u8] = must_cast_slice_mut(buf); + + for k in 0..128 { + let src_row_offset = input_offset + k * in_stride; + let src_slice = &input[src_row_offset..src_row_offset + remaining_cols_bytes]; + // we use 16 because we still transpose a 128x128 matrix, of which only a part + // is filled + let buf_offset = k * 16; + buf_as_bytes[buf_offset..buf_offset + remaining_cols_bytes].copy_from_slice(src_slice); + } + + avx_transpose128x128((&mut buf[..64]).try_into().expect("slice has length 64")); + + let output_offset = j * 128 * out_stride + i * 16; + let buf_as_bytes: &[u8] = must_cast_slice(&*buf); + + for k in 0..c_rest { + let src_slice = &buf_as_bytes[k * 16..(k + 1) * 16]; + let dst_slice = + &mut output[output_offset + k * out_stride..output_offset + k * out_stride + 16]; + dst_slice.copy_from_slice(src_slice); } } @@ -292,4 +371,63 @@ mod tests { assert_eq!(sse_transposed, avx_transposed); } + + #[test] + fn test_avx_transpose_unaligned_data() { + let rows = 128 * 2; + let cols = 128 * 2; + let mut v = vec![0_u8; rows * (cols + 128) / 8]; + StdRng::seed_from_u64(42).fill_bytes(&mut v); + + let v = { + let addr = v.as_ptr().addr(); + let offset = addr.next_multiple_of(3) - addr; + &v[offset..offset + rows * cols / 8] + }; + assert_eq!(0, v.as_ptr().addr() % 3); + // allocate out bufs with same dims + let mut avx_transposed = v.to_owned(); + let mut sse_transposed = v.to_owned(); + + unsafe { + transpose_bitmatrix(&v, &mut avx_transposed, rows); + } + crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows); + + assert_eq!(sse_transposed, avx_transposed); + } + + #[test] + fn test_avx_transpose_larger_cols_divisable_by_4_times_128() { + let rows = 128; + let cols = 128 * 8; + let mut v = vec![0_u8; rows * cols / 8]; + StdRng::seed_from_u64(42).fill_bytes(&mut v); + + let mut avx_transposed = v.clone(); + let mut sse_transposed = v.clone(); + unsafe { + transpose_bitmatrix(&v, &mut avx_transposed, rows); + } + crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows); + + assert_eq!(sse_transposed, avx_transposed); + } + + #[test] + fn test_avx_transpose_larger_cols_divisable_by_8() { + let rows = 128; + let cols = 128 + 32; + let mut v = vec![0_u8; rows * cols / 8]; + StdRng::seed_from_u64(42).fill_bytes(&mut v); + + let mut avx_transposed = v.clone(); + let mut sse_transposed = v.clone(); + unsafe { + transpose_bitmatrix(&v, &mut avx_transposed, rows); + } + crate::transpose::portable::transpose_bitmatrix(&v, &mut sse_transposed, rows); + + assert_eq!(sse_transposed, avx_transposed); + } } diff --git a/cryprot-net/src/lib.rs b/cryprot-net/src/lib.rs index e130e3a..621c819 100644 --- a/cryprot-net/src/lib.rs +++ b/cryprot-net/src/lib.rs @@ -469,7 +469,7 @@ impl SendStreamBytes { self.inner.close().await.map_err(StreamError::Close) } - pub fn as_stream(&mut self) -> SendStreamTemp { + pub fn as_stream(&mut self) -> SendStreamTemp<'_, T> { let framed_send = default_codec().new_write(self); SymmetricallyFramed::new(framed_send, Bincode::default()) } @@ -517,7 +517,7 @@ fn trace_poll(p: Poll>) -> Poll> { } impl ReceiveStreamBytes { - pub fn as_stream(&mut self) -> ReceiveStreamTemp { + pub fn as_stream(&mut self) -> ReceiveStreamTemp<'_, T> { let framed_read = default_codec().new_read(self); SymmetricallyFramed::new(framed_read, Bincode::default()) }