diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 43dc7537c33..2745490bf5e 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -407,8 +407,6 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) { int i = 0; ggml_float sum = 0; -// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE -// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344 #if defined(__AVX512F__) && defined(__AVX512DQ__) for (; i + 15 < n; i += 16) { __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i), @@ -416,18 +414,63 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa _mm512_storeu_ps(y + i, val); sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val)); } + #if defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), + _mm256_set1_ps(mean)); + _mm256_storeu_ps(y + i, val); + val = _mm256_mul_ps(val, val); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), + _mm256_castps256_ps128(val)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (ggml_float)_mm_cvtss_f32(val2); + } + #endif // __AVX2__ && __FMA__ + for (; i + 3 < n; i += 4) { + __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(mean)); + _mm_storeu_ps(y + i, val); + val = _mm_mul_ps(val, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif // __AVX__ || __AVX2__ || __AVX512F__ + sum += (ggml_float)_mm_cvtss_f32(val); + } #elif defined(__AVX2__) && defined(__FMA__) for (; i + 7 < n; i += 8) { __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), _mm256_set1_ps(mean)); _mm256_storeu_ps(y + i, val); - val = _mm256_mul_ps(val,val); + val = _mm256_mul_ps(val, val); __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), _mm256_castps256_ps128(val)); val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); sum += (ggml_float)_mm_cvtss_f32(val2); } + for (; i + 3 < n; i += 4) { + __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(mean)); + _mm_storeu_ps(y + i, val); + val = _mm_mul_ps(val, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif // __AVX__ || __AVX2__ || __AVX512F__ + sum += (ggml_float)_mm_cvtss_f32(val); + } #elif defined(__SSE2__) for (; i + 3 < n; i += 4) { __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),