From d11c45d59dc433d4045baeb1b0e6be01c63181aa Mon Sep 17 00:00:00 2001 From: pl752 Date: Sat, 2 May 2026 03:09:06 +0500 Subject: [PATCH 1/5] Implemented bit-interleaved Q1_0 8x32 repack kernels for x86 AVX2 --- ggml/src/ggml-cpu/arch-fallback.h | 22 ++ ggml/src/ggml-cpu/arch/x86/repack.cpp | 304 ++++++++++++++++++++++++++ ggml/src/ggml-cpu/repack.cpp | 238 ++++++++++++++++++++ ggml/src/ggml-cpu/repack.h | 11 + 4 files changed, 575 insertions(+) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 1758d83c261..1e3a4521c5b 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -39,6 +39,7 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -55,6 +56,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -71,16 +73,20 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 @@ -125,6 +131,7 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -141,6 +148,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -157,6 +165,7 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 #elif defined(__loongarch64) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -172,6 +181,7 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -188,6 +198,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -204,12 +215,14 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 // repack.cpp +#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1 @@ -230,6 +243,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K @@ -245,6 +259,8 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 + #elif defined(__s390x__) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -266,6 +282,7 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -282,6 +299,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -298,6 +316,7 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 #elif defined(__wasm__) // quants.c #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 @@ -321,6 +340,7 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -337,6 +357,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -353,4 +374,5 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 #endif diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index af1cebad131..d97c90c21c7 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -287,6 +287,75 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #endif } +void ggml_quantize_mat_q8_0_4x32(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + +#if defined(__AVX2__) + for (int i = 0; i < nb; i++) { + for (int r = 0; r < 4; r++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x + r * k + i * 32 ); + __m256 v1 = _mm256_loadu_ps( x + r * k + i * 32 + 8 ); + __m256 v2 = _mm256_loadu_ps( x + r * k + i * 32 + 16 ); + __m256 v3 = _mm256_loadu_ps( x + r * k + i * 32 + 24 ); + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + const float d = maxScalar / 127.f; + y[i].d[r] = GGML_CPU_FP32_TO_FP16(d); + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert to int + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + i0 = _mm256_packs_epi16( i0, i2 ); + + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + // Store row r contiguously + _mm256_storeu_si256((__m256i *)(y[i].qs + r * 32), i0); + } + } +#else + UNUSED(nb); + UNUSED(y); + ggml_quantize_mat_q8_0_4x32_generic(x, vy, k); +#endif +} + void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); @@ -1461,6 +1530,147 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined( __AVX2__ ) || defined( __AVX512F__ ) + { + assert (n % QK1_0 == 0); + assert (nc % 8 == 0); + + UNUSED(bs); + UNUSED(nr); + + const int nb = n / QK1_0; + const int nb32 = n / QK8_0; + const int ncols8 = nc / 8; + + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i zero = _mm256_setzero_si256(); + + // Shuffle LUTs for columns 0-3: LUT[b & 0xF] = (b >> c) & 1 ? 0x00 : 0xFF + alignas(32) static const uint8_t sm_lut_c0[16] = { + 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, + 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00 + }; + alignas(32) static const uint8_t sm_lut_c1[16] = { + 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, 0x00, 0x00, + 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, 0x00, 0x00 + }; + alignas(32) static const uint8_t sm_lut_c2[16] = { + 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00 + }; + alignas(32) static const uint8_t sm_lut_c3[16] = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + }; + + const __m256i lut[4] = { + _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i *)sm_lut_c0)), + _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i *)sm_lut_c1)), + _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i *)sm_lut_c2)), + _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i *)sm_lut_c3)), + }; + + // Column masks for columns 4-7 (AND+cmpeq path) + const __m256i col_mask_4 = _mm256_set1_epi8(16); + const __m256i col_mask_5 = _mm256_set1_epi8(32); + const __m256i col_mask_6 = _mm256_set1_epi8(64); + const __m256i col_mask_7 = _mm256_set1_epi8((int8_t)-128); + + const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; + const block_q8_0 * a_ptr = (const block_q8_0 *)vy; + + for (int y = 0; y < nr; ++y) { + const block_q8_0 * a_row = a_ptr + (size_t)y * nb32; + float * row_out = s + (size_t)y * nc; + + for (int x = 0; x < ncols8; ++x) { + const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; + + __m256 acc[8]; + for (int c = 0; c < 8; ++c) acc[c] = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + float bd[8]; + for (int c = 0; c < 8; ++c) + bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); + + __m256 block_acc[8]; + for (int c = 0; c < 8; ++c) block_acc[c] = _mm256_setzero_ps(); + + const uint8_t * qs_base = (const uint8_t *)b_ptr[l].qs; + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * yb = &a_row[l * 4 + sb]; + const __m256i rhs = _mm256_loadu_si256((const __m256i *)yb->qs); + const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + + const __m256i qs_vec = _mm256_loadu_si256((const __m256i *)(qs_base + sb * 32)); + + // Columns 0-3: shuffle LUT on low 7 bits + const __m256i qs_lo7 = _mm256_and_si256(qs_vec, _mm256_set1_epi8(0x7F)); + const __m256i sm0 = _mm256_shuffle_epi8(lut[0], qs_lo7); + const __m256i sm1 = _mm256_shuffle_epi8(lut[1], qs_lo7); + const __m256i sm2 = _mm256_shuffle_epi8(lut[2], qs_lo7); + const __m256i sm3 = _mm256_shuffle_epi8(lut[3], qs_lo7); + + // Columns 4-7: AND + cmpeq + const __m256i sm4 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_4), zero); + const __m256i sm5 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_5), zero); + const __m256i sm6 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_6), zero); + const __m256i sm7 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_7), zero); + + // Sign-flip and accumulate for all 8 columns + const __m256i sy0 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm0), sm0); + const __m256i sy1 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm1), sm1); + const __m256i sy2 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm2), sm2); + const __m256i sy3 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm3), sm3); + const __m256i sy4 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm4), sm4); + const __m256i sy5 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm5), sm5); + const __m256i sy6 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm6), sm6); + const __m256i sy7 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm7), sm7); + + const __m256i s32_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy0), ones_16); + const __m256i s32_1 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy1), ones_16); + const __m256i s32_2 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy2), ones_16); + const __m256i s32_3 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy3), ones_16); + const __m256i s32_4 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy4), ones_16); + const __m256i s32_5 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy5), ones_16); + const __m256i s32_6 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy6), ones_16); + const __m256i s32_7 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy7), ones_16); + + block_acc[0] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_0), block_acc[0]); + block_acc[1] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_1), block_acc[1]); + block_acc[2] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_2), block_acc[2]); + block_acc[3] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_3), block_acc[3]); + block_acc[4] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_4), block_acc[4]); + block_acc[5] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_5), block_acc[5]); + block_acc[6] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_6), block_acc[6]); + block_acc[7] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_7), block_acc[7]); + } + + for (int c = 0; c < 8; ++c) { + acc[c] = _mm256_fmadd_ps(_mm256_set1_ps(bd[c]), block_acc[c], acc[c]); + } + } + + // Reduce 8 lanes to 1 value per column and store + for (int c = 0; c < 8; ++c) { + const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc[c]), _mm256_extractf128_ps(acc[c], 1)); + const __m128 t = _mm_hadd_ps(v, v); + row_out[x * 8 + c] = _mm_cvtss_f32(_mm_hadd_ps(t, t)); + } + } + } + + return; + } +#endif // defined( __AVX2__ ) || defined( __AVX512F__ ) + + ggml_gemv_q1_0_8x32_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -2039,6 +2249,100 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined( __AVX2__ ) || defined( __AVX512F__ ) + { + assert (n % QK1_0 == 0); + assert (nr % 4 == 0); + assert (nc % 8 == 0); + + UNUSED(bs); + + const int nb = n / QK1_0; + const int nb_q8_0 = n / QK8_0; + const int ncols8 = nc / 8; + const int nrows4 = nr / 4; + + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i zero = _mm256_setzero_si256(); + + const uint8_t col_masks[8] = {1, 2, 4, 8, 16, 32, 64, 128}; + + const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; + + for (int y = 0; y < nrows4; ++y) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *)vy + (y * nb_q8_0); + + for (int row_base = 0; row_base < 4; row_base += 2) { + for (int x = 0; x < ncols8; ++x) { + const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; + + __m256 acc[2][8]; + for (int r = 0; r < 2; ++r) + for (int c = 0; c < 8; ++c) + acc[r][c] = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + float bd[8]; + for (int c = 0; c < 8; ++c) + bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); + + __m256 block_acc[2][8]; + for (int r = 0; r < 2; ++r) + for (int c = 0; c < 8; ++c) + block_acc[r][c] = _mm256_setzero_ps(); + + const uint8_t * qs_base = (const uint8_t *)b_ptr[l].qs; + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0x4 * yb = &a_ptr[l * 4 + sb]; + const __m256i qs_vec = _mm256_loadu_si256((const __m256i *)(qs_base + sb * 32)); + + __m256i sm[8]; + for (int c = 0; c < 8; ++c) { + const __m256i mask_c = _mm256_set1_epi8((int8_t)col_masks[c]); + sm[c] = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, mask_c), zero); + } + + for (int r = 0; r < 2; ++r) { + const __m256i rhs = _mm256_loadu_si256((const __m256i *)(yb->qs + (row_base + r) * 32)); + const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d[row_base + r])); + + for (int c = 0; c < 8; ++c) { + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm[c]), sm[c]); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + block_acc[r][c] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32), block_acc[r][c]); + } + } + } + + for (int r = 0; r < 2; ++r) { + for (int c = 0; c < 8; ++c) { + acc[r][c] = _mm256_fmadd_ps(_mm256_set1_ps(bd[c]), block_acc[r][c], acc[r][c]); + } + } + } + + float * s_row0 = s + (y * 4 + row_base + 0) * bs + x * 8; + float * s_row1 = s + (y * 4 + row_base + 1) * bs + x * 8; + for (int c = 0; c < 8; ++c) { + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc[0][c]), _mm256_extractf128_ps(acc[0][c], 1)); + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc[1][c]), _mm256_extractf128_ps(acc[1][c], 1)); + s_row0[c] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); + s_row1[c] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); + } + } + } + } + + return; + } +#endif // defined( __AVX2__ ) || defined( __AVX512F__ ) + + ggml_gemm_q1_0_8x32_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index f18758f16bb..13ba4434e6e 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -208,6 +208,42 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } +void ggml_quantize_mat_q8_0_4x32_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + + // scalar + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d); + } + + // Store each row's 32 bytes contiguously + for (int r = 0; r < 4; r++) { + for (int j = 0; j < QK8_0; j++) { + float x0 = srcv[r][j] * id[r]; + y[i].qs[r * QK8_0 + j] = roundf(x0); + } + } + } +} + void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); @@ -339,6 +375,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } +template <> void ggml_quantize_mat_t<32, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_0_4x32(x, vy, n_per_row); +} + #if defined __riscv_zvfh template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); @@ -884,6 +926,64 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert (n % QK1_0 == 0); + assert (nc % 8 == 0); + + UNUSED(bs); + UNUSED(nr); + + const int nb = n / QK1_0; + const int nb32 = n / QK8_0; + const int ncols8 = nc / 8; + + const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; + const block_q8_0 * a_ptr = (const block_q8_0 *)vy; + + for (int y = 0; y < nr; ++y) { + const block_q8_0 * a_row = a_ptr + (size_t)y * nb32; + float * row_out = s + (size_t)y * nc; + + for (int x = 0; x < ncols8; ++x) { + const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; + + float acc[8] = {0}; + + for (int l = 0; l < nb; ++l) { + float bd[8]; + for (int c = 0; c < 8; ++c) + bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); + + float block_acc[8] = {0}; + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * yb = &a_row[l * 4 + sb]; + const float dy = GGML_CPU_FP16_TO_FP32(yb->d); + + const uint8_t * qs = (const uint8_t *)b_ptr[l].qs + sb * 32; + const int8_t * y = yb->qs; + + for (int c = 0; c < 8; ++c) { + int sumi = 0; + for (int i = 0; i < QK8_0; ++i) { + sumi += ((qs[i] >> c) & 1) ? y[i] : -y[i]; + } + block_acc[c] += dy * (float)sumi; + } + } + + for (int c = 0; c < 8; ++c) { + acc[c] += bd[c] * block_acc[c]; + } + } + + for (int c = 0; c < 8; ++c) { + row_out[x * 8 + c] = acc[c]; + } + } + } +} + void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -1819,6 +1919,70 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert (n % QK1_0 == 0); + assert (nr % 4 == 0); + assert (nc % 8 == 0); + + const int nb = n / QK1_0; + const int nb_q8_0 = n / QK8_0; + const int ncols8 = nc / 8; + const int nrows4 = nr / 4; + + const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; + + for (int y = 0; y < nrows4; ++y) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *)vy + (y * nb_q8_0); + + for (int row_base = 0; row_base < 4; row_base += 2) { + for (int x = 0; x < ncols8; ++x) { + const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; + + float acc[2][8] = {{0}}; + + for (int l = 0; l < nb; ++l) { + float bd[8]; + for (int c = 0; c < 8; ++c) + bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); + + float block_acc[2][8] = {{0}}; + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0x4 * yb = &a_ptr[l * 4 + sb]; + const uint8_t * qs = (const uint8_t *)b_ptr[l].qs + sb * 32; + + for (int r = 0; r < 2; ++r) { + const float dy = GGML_CPU_FP16_TO_FP32(yb->d[row_base + r]); + + for (int c = 0; c < 8; ++c) { + int sumi = 0; + for (int i = 0; i < QK8_0; ++i) { + const int8_t y_val = yb->qs[(row_base + r) * 32 + i]; + sumi += ((qs[i] >> c) & 1) ? y_val : -y_val; + } + block_acc[r][c] += dy * (float)sumi; + } + } + } + + for (int r = 0; r < 2; ++r) { + for (int c = 0; c < 8; ++c) { + acc[r][c] += bd[c] * block_acc[r][c]; + } + } + } + + for (int r = 0; r < 2; ++r) { + float * row_out = s + (y * 4 + row_base + r) * bs + x * 8; + for (int c = 0; c < 8; ++c) { + row_out[c] = acc[r][c]; + } + } + } + } + } +} + void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -2808,6 +2972,29 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } +static block_q1_0x8 make_block_q1_0x8(block_q1_0 * in, unsigned int blck_size_interleave) { + block_q1_0x8 out; + + GGML_ASSERT(blck_size_interleave == 8); + + for (int i = 0; i < 8; ++i) { + out.d[i] = in[i].d; + } + + for (int sb = 0; sb < 4; ++sb) { + for (int i = 0; i < 32; ++i) { + uint8_t byte = 0; + for (int c = 0; c < 8; ++c) { + uint8_t src = in[c].qs[sb * 4 + i / 8]; + byte |= ((src >> (i % 8)) & 1) << c; + } + out.qs[sb * 32 + i] = byte; + } + } + + return out; +} + static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { block_q4_0x16 out; @@ -3477,6 +3664,36 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q1_0_to_q1_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q1_0); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q1_0x8 * dst = (block_q1_0x8 *) t->data; + const block_q1_0 * src = (const block_q1_0 *) data; + block_q1_0 dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK1_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q1_0)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; ++x) { + for (int i = 0; i < nrows_interleaved; ++i) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q1_0x8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + + return 0; +} + static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, @@ -3934,6 +4151,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q1_0_to_q1_0_8_bl(t, 8, data, data_size); +} + #if defined __riscv_zvfh template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size); @@ -4031,6 +4252,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q1_0_8x32_q8_0(n, s, bs, vx, vy, nr, nc); +} + #if defined __riscv_zvfh template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -4128,6 +4353,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q1_0_8x32_q8_0(n, s, bs, vx, vy, nr, nc); +} + #if defined __riscv_zvfh template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -4558,6 +4787,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; + // instance for Q1_0 + static const ggml::cpu::repack::tensor_traits q1_0_8x32_q8_0; + // instances for RISC-V // // These implement outer-product style matrix multiplication kernels with @@ -4597,6 +4829,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } #endif } + } else if (cur->type == GGML_TYPE_Q1_0) { + if (ggml_cpu_has_avx2()) { + if (cur->ne[1] % 8 == 0) { + return &q1_0_8x32_q8_0; + } + } } else if (cur->type == GGML_TYPE_Q4_K) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index cb21edf6239..5db7ca8269f 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -11,6 +11,9 @@ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void); template constexpr int QK_0() { + if constexpr (K == 1) { + return QK1_0; + } if constexpr (K == 4) { return QK4_0; } @@ -32,7 +35,9 @@ static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding"); +static_assert(sizeof(block<1, 8>) == 8 * sizeof(ggml_half) + QK1_0, "wrong block<1,8> size/padding"); +using block_q1_0x8 = block<1, 8>; using block_q4_0x4 = block<4, 4>; using block_q4_0x8 = block<4, 8>; using block_q4_0x16 = block<4, 16>; @@ -139,6 +144,7 @@ extern "C" { void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_0_4x32(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -157,6 +163,7 @@ void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -173,6 +180,7 @@ void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); @@ -191,6 +199,7 @@ void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v // Native implementations void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_0_4x32_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -209,6 +218,7 @@ void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -225,6 +235,7 @@ void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); From 5a9ab761476ae224d6b146db3cc3642d02fcd787 Mon Sep 17 00:00:00 2001 From: pl752 Date: Tue, 5 May 2026 19:38:49 +0500 Subject: [PATCH 2/5] Corrected gemv with assumption of nr==1 for consistency --- ggml/src/ggml-cpu/arch/x86/repack.cpp | 165 +++++++++++++------------- ggml/src/ggml-cpu/repack.cpp | 52 ++++---- 2 files changed, 103 insertions(+), 114 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index d97c90c21c7..f2d5c983630 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -1535,12 +1535,12 @@ void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v { assert (n % QK1_0 == 0); assert (nc % 8 == 0); + assert (nr == 1); UNUSED(bs); UNUSED(nr); const int nb = n / QK1_0; - const int nb32 = n / QK8_0; const int ncols8 = nc / 8; const __m256i ones_8 = _mm256_set1_epi8(1); @@ -1548,28 +1548,28 @@ void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i zero = _mm256_setzero_si256(); // Shuffle LUTs for columns 0-3: LUT[b & 0xF] = (b >> c) & 1 ? 0x00 : 0xFF - alignas(32) static const uint8_t sm_lut_c0[16] = { + alignas(16) static const uint8_t sm_lut_c0[16] = { 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00 }; - alignas(32) static const uint8_t sm_lut_c1[16] = { + alignas(16) static const uint8_t sm_lut_c1[16] = { 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, 0x00, 0x00 }; - alignas(32) static const uint8_t sm_lut_c2[16] = { + alignas(16) static const uint8_t sm_lut_c2[16] = { 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00 }; - alignas(32) static const uint8_t sm_lut_c3[16] = { + alignas(16) static const uint8_t sm_lut_c3[16] = { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; const __m256i lut[4] = { - _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i *)sm_lut_c0)), - _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i *)sm_lut_c1)), - _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i *)sm_lut_c2)), - _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i *)sm_lut_c3)), + _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c0)), + _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c1)), + _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c2)), + _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c3)), }; // Column masks for columns 4-7 (AND+cmpeq path) @@ -1581,87 +1581,82 @@ void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; const block_q8_0 * a_ptr = (const block_q8_0 *)vy; - for (int y = 0; y < nr; ++y) { - const block_q8_0 * a_row = a_ptr + (size_t)y * nb32; - float * row_out = s + (size_t)y * nc; - - for (int x = 0; x < ncols8; ++x) { - const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; - - __m256 acc[8]; - for (int c = 0; c < 8; ++c) acc[c] = _mm256_setzero_ps(); - - for (int l = 0; l < nb; ++l) { - float bd[8]; - for (int c = 0; c < 8; ++c) - bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); - - __m256 block_acc[8]; - for (int c = 0; c < 8; ++c) block_acc[c] = _mm256_setzero_ps(); - - const uint8_t * qs_base = (const uint8_t *)b_ptr[l].qs; - - for (int sb = 0; sb < 4; ++sb) { - const block_q8_0 * yb = &a_row[l * 4 + sb]; - const __m256i rhs = _mm256_loadu_si256((const __m256i *)yb->qs); - const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); - - const __m256i qs_vec = _mm256_loadu_si256((const __m256i *)(qs_base + sb * 32)); - - // Columns 0-3: shuffle LUT on low 7 bits - const __m256i qs_lo7 = _mm256_and_si256(qs_vec, _mm256_set1_epi8(0x7F)); - const __m256i sm0 = _mm256_shuffle_epi8(lut[0], qs_lo7); - const __m256i sm1 = _mm256_shuffle_epi8(lut[1], qs_lo7); - const __m256i sm2 = _mm256_shuffle_epi8(lut[2], qs_lo7); - const __m256i sm3 = _mm256_shuffle_epi8(lut[3], qs_lo7); - - // Columns 4-7: AND + cmpeq - const __m256i sm4 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_4), zero); - const __m256i sm5 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_5), zero); - const __m256i sm6 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_6), zero); - const __m256i sm7 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_7), zero); - - // Sign-flip and accumulate for all 8 columns - const __m256i sy0 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm0), sm0); - const __m256i sy1 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm1), sm1); - const __m256i sy2 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm2), sm2); - const __m256i sy3 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm3), sm3); - const __m256i sy4 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm4), sm4); - const __m256i sy5 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm5), sm5); - const __m256i sy6 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm6), sm6); - const __m256i sy7 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm7), sm7); - - const __m256i s32_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy0), ones_16); - const __m256i s32_1 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy1), ones_16); - const __m256i s32_2 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy2), ones_16); - const __m256i s32_3 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy3), ones_16); - const __m256i s32_4 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy4), ones_16); - const __m256i s32_5 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy5), ones_16); - const __m256i s32_6 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy6), ones_16); - const __m256i s32_7 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy7), ones_16); - - block_acc[0] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_0), block_acc[0]); - block_acc[1] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_1), block_acc[1]); - block_acc[2] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_2), block_acc[2]); - block_acc[3] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_3), block_acc[3]); - block_acc[4] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_4), block_acc[4]); - block_acc[5] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_5), block_acc[5]); - block_acc[6] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_6), block_acc[6]); - block_acc[7] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_7), block_acc[7]); - } - - for (int c = 0; c < 8; ++c) { - acc[c] = _mm256_fmadd_ps(_mm256_set1_ps(bd[c]), block_acc[c], acc[c]); - } + for (int x = 0; x < ncols8; ++x) { + const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; + + __m256 acc[8]; + for (int c = 0; c < 8; ++c) acc[c] = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + float bd[8]; + for (int c = 0; c < 8; ++c) + bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); + + __m256 block_acc[8]; + for (int c = 0; c < 8; ++c) block_acc[c] = _mm256_setzero_ps(); + + const uint8_t * qs_base = (const uint8_t *)b_ptr[l].qs; + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * yb = &a_ptr[l * 4 + sb]; + const __m256i rhs = _mm256_loadu_si256((const __m256i *)yb->qs); + const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + + const __m256i qs_vec = _mm256_loadu_si256((const __m256i *)(qs_base + sb * 32)); + + // Columns 0-3: shuffle LUT on low 7 bits + const __m256i qs_lo7 = _mm256_and_si256(qs_vec, _mm256_set1_epi8(0x7F)); + const __m256i sm0 = _mm256_shuffle_epi8(lut[0], qs_lo7); + const __m256i sm1 = _mm256_shuffle_epi8(lut[1], qs_lo7); + const __m256i sm2 = _mm256_shuffle_epi8(lut[2], qs_lo7); + const __m256i sm3 = _mm256_shuffle_epi8(lut[3], qs_lo7); + + // Columns 4-7: AND + cmpeq + const __m256i sm4 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_4), zero); + const __m256i sm5 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_5), zero); + const __m256i sm6 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_6), zero); + const __m256i sm7 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_7), zero); + + // Sign-flip and accumulate for all 8 columns + const __m256i sy0 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm0), sm0); + const __m256i sy1 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm1), sm1); + const __m256i sy2 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm2), sm2); + const __m256i sy3 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm3), sm3); + const __m256i sy4 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm4), sm4); + const __m256i sy5 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm5), sm5); + const __m256i sy6 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm6), sm6); + const __m256i sy7 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm7), sm7); + + const __m256i s32_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy0), ones_16); + const __m256i s32_1 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy1), ones_16); + const __m256i s32_2 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy2), ones_16); + const __m256i s32_3 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy3), ones_16); + const __m256i s32_4 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy4), ones_16); + const __m256i s32_5 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy5), ones_16); + const __m256i s32_6 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy6), ones_16); + const __m256i s32_7 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy7), ones_16); + + block_acc[0] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_0), block_acc[0]); + block_acc[1] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_1), block_acc[1]); + block_acc[2] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_2), block_acc[2]); + block_acc[3] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_3), block_acc[3]); + block_acc[4] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_4), block_acc[4]); + block_acc[5] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_5), block_acc[5]); + block_acc[6] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_6), block_acc[6]); + block_acc[7] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_7), block_acc[7]); } - // Reduce 8 lanes to 1 value per column and store for (int c = 0; c < 8; ++c) { - const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc[c]), _mm256_extractf128_ps(acc[c], 1)); - const __m128 t = _mm_hadd_ps(v, v); - row_out[x * 8 + c] = _mm_cvtss_f32(_mm_hadd_ps(t, t)); + acc[c] = _mm256_fmadd_ps(_mm256_set1_ps(bd[c]), block_acc[c], acc[c]); } } + + // Reduce 8 lanes to 1 value per column and store + for (int c = 0; c < 8; ++c) { + const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc[c]), _mm256_extractf128_ps(acc[c], 1)); + const __m128 t = _mm_hadd_ps(v, v); + s[x * 8 + c] = _mm_cvtss_f32(_mm_hadd_ps(t, t)); + } } return; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 13ba4434e6e..7b0a0a1a8a0 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -929,58 +929,52 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert (n % QK1_0 == 0); assert (nc % 8 == 0); + assert (nr == 1); UNUSED(bs); UNUSED(nr); const int nb = n / QK1_0; - const int nb32 = n / QK8_0; const int ncols8 = nc / 8; const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; const block_q8_0 * a_ptr = (const block_q8_0 *)vy; - for (int y = 0; y < nr; ++y) { - const block_q8_0 * a_row = a_ptr + (size_t)y * nb32; - float * row_out = s + (size_t)y * nc; + for (int x = 0; x < ncols8; ++x) { + const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; - for (int x = 0; x < ncols8; ++x) { - const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; + float acc[8] = {0}; - float acc[8] = {0}; + for (int l = 0; l < nb; ++l) { + float bd[8]; + for (int c = 0; c < 8; ++c) + bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); - for (int l = 0; l < nb; ++l) { - float bd[8]; - for (int c = 0; c < 8; ++c) - bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); + float block_acc[8] = {0}; - float block_acc[8] = {0}; + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * yb = &a_ptr[l * 4 + sb]; + const float dy = GGML_CPU_FP16_TO_FP32(yb->d); - for (int sb = 0; sb < 4; ++sb) { - const block_q8_0 * yb = &a_row[l * 4 + sb]; - const float dy = GGML_CPU_FP16_TO_FP32(yb->d); - - const uint8_t * qs = (const uint8_t *)b_ptr[l].qs + sb * 32; - const int8_t * y = yb->qs; - - for (int c = 0; c < 8; ++c) { - int sumi = 0; - for (int i = 0; i < QK8_0; ++i) { - sumi += ((qs[i] >> c) & 1) ? y[i] : -y[i]; - } - block_acc[c] += dy * (float)sumi; - } - } + const uint8_t * qs = (const uint8_t *)b_ptr[l].qs + sb * 32; + const int8_t * y = yb->qs; for (int c = 0; c < 8; ++c) { - acc[c] += bd[c] * block_acc[c]; + int sumi = 0; + for (int i = 0; i < QK8_0; ++i) { + sumi += ((qs[i] >> c) & 1) ? y[i] : -y[i]; + } + block_acc[c] += dy * (float)sumi; } } for (int c = 0; c < 8; ++c) { - row_out[x * 8 + c] = acc[c]; + acc[c] += bd[c] * block_acc[c]; } } + + static_assert(sizeof(acc) == 32); + memcpy(s + x*8, acc, sizeof(acc)); } } From 3fa569db05ca2e65d691d5a74133f2e99451f5b5 Mon Sep 17 00:00:00 2001 From: pl752 Date: Tue, 5 May 2026 20:51:50 +0500 Subject: [PATCH 3/5] Unrolled mm256 registers to avoid register pointers --- ggml/src/ggml-cpu/arch/x86/repack.cpp | 257 ++++++++++++++++++-------- 1 file changed, 179 insertions(+), 78 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index f2d5c983630..768a0952bf8 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -1565,12 +1565,10 @@ void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; - const __m256i lut[4] = { - _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c0)), - _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c1)), - _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c2)), - _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c3)), - }; + const __m256i lut0 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c0)); + const __m256i lut1 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c1)); + const __m256i lut2 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c2)); + const __m256i lut3 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c3)); // Column masks for columns 4-7 (AND+cmpeq path) const __m256i col_mask_4 = _mm256_set1_epi8(16); @@ -1578,22 +1576,37 @@ void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i col_mask_6 = _mm256_set1_epi8(64); const __m256i col_mask_7 = _mm256_set1_epi8((int8_t)-128); + const __m256i low_nibble_mask = _mm256_set1_epi8(0x0F); + const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; const block_q8_0 * a_ptr = (const block_q8_0 *)vy; for (int x = 0; x < ncols8; ++x) { const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; - __m256 acc[8]; - for (int c = 0; c < 8; ++c) acc[c] = _mm256_setzero_ps(); + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + __m256 acc4 = _mm256_setzero_ps(); + __m256 acc5 = _mm256_setzero_ps(); + __m256 acc6 = _mm256_setzero_ps(); + __m256 acc7 = _mm256_setzero_ps(); for (int l = 0; l < nb; ++l) { float bd[8]; - for (int c = 0; c < 8; ++c) + for (int c = 0; c < 8; ++c) { bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); + } - __m256 block_acc[8]; - for (int c = 0; c < 8; ++c) block_acc[c] = _mm256_setzero_ps(); + __m256 ba0 = _mm256_setzero_ps(); + __m256 ba1 = _mm256_setzero_ps(); + __m256 ba2 = _mm256_setzero_ps(); + __m256 ba3 = _mm256_setzero_ps(); + __m256 ba4 = _mm256_setzero_ps(); + __m256 ba5 = _mm256_setzero_ps(); + __m256 ba6 = _mm256_setzero_ps(); + __m256 ba7 = _mm256_setzero_ps(); const uint8_t * qs_base = (const uint8_t *)b_ptr[l].qs; @@ -1605,11 +1618,11 @@ void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i qs_vec = _mm256_loadu_si256((const __m256i *)(qs_base + sb * 32)); // Columns 0-3: shuffle LUT on low 7 bits - const __m256i qs_lo7 = _mm256_and_si256(qs_vec, _mm256_set1_epi8(0x7F)); - const __m256i sm0 = _mm256_shuffle_epi8(lut[0], qs_lo7); - const __m256i sm1 = _mm256_shuffle_epi8(lut[1], qs_lo7); - const __m256i sm2 = _mm256_shuffle_epi8(lut[2], qs_lo7); - const __m256i sm3 = _mm256_shuffle_epi8(lut[3], qs_lo7); + const __m256i qs_lo4 = _mm256_and_si256(qs_vec, low_nibble_mask); + const __m256i sm0 = _mm256_shuffle_epi8(lut0, qs_lo4); + const __m256i sm1 = _mm256_shuffle_epi8(lut1, qs_lo4); + const __m256i sm2 = _mm256_shuffle_epi8(lut2, qs_lo4); + const __m256i sm3 = _mm256_shuffle_epi8(lut3, qs_lo4); // Columns 4-7: AND + cmpeq const __m256i sm4 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_4), zero); @@ -1627,35 +1640,57 @@ void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i sy6 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm6), sm6); const __m256i sy7 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm7), sm7); - const __m256i s32_0 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy0), ones_16); - const __m256i s32_1 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy1), ones_16); - const __m256i s32_2 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy2), ones_16); - const __m256i s32_3 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy3), ones_16); - const __m256i s32_4 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy4), ones_16); - const __m256i s32_5 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy5), ones_16); - const __m256i s32_6 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy6), ones_16); - const __m256i s32_7 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy7), ones_16); - - block_acc[0] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_0), block_acc[0]); - block_acc[1] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_1), block_acc[1]); - block_acc[2] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_2), block_acc[2]); - block_acc[3] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_3), block_acc[3]); - block_acc[4] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_4), block_acc[4]); - block_acc[5] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_5), block_acc[5]); - block_acc[6] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_6), block_acc[6]); - block_acc[7] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32_7), block_acc[7]); + ba0 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy0), ones_16)), ba0); + ba1 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy1), ones_16)), ba1); + ba2 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy2), ones_16)), ba2); + ba3 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy3), ones_16)), ba3); + ba4 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy4), ones_16)), ba4); + ba5 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy5), ones_16)), ba5); + ba6 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy6), ones_16)), ba6); + ba7 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy7), ones_16)), ba7); } - for (int c = 0; c < 8; ++c) { - acc[c] = _mm256_fmadd_ps(_mm256_set1_ps(bd[c]), block_acc[c], acc[c]); - } + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(bd[0]), ba0, acc0); + acc1 = _mm256_fmadd_ps(_mm256_set1_ps(bd[1]), ba1, acc1); + acc2 = _mm256_fmadd_ps(_mm256_set1_ps(bd[2]), ba2, acc2); + acc3 = _mm256_fmadd_ps(_mm256_set1_ps(bd[3]), ba3, acc3); + acc4 = _mm256_fmadd_ps(_mm256_set1_ps(bd[4]), ba4, acc4); + acc5 = _mm256_fmadd_ps(_mm256_set1_ps(bd[5]), ba5, acc5); + acc6 = _mm256_fmadd_ps(_mm256_set1_ps(bd[6]), ba6, acc6); + acc7 = _mm256_fmadd_ps(_mm256_set1_ps(bd[7]), ba7, acc7); } - // Reduce 8 lanes to 1 value per column and store - for (int c = 0; c < 8; ++c) { - const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc[c]), _mm256_extractf128_ps(acc[c], 1)); - const __m128 t = _mm_hadd_ps(v, v); - s[x * 8 + c] = _mm_cvtss_f32(_mm_hadd_ps(t, t)); + { + const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc0), _mm256_extractf128_ps(acc0, 1)); + s[x * 8 + 0] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); + } + { + const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc1), _mm256_extractf128_ps(acc1, 1)); + s[x * 8 + 1] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); + } + { + const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc2), _mm256_extractf128_ps(acc2, 1)); + s[x * 8 + 2] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); + } + { + const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc3), _mm256_extractf128_ps(acc3, 1)); + s[x * 8 + 3] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); + } + { + const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc4), _mm256_extractf128_ps(acc4, 1)); + s[x * 8 + 4] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); + } + { + const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc5), _mm256_extractf128_ps(acc5, 1)); + s[x * 8 + 5] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); + } + { + const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc6), _mm256_extractf128_ps(acc6, 1)); + s[x * 8 + 6] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); + } + { + const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc7), _mm256_extractf128_ps(acc7, 1)); + s[x * 8 + 7] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); } } @@ -2262,8 +2297,6 @@ void ggml_gemm_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i ones_16 = _mm256_set1_epi16(1); const __m256i zero = _mm256_setzero_si256(); - const uint8_t col_masks[8] = {1, 2, 4, 8, 16, 32, 64, 128}; - const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; for (int y = 0; y < nrows4; ++y) { @@ -2273,20 +2306,21 @@ void ggml_gemm_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v for (int x = 0; x < ncols8; ++x) { const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; - __m256 acc[2][8]; - for (int r = 0; r < 2; ++r) - for (int c = 0; c < 8; ++c) - acc[r][c] = _mm256_setzero_ps(); + __m256 acc00 = _mm256_setzero_ps(), acc01 = _mm256_setzero_ps(), acc02 = _mm256_setzero_ps(), acc03 = _mm256_setzero_ps(); + __m256 acc04 = _mm256_setzero_ps(), acc05 = _mm256_setzero_ps(), acc06 = _mm256_setzero_ps(), acc07 = _mm256_setzero_ps(); + __m256 acc10 = _mm256_setzero_ps(), acc11 = _mm256_setzero_ps(), acc12 = _mm256_setzero_ps(), acc13 = _mm256_setzero_ps(); + __m256 acc14 = _mm256_setzero_ps(), acc15 = _mm256_setzero_ps(), acc16 = _mm256_setzero_ps(), acc17 = _mm256_setzero_ps(); for (int l = 0; l < nb; ++l) { float bd[8]; - for (int c = 0; c < 8; ++c) + for (int c = 0; c < 8; ++c) { bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); + } - __m256 block_acc[2][8]; - for (int r = 0; r < 2; ++r) - for (int c = 0; c < 8; ++c) - block_acc[r][c] = _mm256_setzero_ps(); + __m256 ba00 = _mm256_setzero_ps(), ba01 = _mm256_setzero_ps(), ba02 = _mm256_setzero_ps(), ba03 = _mm256_setzero_ps(); + __m256 ba04 = _mm256_setzero_ps(), ba05 = _mm256_setzero_ps(), ba06 = _mm256_setzero_ps(), ba07 = _mm256_setzero_ps(); + __m256 ba10 = _mm256_setzero_ps(), ba11 = _mm256_setzero_ps(), ba12 = _mm256_setzero_ps(), ba13 = _mm256_setzero_ps(); + __m256 ba14 = _mm256_setzero_ps(), ba15 = _mm256_setzero_ps(), ba16 = _mm256_setzero_ps(), ba17 = _mm256_setzero_ps(); const uint8_t * qs_base = (const uint8_t *)b_ptr[l].qs; @@ -2294,38 +2328,105 @@ void ggml_gemm_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const block_q8_0x4 * yb = &a_ptr[l * 4 + sb]; const __m256i qs_vec = _mm256_loadu_si256((const __m256i *)(qs_base + sb * 32)); - __m256i sm[8]; - for (int c = 0; c < 8; ++c) { - const __m256i mask_c = _mm256_set1_epi8((int8_t)col_masks[c]); - sm[c] = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, mask_c), zero); - } - - for (int r = 0; r < 2; ++r) { - const __m256i rhs = _mm256_loadu_si256((const __m256i *)(yb->qs + (row_base + r) * 32)); - const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d[row_base + r])); - - for (int c = 0; c < 8; ++c) { - const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm[c]), sm[c]); - const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); - block_acc[r][c] = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(s32), block_acc[r][c]); - } - } + const __m256i sm0 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(1)), zero); + const __m256i sm1 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(2)), zero); + const __m256i sm2 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(4)), zero); + const __m256i sm3 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(8)), zero); + const __m256i sm4 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(16)), zero); + const __m256i sm5 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(32)), zero); + const __m256i sm6 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(64)), zero); + const __m256i sm7 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8((int8_t)-128)), zero); + + const __m256i rhs0 = _mm256_loadu_si256((const __m256i *)(yb->qs + (row_base + 0) * 32)); + const __m256 dy0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d[row_base + 0])); + ba00 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm0), sm0)), ones_16)), ba00); + ba01 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm1), sm1)), ones_16)), ba01); + ba02 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm2), sm2)), ones_16)), ba02); + ba03 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm3), sm3)), ones_16)), ba03); + ba04 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm4), sm4)), ones_16)), ba04); + ba05 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm5), sm5)), ones_16)), ba05); + ba06 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm6), sm6)), ones_16)), ba06); + ba07 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm7), sm7)), ones_16)), ba07); + + const __m256i rhs1 = _mm256_loadu_si256((const __m256i *)(yb->qs + (row_base + 1) * 32)); + const __m256 dy1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d[row_base + 1])); + ba10 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm0), sm0)), ones_16)), ba10); + ba11 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm1), sm1)), ones_16)), ba11); + ba12 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm2), sm2)), ones_16)), ba12); + ba13 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm3), sm3)), ones_16)), ba13); + ba14 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm4), sm4)), ones_16)), ba14); + ba15 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm5), sm5)), ones_16)), ba15); + ba16 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm6), sm6)), ones_16)), ba16); + ba17 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm7), sm7)), ones_16)), ba17); } - for (int r = 0; r < 2; ++r) { - for (int c = 0; c < 8; ++c) { - acc[r][c] = _mm256_fmadd_ps(_mm256_set1_ps(bd[c]), block_acc[r][c], acc[r][c]); - } - } + acc00 = _mm256_fmadd_ps(_mm256_set1_ps(bd[0]), ba00, acc00); + acc01 = _mm256_fmadd_ps(_mm256_set1_ps(bd[1]), ba01, acc01); + acc02 = _mm256_fmadd_ps(_mm256_set1_ps(bd[2]), ba02, acc02); + acc03 = _mm256_fmadd_ps(_mm256_set1_ps(bd[3]), ba03, acc03); + acc04 = _mm256_fmadd_ps(_mm256_set1_ps(bd[4]), ba04, acc04); + acc05 = _mm256_fmadd_ps(_mm256_set1_ps(bd[5]), ba05, acc05); + acc06 = _mm256_fmadd_ps(_mm256_set1_ps(bd[6]), ba06, acc06); + acc07 = _mm256_fmadd_ps(_mm256_set1_ps(bd[7]), ba07, acc07); + acc10 = _mm256_fmadd_ps(_mm256_set1_ps(bd[0]), ba10, acc10); + acc11 = _mm256_fmadd_ps(_mm256_set1_ps(bd[1]), ba11, acc11); + acc12 = _mm256_fmadd_ps(_mm256_set1_ps(bd[2]), ba12, acc12); + acc13 = _mm256_fmadd_ps(_mm256_set1_ps(bd[3]), ba13, acc13); + acc14 = _mm256_fmadd_ps(_mm256_set1_ps(bd[4]), ba14, acc14); + acc15 = _mm256_fmadd_ps(_mm256_set1_ps(bd[5]), ba15, acc15); + acc16 = _mm256_fmadd_ps(_mm256_set1_ps(bd[6]), ba16, acc16); + acc17 = _mm256_fmadd_ps(_mm256_set1_ps(bd[7]), ba17, acc17); } float * s_row0 = s + (y * 4 + row_base + 0) * bs + x * 8; float * s_row1 = s + (y * 4 + row_base + 1) * bs + x * 8; - for (int c = 0; c < 8; ++c) { - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc[0][c]), _mm256_extractf128_ps(acc[0][c], 1)); - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc[1][c]), _mm256_extractf128_ps(acc[1][c], 1)); - s_row0[c] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); - s_row1[c] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); + { + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc00), _mm256_extractf128_ps(acc00, 1)); + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc10), _mm256_extractf128_ps(acc10, 1)); + s_row0[0] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); + s_row1[0] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); + } + { + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc01), _mm256_extractf128_ps(acc01, 1)); + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc11), _mm256_extractf128_ps(acc11, 1)); + s_row0[1] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); + s_row1[1] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); + } + { + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc02), _mm256_extractf128_ps(acc02, 1)); + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc12), _mm256_extractf128_ps(acc12, 1)); + s_row0[2] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); + s_row1[2] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); + } + { + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc03), _mm256_extractf128_ps(acc03, 1)); + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc13), _mm256_extractf128_ps(acc13, 1)); + s_row0[3] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); + s_row1[3] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); + } + { + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc04), _mm256_extractf128_ps(acc04, 1)); + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc14), _mm256_extractf128_ps(acc14, 1)); + s_row0[4] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); + s_row1[4] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); + } + { + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc05), _mm256_extractf128_ps(acc05, 1)); + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc15), _mm256_extractf128_ps(acc15, 1)); + s_row0[5] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); + s_row1[5] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); + } + { + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc06), _mm256_extractf128_ps(acc06, 1)); + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc16), _mm256_extractf128_ps(acc16, 1)); + s_row0[6] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); + s_row1[6] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); + } + { + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc07), _mm256_extractf128_ps(acc07, 1)); + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc17), _mm256_extractf128_ps(acc17, 1)); + s_row0[7] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); + s_row1[7] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); } } } From 4e34da1643c12e67bd99bb90012f41526e5bfbec Mon Sep 17 00:00:00 2001 From: pl752 Date: Tue, 5 May 2026 23:01:28 +0500 Subject: [PATCH 4/5] Split 8 cols to two passes to reduce register pressure --- ggml/src/ggml-cpu/arch/x86/repack.cpp | 201 ++++++++++---------------- 1 file changed, 77 insertions(+), 124 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index 768a0952bf8..c783f0cc063 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -2279,6 +2279,77 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +#define GGML_GEMM_Q1_0_4COL(M0, M1, M2, M3, D0, D1, D2, D3, OFF) \ + { \ + const __m256i cm0 = _mm256_set1_epi8(M0); \ + const __m256i cm1 = _mm256_set1_epi8(M1); \ + const __m256i cm2 = _mm256_set1_epi8(M2); \ + const __m256i cm3 = _mm256_set1_epi8(M3); \ + __m256 a0 = _mm256_setzero_ps(), a1 = _mm256_setzero_ps(), a2 = _mm256_setzero_ps(), a3 = _mm256_setzero_ps(); \ + __m256 a10 = _mm256_setzero_ps(), a11 = _mm256_setzero_ps(), a12 = _mm256_setzero_ps(), a13 = _mm256_setzero_ps(); \ + for (int l = 0; l < nb; ++l) { \ + const float bd0 = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[D0]); \ + const float bd1 = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[D1]); \ + const float bd2 = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[D2]); \ + const float bd3 = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[D3]); \ + __m256 b0 = _mm256_setzero_ps(), b1 = _mm256_setzero_ps(), b2 = _mm256_setzero_ps(), b3 = _mm256_setzero_ps(); \ + __m256 b10 = _mm256_setzero_ps(), b11 = _mm256_setzero_ps(), b12 = _mm256_setzero_ps(), b13 = _mm256_setzero_ps(); \ + const uint8_t * qs_base = (const uint8_t *)b_ptr[l].qs; \ + for (int sb = 0; sb < 4; ++sb) { \ + const block_q8_0x4 * yb = &a_ptr[l * 4 + sb]; \ + const __m256i qs_vec = _mm256_loadu_si256((const __m256i *)(qs_base + sb * 32)); \ + const __m256i sm0 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, cm0), zero); \ + const __m256i sm1 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, cm1), zero); \ + const __m256i sm2 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, cm2), zero); \ + const __m256i sm3 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, cm3), zero); \ + const __m256i rhs0 = _mm256_loadu_si256((const __m256i *)(yb->qs + (row_base + 0) * 32)); \ + const __m256 dy0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d[row_base + 0])); \ + b0 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm0), sm0)), ones_16)), b0); \ + b1 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm1), sm1)), ones_16)), b1); \ + b2 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm2), sm2)), ones_16)), b2); \ + b3 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm3), sm3)), ones_16)), b3); \ + const __m256i rhs1 = _mm256_loadu_si256((const __m256i *)(yb->qs + (row_base + 1) * 32)); \ + const __m256 dy1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d[row_base + 1])); \ + b10 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm0), sm0)), ones_16)), b10); \ + b11 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm1), sm1)), ones_16)), b11); \ + b12 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm2), sm2)), ones_16)), b12); \ + b13 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm3), sm3)), ones_16)), b13); \ + } \ + a0 = _mm256_fmadd_ps(_mm256_set1_ps(bd0), b0, a0); \ + a1 = _mm256_fmadd_ps(_mm256_set1_ps(bd1), b1, a1); \ + a2 = _mm256_fmadd_ps(_mm256_set1_ps(bd2), b2, a2); \ + a3 = _mm256_fmadd_ps(_mm256_set1_ps(bd3), b3, a3); \ + a10 = _mm256_fmadd_ps(_mm256_set1_ps(bd0), b10, a10); \ + a11 = _mm256_fmadd_ps(_mm256_set1_ps(bd1), b11, a11); \ + a12 = _mm256_fmadd_ps(_mm256_set1_ps(bd2), b12, a12); \ + a13 = _mm256_fmadd_ps(_mm256_set1_ps(bd3), b13, a13); \ + } \ + { \ + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(a0), _mm256_extractf128_ps(a0, 1)); \ + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(a10), _mm256_extractf128_ps(a10, 1)); \ + s_row0[OFF + 0] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); \ + s_row1[OFF + 0] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); \ + } \ + { \ + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(a1), _mm256_extractf128_ps(a1, 1)); \ + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(a11), _mm256_extractf128_ps(a11, 1)); \ + s_row0[OFF + 1] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); \ + s_row1[OFF + 1] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); \ + } \ + { \ + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(a2), _mm256_extractf128_ps(a2, 1)); \ + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(a12), _mm256_extractf128_ps(a12, 1)); \ + s_row0[OFF + 2] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); \ + s_row1[OFF + 2] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); \ + } \ + { \ + const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(a3), _mm256_extractf128_ps(a3, 1)); \ + const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(a13), _mm256_extractf128_ps(a13, 1)); \ + s_row0[OFF + 3] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); \ + s_row1[OFF + 3] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); \ + } \ + } + void ggml_gemm_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { #if defined( __AVX2__ ) || defined( __AVX512F__ ) { @@ -2303,131 +2374,13 @@ void ggml_gemm_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const block_q8_0x4 * a_ptr = (const block_q8_0x4 *)vy + (y * nb_q8_0); for (int row_base = 0; row_base < 4; row_base += 2) { - for (int x = 0; x < ncols8; ++x) { - const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; - - __m256 acc00 = _mm256_setzero_ps(), acc01 = _mm256_setzero_ps(), acc02 = _mm256_setzero_ps(), acc03 = _mm256_setzero_ps(); - __m256 acc04 = _mm256_setzero_ps(), acc05 = _mm256_setzero_ps(), acc06 = _mm256_setzero_ps(), acc07 = _mm256_setzero_ps(); - __m256 acc10 = _mm256_setzero_ps(), acc11 = _mm256_setzero_ps(), acc12 = _mm256_setzero_ps(), acc13 = _mm256_setzero_ps(); - __m256 acc14 = _mm256_setzero_ps(), acc15 = _mm256_setzero_ps(), acc16 = _mm256_setzero_ps(), acc17 = _mm256_setzero_ps(); - - for (int l = 0; l < nb; ++l) { - float bd[8]; - for (int c = 0; c < 8; ++c) { - bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); - } - - __m256 ba00 = _mm256_setzero_ps(), ba01 = _mm256_setzero_ps(), ba02 = _mm256_setzero_ps(), ba03 = _mm256_setzero_ps(); - __m256 ba04 = _mm256_setzero_ps(), ba05 = _mm256_setzero_ps(), ba06 = _mm256_setzero_ps(), ba07 = _mm256_setzero_ps(); - __m256 ba10 = _mm256_setzero_ps(), ba11 = _mm256_setzero_ps(), ba12 = _mm256_setzero_ps(), ba13 = _mm256_setzero_ps(); - __m256 ba14 = _mm256_setzero_ps(), ba15 = _mm256_setzero_ps(), ba16 = _mm256_setzero_ps(), ba17 = _mm256_setzero_ps(); - - const uint8_t * qs_base = (const uint8_t *)b_ptr[l].qs; - - for (int sb = 0; sb < 4; ++sb) { - const block_q8_0x4 * yb = &a_ptr[l * 4 + sb]; - const __m256i qs_vec = _mm256_loadu_si256((const __m256i *)(qs_base + sb * 32)); - - const __m256i sm0 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(1)), zero); - const __m256i sm1 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(2)), zero); - const __m256i sm2 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(4)), zero); - const __m256i sm3 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(8)), zero); - const __m256i sm4 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(16)), zero); - const __m256i sm5 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(32)), zero); - const __m256i sm6 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8(64)), zero); - const __m256i sm7 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, _mm256_set1_epi8((int8_t)-128)), zero); - - const __m256i rhs0 = _mm256_loadu_si256((const __m256i *)(yb->qs + (row_base + 0) * 32)); - const __m256 dy0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d[row_base + 0])); - ba00 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm0), sm0)), ones_16)), ba00); - ba01 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm1), sm1)), ones_16)), ba01); - ba02 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm2), sm2)), ones_16)), ba02); - ba03 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm3), sm3)), ones_16)), ba03); - ba04 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm4), sm4)), ones_16)), ba04); - ba05 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm5), sm5)), ones_16)), ba05); - ba06 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm6), sm6)), ones_16)), ba06); - ba07 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm7), sm7)), ones_16)), ba07); - - const __m256i rhs1 = _mm256_loadu_si256((const __m256i *)(yb->qs + (row_base + 1) * 32)); - const __m256 dy1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d[row_base + 1])); - ba10 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm0), sm0)), ones_16)), ba10); - ba11 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm1), sm1)), ones_16)), ba11); - ba12 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm2), sm2)), ones_16)), ba12); - ba13 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm3), sm3)), ones_16)), ba13); - ba14 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm4), sm4)), ones_16)), ba14); - ba15 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm5), sm5)), ones_16)), ba15); - ba16 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm6), sm6)), ones_16)), ba16); - ba17 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm7), sm7)), ones_16)), ba17); - } - - acc00 = _mm256_fmadd_ps(_mm256_set1_ps(bd[0]), ba00, acc00); - acc01 = _mm256_fmadd_ps(_mm256_set1_ps(bd[1]), ba01, acc01); - acc02 = _mm256_fmadd_ps(_mm256_set1_ps(bd[2]), ba02, acc02); - acc03 = _mm256_fmadd_ps(_mm256_set1_ps(bd[3]), ba03, acc03); - acc04 = _mm256_fmadd_ps(_mm256_set1_ps(bd[4]), ba04, acc04); - acc05 = _mm256_fmadd_ps(_mm256_set1_ps(bd[5]), ba05, acc05); - acc06 = _mm256_fmadd_ps(_mm256_set1_ps(bd[6]), ba06, acc06); - acc07 = _mm256_fmadd_ps(_mm256_set1_ps(bd[7]), ba07, acc07); - acc10 = _mm256_fmadd_ps(_mm256_set1_ps(bd[0]), ba10, acc10); - acc11 = _mm256_fmadd_ps(_mm256_set1_ps(bd[1]), ba11, acc11); - acc12 = _mm256_fmadd_ps(_mm256_set1_ps(bd[2]), ba12, acc12); - acc13 = _mm256_fmadd_ps(_mm256_set1_ps(bd[3]), ba13, acc13); - acc14 = _mm256_fmadd_ps(_mm256_set1_ps(bd[4]), ba14, acc14); - acc15 = _mm256_fmadd_ps(_mm256_set1_ps(bd[5]), ba15, acc15); - acc16 = _mm256_fmadd_ps(_mm256_set1_ps(bd[6]), ba16, acc16); - acc17 = _mm256_fmadd_ps(_mm256_set1_ps(bd[7]), ba17, acc17); - } + for (int x8 = 0; x8 < ncols8; ++x8) { + const block_q1_0x8 * b_ptr = vx_bi + (size_t)x8 * nb; + float * s_row0 = s + (y * 4 + row_base + 0) * bs + x8 * 8; + float * s_row1 = s + (y * 4 + row_base + 1) * bs + x8 * 8; - float * s_row0 = s + (y * 4 + row_base + 0) * bs + x * 8; - float * s_row1 = s + (y * 4 + row_base + 1) * bs + x * 8; - { - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc00), _mm256_extractf128_ps(acc00, 1)); - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc10), _mm256_extractf128_ps(acc10, 1)); - s_row0[0] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); - s_row1[0] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); - } - { - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc01), _mm256_extractf128_ps(acc01, 1)); - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc11), _mm256_extractf128_ps(acc11, 1)); - s_row0[1] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); - s_row1[1] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); - } - { - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc02), _mm256_extractf128_ps(acc02, 1)); - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc12), _mm256_extractf128_ps(acc12, 1)); - s_row0[2] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); - s_row1[2] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); - } - { - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc03), _mm256_extractf128_ps(acc03, 1)); - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc13), _mm256_extractf128_ps(acc13, 1)); - s_row0[3] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); - s_row1[3] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); - } - { - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc04), _mm256_extractf128_ps(acc04, 1)); - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc14), _mm256_extractf128_ps(acc14, 1)); - s_row0[4] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); - s_row1[4] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); - } - { - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc05), _mm256_extractf128_ps(acc05, 1)); - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc15), _mm256_extractf128_ps(acc15, 1)); - s_row0[5] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); - s_row1[5] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); - } - { - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc06), _mm256_extractf128_ps(acc06, 1)); - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc16), _mm256_extractf128_ps(acc16, 1)); - s_row0[6] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); - s_row1[6] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); - } - { - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(acc07), _mm256_extractf128_ps(acc07, 1)); - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(acc17), _mm256_extractf128_ps(acc17, 1)); - s_row0[7] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); - s_row1[7] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); - } + GGML_GEMM_Q1_0_4COL(1, 2, 4, 8, 0, 1, 2, 3, 0) + GGML_GEMM_Q1_0_4COL(16, 32, 64, -128, 4, 5, 6, 7, 4) } } } From a508eee41c8dc739bcc0a74f9b0f071bed29143c Mon Sep 17 00:00:00 2001 From: pl752 Date: Fri, 8 May 2026 12:51:28 +0500 Subject: [PATCH 5/5] Implement rot4 Q1.0 repack kernels with positive-sum GEMV optimization --- ggml/src/ggml-cpu/arch-fallback.h | 36 +- ggml/src/ggml-cpu/arch/x86/repack.cpp | 1002 ++++++++++++++++--------- ggml/src/ggml-cpu/repack.cpp | 152 ++-- ggml/src/ggml-cpu/repack.h | 12 +- 4 files changed, 756 insertions(+), 446 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 1e3a4521c5b..3163e194982 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -39,7 +39,6 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -56,7 +55,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 -#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -73,26 +72,24 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 -#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K -#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K -#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 // repack.cpp -#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 @@ -131,7 +128,6 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -148,7 +144,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 -#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -165,7 +161,7 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 -#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__loongarch64) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -181,7 +177,6 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -198,7 +193,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 -#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -215,14 +210,13 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 -#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 // repack.cpp -#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1 @@ -243,7 +237,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 -#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K @@ -259,7 +253,7 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 -#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__s390x__) // quants.c @@ -282,7 +276,6 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -299,7 +292,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 -#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -316,7 +309,7 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 -#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__wasm__) // quants.c #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 @@ -340,7 +333,6 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_quantize_mat_q8_0_4x32_generic ggml_quantize_mat_q8_0_4x32 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -357,7 +349,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 -#define ggml_gemv_q1_0_8x32_q8_0_generic ggml_gemv_q1_0_8x32_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -374,5 +366,5 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 -#define ggml_gemm_q1_0_8x32_q8_0_generic ggml_gemm_q1_0_8x32_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #endif diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index c783f0cc063..4a5834ad548 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -175,6 +175,74 @@ static inline __m256i mul_sum_i8_pairs_acc_int32x8(const __m256i acc, const __m2 } #endif +void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + +#if defined(__AVX2__) + const __m256 signBit = _mm256_set1_ps(-0.0f); + + for (int i = 0; i < nb; ++i) { + __m256 srcv[4][4]; + __m256 idv[4]; + + for (int r = 0; r < 4; ++r) { + const float *xr = x + r * k + i * 32; + + srcv[r][0] = _mm256_loadu_ps(xr + 0); + srcv[r][1] = _mm256_loadu_ps(xr + 8); + srcv[r][2] = _mm256_loadu_ps(xr + 16); + srcv[r][3] = _mm256_loadu_ps(xr + 24); + + __m256 maxAbs = _mm256_andnot_ps(signBit, srcv[r][0]); + maxAbs = _mm256_max_ps(maxAbs, _mm256_andnot_ps(signBit, srcv[r][1])); + maxAbs = _mm256_max_ps(maxAbs, _mm256_andnot_ps(signBit, srcv[r][2])); + maxAbs = _mm256_max_ps(maxAbs, _mm256_andnot_ps(signBit, srcv[r][3])); + + __m128 max4 = _mm_max_ps(_mm256_castps256_ps128(maxAbs), _mm256_extractf128_ps(maxAbs, 1)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + + const float amax = _mm_cvtss_f32(max4); + const float d = amax / 127.0f; + const float id = amax != 0.0f ? 127.0f / amax : 0.0f; + + y[i].d[r] = GGML_CPU_FP32_TO_FP16(d); + idv[r] = _mm256_set1_ps(id); + } + + for (int j = 0; j < 4; ++j) { + __m256 q0 = _mm256_mul_ps(srcv[0][j], idv[0]); + __m256 q1 = _mm256_mul_ps(srcv[1][j], idv[1]); + __m256 q2 = _mm256_mul_ps(srcv[2][j], idv[2]); + __m256 q3 = _mm256_mul_ps(srcv[3][j], idv[3]); + + q0 = _mm256_round_ps(q0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q1 = _mm256_round_ps(q1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q2 = _mm256_round_ps(q2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q3 = _mm256_round_ps(q3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + + __m256i i0 = _mm256_cvtps_epi32(q0); + __m256i i1 = _mm256_cvtps_epi32(q1); + __m256i i2 = _mm256_cvtps_epi32(q2); + __m256i i3 = _mm256_cvtps_epi32(q3); + + const __m256i p01 = _mm256_packs_epi32(i0, i1); + const __m256i p23 = _mm256_packs_epi32(i2, i3); + + const __m256i packed = _mm256_packs_epi16(p01, p23); + + _mm256_storeu_si256((__m256i *)(y[i].qs + j * 32), packed); + } + } +#else + ggml_quantize_mat_q8_0_4x4_generic(x, vy, k); +#endif +} + void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -287,75 +355,6 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #endif } -void ggml_quantize_mat_q8_0_4x32(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; - -#if defined(__AVX2__) - for (int i = 0; i < nb; i++) { - for (int r = 0; r < 4; r++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x + r * k + i * 32 ); - __m256 v1 = _mm256_loadu_ps( x + r * k + i * 32 + 8 ); - __m256 v2 = _mm256_loadu_ps( x + r * k + i * 32 + 16 ); - __m256 v3 = _mm256_loadu_ps( x + r * k + i * 32 + 24 ); - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - const float d = maxScalar / 127.f; - y[i].d[r] = GGML_CPU_FP32_TO_FP16(d); - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert to int - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - i0 = _mm256_packs_epi32( i0, i1 ); - i2 = _mm256_packs_epi32( i2, i3 ); - i0 = _mm256_packs_epi16( i0, i2 ); - - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - // Store row r contiguously - _mm256_storeu_si256((__m256i *)(y[i].qs + r * 32), i0); - } - } -#else - UNUSED(nb); - UNUSED(y); - ggml_quantize_mat_q8_0_4x32_generic(x, vy, k); -#endif -} - void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); @@ -1530,177 +1529,6 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined( __AVX2__ ) || defined( __AVX512F__ ) - { - assert (n % QK1_0 == 0); - assert (nc % 8 == 0); - assert (nr == 1); - - UNUSED(bs); - UNUSED(nr); - - const int nb = n / QK1_0; - const int ncols8 = nc / 8; - - const __m256i ones_8 = _mm256_set1_epi8(1); - const __m256i ones_16 = _mm256_set1_epi16(1); - const __m256i zero = _mm256_setzero_si256(); - - // Shuffle LUTs for columns 0-3: LUT[b & 0xF] = (b >> c) & 1 ? 0x00 : 0xFF - alignas(16) static const uint8_t sm_lut_c0[16] = { - 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, - 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00, 0xFF, 0x00 - }; - alignas(16) static const uint8_t sm_lut_c1[16] = { - 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, 0x00, 0x00, - 0xFF, 0xFF, 0x00, 0x00, 0xFF, 0xFF, 0x00, 0x00 - }; - alignas(16) static const uint8_t sm_lut_c2[16] = { - 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, - 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00 - }; - alignas(16) static const uint8_t sm_lut_c3[16] = { - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 - }; - - const __m256i lut0 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c0)); - const __m256i lut1 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c1)); - const __m256i lut2 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c2)); - const __m256i lut3 = _mm256_broadcastsi128_si256(_mm_load_si128((const __m128i *)sm_lut_c3)); - - // Column masks for columns 4-7 (AND+cmpeq path) - const __m256i col_mask_4 = _mm256_set1_epi8(16); - const __m256i col_mask_5 = _mm256_set1_epi8(32); - const __m256i col_mask_6 = _mm256_set1_epi8(64); - const __m256i col_mask_7 = _mm256_set1_epi8((int8_t)-128); - - const __m256i low_nibble_mask = _mm256_set1_epi8(0x0F); - - const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; - const block_q8_0 * a_ptr = (const block_q8_0 *)vy; - - for (int x = 0; x < ncols8; ++x) { - const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; - - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); - __m256 acc2 = _mm256_setzero_ps(); - __m256 acc3 = _mm256_setzero_ps(); - __m256 acc4 = _mm256_setzero_ps(); - __m256 acc5 = _mm256_setzero_ps(); - __m256 acc6 = _mm256_setzero_ps(); - __m256 acc7 = _mm256_setzero_ps(); - - for (int l = 0; l < nb; ++l) { - float bd[8]; - for (int c = 0; c < 8; ++c) { - bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); - } - - __m256 ba0 = _mm256_setzero_ps(); - __m256 ba1 = _mm256_setzero_ps(); - __m256 ba2 = _mm256_setzero_ps(); - __m256 ba3 = _mm256_setzero_ps(); - __m256 ba4 = _mm256_setzero_ps(); - __m256 ba5 = _mm256_setzero_ps(); - __m256 ba6 = _mm256_setzero_ps(); - __m256 ba7 = _mm256_setzero_ps(); - - const uint8_t * qs_base = (const uint8_t *)b_ptr[l].qs; - - for (int sb = 0; sb < 4; ++sb) { - const block_q8_0 * yb = &a_ptr[l * 4 + sb]; - const __m256i rhs = _mm256_loadu_si256((const __m256i *)yb->qs); - const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); - - const __m256i qs_vec = _mm256_loadu_si256((const __m256i *)(qs_base + sb * 32)); - - // Columns 0-3: shuffle LUT on low 7 bits - const __m256i qs_lo4 = _mm256_and_si256(qs_vec, low_nibble_mask); - const __m256i sm0 = _mm256_shuffle_epi8(lut0, qs_lo4); - const __m256i sm1 = _mm256_shuffle_epi8(lut1, qs_lo4); - const __m256i sm2 = _mm256_shuffle_epi8(lut2, qs_lo4); - const __m256i sm3 = _mm256_shuffle_epi8(lut3, qs_lo4); - - // Columns 4-7: AND + cmpeq - const __m256i sm4 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_4), zero); - const __m256i sm5 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_5), zero); - const __m256i sm6 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_6), zero); - const __m256i sm7 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, col_mask_7), zero); - - // Sign-flip and accumulate for all 8 columns - const __m256i sy0 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm0), sm0); - const __m256i sy1 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm1), sm1); - const __m256i sy2 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm2), sm2); - const __m256i sy3 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm3), sm3); - const __m256i sy4 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm4), sm4); - const __m256i sy5 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm5), sm5); - const __m256i sy6 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm6), sm6); - const __m256i sy7 = _mm256_sub_epi8(_mm256_xor_si256(rhs, sm7), sm7); - - ba0 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy0), ones_16)), ba0); - ba1 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy1), ones_16)), ba1); - ba2 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy2), ones_16)), ba2); - ba3 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy3), ones_16)), ba3); - ba4 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy4), ones_16)), ba4); - ba5 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy5), ones_16)), ba5); - ba6 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy6), ones_16)), ba6); - ba7 = _mm256_fmadd_ps(dy, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy7), ones_16)), ba7); - } - - acc0 = _mm256_fmadd_ps(_mm256_set1_ps(bd[0]), ba0, acc0); - acc1 = _mm256_fmadd_ps(_mm256_set1_ps(bd[1]), ba1, acc1); - acc2 = _mm256_fmadd_ps(_mm256_set1_ps(bd[2]), ba2, acc2); - acc3 = _mm256_fmadd_ps(_mm256_set1_ps(bd[3]), ba3, acc3); - acc4 = _mm256_fmadd_ps(_mm256_set1_ps(bd[4]), ba4, acc4); - acc5 = _mm256_fmadd_ps(_mm256_set1_ps(bd[5]), ba5, acc5); - acc6 = _mm256_fmadd_ps(_mm256_set1_ps(bd[6]), ba6, acc6); - acc7 = _mm256_fmadd_ps(_mm256_set1_ps(bd[7]), ba7, acc7); - } - - { - const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc0), _mm256_extractf128_ps(acc0, 1)); - s[x * 8 + 0] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); - } - { - const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc1), _mm256_extractf128_ps(acc1, 1)); - s[x * 8 + 1] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); - } - { - const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc2), _mm256_extractf128_ps(acc2, 1)); - s[x * 8 + 2] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); - } - { - const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc3), _mm256_extractf128_ps(acc3, 1)); - s[x * 8 + 3] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); - } - { - const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc4), _mm256_extractf128_ps(acc4, 1)); - s[x * 8 + 4] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); - } - { - const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc5), _mm256_extractf128_ps(acc5, 1)); - s[x * 8 + 5] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); - } - { - const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc6), _mm256_extractf128_ps(acc6, 1)); - s[x * 8 + 6] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); - } - { - const __m128 v = _mm_add_ps(_mm256_castps256_ps128(acc7), _mm256_extractf128_ps(acc7, 1)); - s[x * 8 + 7] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v, v), _mm_hadd_ps(v, v))); - } - } - - return; - } -#endif // defined( __AVX2__ ) || defined( __AVX512F__ ) - - ggml_gemv_q1_0_8x32_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -2279,117 +2107,619 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -#define GGML_GEMM_Q1_0_4COL(M0, M1, M2, M3, D0, D1, D2, D3, OFF) \ - { \ - const __m256i cm0 = _mm256_set1_epi8(M0); \ - const __m256i cm1 = _mm256_set1_epi8(M1); \ - const __m256i cm2 = _mm256_set1_epi8(M2); \ - const __m256i cm3 = _mm256_set1_epi8(M3); \ - __m256 a0 = _mm256_setzero_ps(), a1 = _mm256_setzero_ps(), a2 = _mm256_setzero_ps(), a3 = _mm256_setzero_ps(); \ - __m256 a10 = _mm256_setzero_ps(), a11 = _mm256_setzero_ps(), a12 = _mm256_setzero_ps(), a13 = _mm256_setzero_ps(); \ - for (int l = 0; l < nb; ++l) { \ - const float bd0 = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[D0]); \ - const float bd1 = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[D1]); \ - const float bd2 = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[D2]); \ - const float bd3 = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[D3]); \ - __m256 b0 = _mm256_setzero_ps(), b1 = _mm256_setzero_ps(), b2 = _mm256_setzero_ps(), b3 = _mm256_setzero_ps(); \ - __m256 b10 = _mm256_setzero_ps(), b11 = _mm256_setzero_ps(), b12 = _mm256_setzero_ps(), b13 = _mm256_setzero_ps(); \ - const uint8_t * qs_base = (const uint8_t *)b_ptr[l].qs; \ - for (int sb = 0; sb < 4; ++sb) { \ - const block_q8_0x4 * yb = &a_ptr[l * 4 + sb]; \ - const __m256i qs_vec = _mm256_loadu_si256((const __m256i *)(qs_base + sb * 32)); \ - const __m256i sm0 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, cm0), zero); \ - const __m256i sm1 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, cm1), zero); \ - const __m256i sm2 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, cm2), zero); \ - const __m256i sm3 = _mm256_cmpeq_epi8(_mm256_and_si256(qs_vec, cm3), zero); \ - const __m256i rhs0 = _mm256_loadu_si256((const __m256i *)(yb->qs + (row_base + 0) * 32)); \ - const __m256 dy0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d[row_base + 0])); \ - b0 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm0), sm0)), ones_16)), b0); \ - b1 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm1), sm1)), ones_16)), b1); \ - b2 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm2), sm2)), ones_16)), b2); \ - b3 = _mm256_fmadd_ps(dy0, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs0, sm3), sm3)), ones_16)), b3); \ - const __m256i rhs1 = _mm256_loadu_si256((const __m256i *)(yb->qs + (row_base + 1) * 32)); \ - const __m256 dy1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d[row_base + 1])); \ - b10 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm0), sm0)), ones_16)), b10); \ - b11 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm1), sm1)), ones_16)), b11); \ - b12 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm2), sm2)), ones_16)), b12); \ - b13 = _mm256_fmadd_ps(dy1, _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, _mm256_sub_epi8(_mm256_xor_si256(rhs1, sm3), sm3)), ones_16)), b13); \ - } \ - a0 = _mm256_fmadd_ps(_mm256_set1_ps(bd0), b0, a0); \ - a1 = _mm256_fmadd_ps(_mm256_set1_ps(bd1), b1, a1); \ - a2 = _mm256_fmadd_ps(_mm256_set1_ps(bd2), b2, a2); \ - a3 = _mm256_fmadd_ps(_mm256_set1_ps(bd3), b3, a3); \ - a10 = _mm256_fmadd_ps(_mm256_set1_ps(bd0), b10, a10); \ - a11 = _mm256_fmadd_ps(_mm256_set1_ps(bd1), b11, a11); \ - a12 = _mm256_fmadd_ps(_mm256_set1_ps(bd2), b12, a12); \ - a13 = _mm256_fmadd_ps(_mm256_set1_ps(bd3), b13, a13); \ - } \ - { \ - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(a0), _mm256_extractf128_ps(a0, 1)); \ - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(a10), _mm256_extractf128_ps(a10, 1)); \ - s_row0[OFF + 0] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); \ - s_row1[OFF + 0] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); \ - } \ - { \ - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(a1), _mm256_extractf128_ps(a1, 1)); \ - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(a11), _mm256_extractf128_ps(a11, 1)); \ - s_row0[OFF + 1] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); \ - s_row1[OFF + 1] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); \ - } \ - { \ - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(a2), _mm256_extractf128_ps(a2, 1)); \ - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(a12), _mm256_extractf128_ps(a12, 1)); \ - s_row0[OFF + 2] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); \ - s_row1[OFF + 2] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); \ - } \ - { \ - const __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(a3), _mm256_extractf128_ps(a3, 1)); \ - const __m128 v1 = _mm_add_ps(_mm256_castps256_ps128(a13), _mm256_extractf128_ps(a13, 1)); \ - s_row0[OFF + 3] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v0, v0), _mm_hadd_ps(v0, v0))); \ - s_row1[OFF + 3] = _mm_cvtss_f32(_mm_hadd_ps(_mm_hadd_ps(v1, v1), _mm_hadd_ps(v1, v1))); \ - } \ +// Rotated Q1.0 layout helpers and kernels + +static const int8_t q1_byte_sel[32] = { + 0,0,0,0, 0,0,0,0, + 1,1,1,1, 1,1,1,1, + 2,2,2,2, 2,2,2,2, + 3,3,3,3, 3,3,3,3, +}; + +static const int8_t q1_bit_mask[32] = { + 1, 2, 4, 8, 16, 32, 64, -128, + 1, 2, 4, 8, 16, 32, 64, -128, + 1, 2, 4, 8, 16, 32, 64, -128, + 1, 2, 4, 8, 16, 32, 64, -128, +}; + +static inline uint32_t ggml_load_u32(const void * p) { + return *(const uint32_t *)p; +} + +#if defined(__F16C__) +static inline __m256 ggml_cvt_fp16x4_to_fp32x2(const ggml_half * p) { + __m128i h = _mm_loadu_si64((const void *)p); + __m128 f = _mm_cvtph_ps(h); + return _mm256_insertf128_ps(_mm256_castps128_ps256(f), f, 1); +} +#endif + + + +static inline __m256i ggml_q1_negmask_8cols( + uint32_t qpack, + __m256i byte_sel, + __m256i bit_mask, + __m256i zero) { + const __m256i qsrc = _mm256_set1_epi32((int32_t)qpack); + const __m256i qbyte = _mm256_shuffle_epi8(qsrc, byte_sel); + const __m256i qbit = _mm256_and_si256(qbyte, bit_mask); + // 0xff where q1 bit is zero, 0x00 where q1 bit is set + return _mm256_cmpeq_epi8(qbit, zero); +} + +// xor-sub pairs: accumulate in int16 (safe: 8*2*127=2032 < 32767) +static inline __m256i ggml_q1_dotpairs_xorsub_i16( + __m256i negmask, + __m256i yrep, + __m256i ones_8) { + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(yrep, negmask), negmask); + return _mm256_maddubs_epi16(ones_8, sy); +} + +// Bit extraction as 0/1 for positive-sum formulation: dot = 2*pos - total +static inline __m256i ggml_q1_bit01_8cols( + uint32_t qpack, + __m256i byte_sel, + __m256i bit_mask, + __m256i ones_8, + __m256i zero) { + const __m256i qsrc = _mm256_set1_epi32((int32_t) qpack); + const __m256i qbyte = _mm256_shuffle_epi8(qsrc, byte_sel); + const __m256i qbit = _mm256_and_si256(qbyte, bit_mask); + const __m256i is_zero = _mm256_cmpeq_epi8(qbit, zero); + return _mm256_andnot_si256(is_zero, ones_8); +} + +static inline __m256i ggml_q8_totalpairs_i16( + __m256i yrep, + __m256i ones_8) { + return _mm256_maddubs_epi16(ones_8, yrep); +} + +static inline __m256i ggml_q1_pospairs_i16( + __m256i bit01, + __m256i yrep) { + return _mm256_maddubs_epi16(bit01, yrep); +} + +void ggml_gemv_q1_0_8x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) && defined(__FMA__) + assert(n % QK1_0 == 0); + assert(nc % 16 == 0); + assert(nr == 1); + + UNUSED(bs); + UNUSED(nr); + + const int nb = n / QK1_0; + + const block_q1_0x8 * GGML_RESTRICT x = (const block_q1_0x8 *) vx; + const block_q8_0 * GGML_RESTRICT y = (const block_q8_0 *) vy; + + const __m256i byte_sel = _mm256_loadu_si256((const __m256i *) q1_byte_sel); + const __m256i bit_mask = _mm256_loadu_si256((const __m256i *) q1_bit_mask); + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i zero = _mm256_setzero_si256(); + + // 48-column main loop (AVX-512 register file) / 32-column (AVX2) +#if defined(__AVX512VL__) + { + const int ncols48 = nc / 48; + const int nc_tail = nc - ncols48 * 48; + + for (int cx = 0; cx < ncols48; ++cx) { + const block_q1_0x8 * GGML_RESTRICT x0 = x + (size_t)(cx * 6) * nb; + const block_q1_0x8 * GGML_RESTRICT x1 = x0 + nb; + const block_q1_0x8 * GGML_RESTRICT x2 = x1 + nb; + const block_q1_0x8 * GGML_RESTRICT x3 = x2 + nb; + const block_q1_0x8 * GGML_RESTRICT x4 = x3 + nb; + const block_q1_0x8 * GGML_RESTRICT x5 = x4 + nb; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + __m256 acc4 = _mm256_setzero_ps(); + __m256 acc5 = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmp0 = _mm256_setzero_ps(); + __m256 tmp1 = _mm256_setzero_ps(); + __m256 tmp2 = _mm256_setzero_ps(); + __m256 tmp3 = _mm256_setzero_ps(); + __m256 tmp4 = _mm256_setzero_ps(); + __m256 tmp5 = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * GGML_RESTRICT yb = &y[l * 4 + sb]; + + __m256i total16 = _mm256_setzero_si256(); + __m256i pos0 = _mm256_setzero_si256(); + __m256i pos1 = _mm256_setzero_si256(); + __m256i pos2 = _mm256_setzero_si256(); + __m256i pos3 = _mm256_setzero_si256(); + __m256i pos4 = _mm256_setzero_si256(); + __m256i pos5 = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const __m256i yrep = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 4)); + + total16 = _mm256_add_epi16(total16, ggml_q8_totalpairs_i16(yrep, ones_8)); + + const uint32_t qpack0 = ggml_load_u32(x0[l].qs + (sb * 8 + g) * 4); + const __m256i bit0 = ggml_q1_bit01_8cols(qpack0, byte_sel, bit_mask, ones_8, zero); + pos0 = _mm256_add_epi16(pos0, ggml_q1_pospairs_i16(bit0, yrep)); + + const uint32_t qpack1 = ggml_load_u32(x1[l].qs + (sb * 8 + g) * 4); + const __m256i bit1 = ggml_q1_bit01_8cols(qpack1, byte_sel, bit_mask, ones_8, zero); + pos1 = _mm256_add_epi16(pos1, ggml_q1_pospairs_i16(bit1, yrep)); + + const uint32_t qpack2 = ggml_load_u32(x2[l].qs + (sb * 8 + g) * 4); + const __m256i bit2 = ggml_q1_bit01_8cols(qpack2, byte_sel, bit_mask, ones_8, zero); + pos2 = _mm256_add_epi16(pos2, ggml_q1_pospairs_i16(bit2, yrep)); + + const uint32_t qpack3 = ggml_load_u32(x3[l].qs + (sb * 8 + g) * 4); + const __m256i bit3 = ggml_q1_bit01_8cols(qpack3, byte_sel, bit_mask, ones_8, zero); + pos3 = _mm256_add_epi16(pos3, ggml_q1_pospairs_i16(bit3, yrep)); + + const uint32_t qpack4 = ggml_load_u32(x4[l].qs + (sb * 8 + g) * 4); + const __m256i bit4 = ggml_q1_bit01_8cols(qpack4, byte_sel, bit_mask, ones_8, zero); + pos4 = _mm256_add_epi16(pos4, ggml_q1_pospairs_i16(bit4, yrep)); + + const uint32_t qpack5 = ggml_load_u32(x5[l].qs + (sb * 8 + g) * 4); + const __m256i bit5 = ggml_q1_bit01_8cols(qpack5, byte_sel, bit_mask, ones_8, zero); + pos5 = _mm256_add_epi16(pos5, ggml_q1_pospairs_i16(bit5, yrep)); + } + + const __m256i dot0_16 = _mm256_sub_epi16(_mm256_add_epi16(pos0, pos0), total16); + const __m256i dot1_16 = _mm256_sub_epi16(_mm256_add_epi16(pos1, pos1), total16); + const __m256i dot2_16 = _mm256_sub_epi16(_mm256_add_epi16(pos2, pos2), total16); + const __m256i dot3_16 = _mm256_sub_epi16(_mm256_add_epi16(pos3, pos3), total16); + const __m256i dot4_16 = _mm256_sub_epi16(_mm256_add_epi16(pos4, pos4), total16); + const __m256i dot5_16 = _mm256_sub_epi16(_mm256_add_epi16(pos5, pos5), total16); + + const __m256 ydv = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + tmp0 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot0_16, ones_16)), ydv, tmp0); + tmp1 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot1_16, ones_16)), ydv, tmp1); + tmp2 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot2_16, ones_16)), ydv, tmp2); + tmp3 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot3_16, ones_16)), ydv, tmp3); + tmp4 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot4_16, ones_16)), ydv, tmp4); + tmp5 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot5_16, ones_16)), ydv, tmp5); + } + + const __m256 bdv0 = GGML_F32Cx8_LOAD(x0[l].d); + const __m256 bdv1 = GGML_F32Cx8_LOAD(x1[l].d); + const __m256 bdv2 = GGML_F32Cx8_LOAD(x2[l].d); + const __m256 bdv3 = GGML_F32Cx8_LOAD(x3[l].d); + const __m256 bdv4 = GGML_F32Cx8_LOAD(x4[l].d); + const __m256 bdv5 = GGML_F32Cx8_LOAD(x5[l].d); + + acc0 = _mm256_fmadd_ps(tmp0, bdv0, acc0); + acc1 = _mm256_fmadd_ps(tmp1, bdv1, acc1); + acc2 = _mm256_fmadd_ps(tmp2, bdv2, acc2); + acc3 = _mm256_fmadd_ps(tmp3, bdv3, acc3); + acc4 = _mm256_fmadd_ps(tmp4, bdv4, acc4); + acc5 = _mm256_fmadd_ps(tmp5, bdv5, acc5); + } + + float *sbase = s + cx * 48; + _mm256_storeu_ps(sbase + 0, acc0); + _mm256_storeu_ps(sbase + 8, acc1); + _mm256_storeu_ps(sbase + 16, acc2); + _mm256_storeu_ps(sbase + 24, acc3); + _mm256_storeu_ps(sbase + 32, acc4); + _mm256_storeu_ps(sbase + 40, acc5); + } + + // 32-column tail + if (nc_tail >= 32) { + const block_q1_0x8 * GGML_RESTRICT xa = x + (size_t)(ncols48 * 6) * nb; + const block_q1_0x8 * GGML_RESTRICT xb = xa + nb; + const block_q1_0x8 * GGML_RESTRICT xc = xb + nb; + const block_q1_0x8 * GGML_RESTRICT xd = xc + nb; + + __m256 acca = _mm256_setzero_ps(); + __m256 accb = _mm256_setzero_ps(); + __m256 accc = _mm256_setzero_ps(); + __m256 accd = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmpa = _mm256_setzero_ps(); + __m256 tmpb = _mm256_setzero_ps(); + __m256 tmpc = _mm256_setzero_ps(); + __m256 tmpd = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * GGML_RESTRICT yb = &y[l * 4 + sb]; + + __m256i total16 = _mm256_setzero_si256(); + __m256i posa = _mm256_setzero_si256(); + __m256i posb = _mm256_setzero_si256(); + __m256i posc = _mm256_setzero_si256(); + __m256i posd = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const __m256i yrep = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 4)); + + total16 = _mm256_add_epi16(total16, ggml_q8_totalpairs_i16(yrep, ones_8)); + + const uint32_t qpacka = ggml_load_u32(xa[l].qs + (sb * 8 + g) * 4); + const __m256i bita = ggml_q1_bit01_8cols(qpacka, byte_sel, bit_mask, ones_8, zero); + posa = _mm256_add_epi16(posa, ggml_q1_pospairs_i16(bita, yrep)); + + const uint32_t qpackb = ggml_load_u32(xb[l].qs + (sb * 8 + g) * 4); + const __m256i bitb = ggml_q1_bit01_8cols(qpackb, byte_sel, bit_mask, ones_8, zero); + posb = _mm256_add_epi16(posb, ggml_q1_pospairs_i16(bitb, yrep)); + + const uint32_t qpackc = ggml_load_u32(xc[l].qs + (sb * 8 + g) * 4); + const __m256i bitc = ggml_q1_bit01_8cols(qpackc, byte_sel, bit_mask, ones_8, zero); + posc = _mm256_add_epi16(posc, ggml_q1_pospairs_i16(bitc, yrep)); + + const uint32_t qpackd = ggml_load_u32(xd[l].qs + (sb * 8 + g) * 4); + const __m256i bitd = ggml_q1_bit01_8cols(qpackd, byte_sel, bit_mask, ones_8, zero); + posd = _mm256_add_epi16(posd, ggml_q1_pospairs_i16(bitd, yrep)); + } + + const __m256i dota_16 = _mm256_sub_epi16(_mm256_add_epi16(posa, posa), total16); + const __m256i dotb_16 = _mm256_sub_epi16(_mm256_add_epi16(posb, posb), total16); + const __m256i dotc_16 = _mm256_sub_epi16(_mm256_add_epi16(posc, posc), total16); + const __m256i dotd_16 = _mm256_sub_epi16(_mm256_add_epi16(posd, posd), total16); + + const __m256 ydv = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + tmpa = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dota_16, ones_16)), ydv, tmpa); + tmpb = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dotb_16, ones_16)), ydv, tmpb); + tmpc = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dotc_16, ones_16)), ydv, tmpc); + tmpd = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dotd_16, ones_16)), ydv, tmpd); + } + + const __m256 bdva = GGML_F32Cx8_LOAD(xa[l].d); + const __m256 bdvb = GGML_F32Cx8_LOAD(xb[l].d); + const __m256 bdvc = GGML_F32Cx8_LOAD(xc[l].d); + const __m256 bdvd = GGML_F32Cx8_LOAD(xd[l].d); + + acca = _mm256_fmadd_ps(tmpa, bdva, acca); + accb = _mm256_fmadd_ps(tmpb, bdvb, accb); + accc = _mm256_fmadd_ps(tmpc, bdvc, accc); + accd = _mm256_fmadd_ps(tmpd, bdvd, accd); + } + + float *sbase = s + ncols48 * 48; + _mm256_storeu_ps(sbase + 0, acca); + _mm256_storeu_ps(sbase + 8, accb); + _mm256_storeu_ps(sbase + 16, accc); + _mm256_storeu_ps(sbase + 24, accd); + } + + // 16-column tail + if (nc_tail == 16) { + const block_q1_0x8 * GGML_RESTRICT xa = x + (size_t)(ncols48 * 6) * nb; + const block_q1_0x8 * GGML_RESTRICT xb = xa + nb; + + __m256 acca = _mm256_setzero_ps(); + __m256 accb = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmpa = _mm256_setzero_ps(); + __m256 tmpb = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * GGML_RESTRICT yb = &y[l * 4 + sb]; + + __m256i total16 = _mm256_setzero_si256(); + __m256i posa = _mm256_setzero_si256(); + __m256i posb = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const __m256i yrep = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 4)); + + total16 = _mm256_add_epi16(total16, ggml_q8_totalpairs_i16(yrep, ones_8)); + + const uint32_t qpacka = ggml_load_u32(xa[l].qs + (sb * 8 + g) * 4); + const __m256i bita = ggml_q1_bit01_8cols(qpacka, byte_sel, bit_mask, ones_8, zero); + posa = _mm256_add_epi16(posa, ggml_q1_pospairs_i16(bita, yrep)); + + const uint32_t qpackb = ggml_load_u32(xb[l].qs + (sb * 8 + g) * 4); + const __m256i bitb = ggml_q1_bit01_8cols(qpackb, byte_sel, bit_mask, ones_8, zero); + posb = _mm256_add_epi16(posb, ggml_q1_pospairs_i16(bitb, yrep)); + } + + const __m256i dota_16 = _mm256_sub_epi16(_mm256_add_epi16(posa, posa), total16); + const __m256i dotb_16 = _mm256_sub_epi16(_mm256_add_epi16(posb, posb), total16); + + const __m256 ydv = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + tmpa = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dota_16, ones_16)), ydv, tmpa); + tmpb = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dotb_16, ones_16)), ydv, tmpb); + } + + const __m256 bdva = GGML_F32Cx8_LOAD(xa[l].d); + const __m256 bdvb = GGML_F32Cx8_LOAD(xb[l].d); + + acca = _mm256_fmadd_ps(tmpa, bdva, acca); + accb = _mm256_fmadd_ps(tmpb, bdvb, accb); + } + + float *sbase = s + ncols48 * 48; + _mm256_storeu_ps(sbase, acca); + _mm256_storeu_ps(sbase + 8, accb); + } } +#else + { + const int ncols32 = nc / 32; + const int nc_tail = nc - ncols32 * 32; -void ggml_gemm_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined( __AVX2__ ) || defined( __AVX512F__ ) - { - assert (n % QK1_0 == 0); - assert (nr % 4 == 0); - assert (nc % 8 == 0); + for (int cx = 0; cx < ncols32; ++cx) { + const block_q1_0x8 * GGML_RESTRICT x0 = x + (size_t)(cx * 4) * nb; + const block_q1_0x8 * GGML_RESTRICT x1 = x0 + nb; + const block_q1_0x8 * GGML_RESTRICT x2 = x1 + nb; + const block_q1_0x8 * GGML_RESTRICT x3 = x2 + nb; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmp0 = _mm256_setzero_ps(); + __m256 tmp1 = _mm256_setzero_ps(); + __m256 tmp2 = _mm256_setzero_ps(); + __m256 tmp3 = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * GGML_RESTRICT yb = &y[l * 4 + sb]; - UNUSED(bs); + __m256i total16 = _mm256_setzero_si256(); + __m256i pos0 = _mm256_setzero_si256(); + __m256i pos1 = _mm256_setzero_si256(); + __m256i pos2 = _mm256_setzero_si256(); + __m256i pos3 = _mm256_setzero_si256(); - const int nb = n / QK1_0; - const int nb_q8_0 = n / QK8_0; - const int ncols8 = nc / 8; - const int nrows4 = nr / 4; + for (int g = 0; g < 8; ++g) { + const __m256i yrep = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 4)); - const __m256i ones_8 = _mm256_set1_epi8(1); - const __m256i ones_16 = _mm256_set1_epi16(1); - const __m256i zero = _mm256_setzero_si256(); + total16 = _mm256_add_epi16(total16, ggml_q8_totalpairs_i16(yrep, ones_8)); - const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; + const uint32_t qpack0 = ggml_load_u32(x0[l].qs + (sb * 8 + g) * 4); + const __m256i bit0 = ggml_q1_bit01_8cols(qpack0, byte_sel, bit_mask, ones_8, zero); + pos0 = _mm256_add_epi16(pos0, ggml_q1_pospairs_i16(bit0, yrep)); - for (int y = 0; y < nrows4; ++y) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *)vy + (y * nb_q8_0); + const uint32_t qpack1 = ggml_load_u32(x1[l].qs + (sb * 8 + g) * 4); + const __m256i bit1 = ggml_q1_bit01_8cols(qpack1, byte_sel, bit_mask, ones_8, zero); + pos1 = _mm256_add_epi16(pos1, ggml_q1_pospairs_i16(bit1, yrep)); - for (int row_base = 0; row_base < 4; row_base += 2) { - for (int x8 = 0; x8 < ncols8; ++x8) { - const block_q1_0x8 * b_ptr = vx_bi + (size_t)x8 * nb; - float * s_row0 = s + (y * 4 + row_base + 0) * bs + x8 * 8; - float * s_row1 = s + (y * 4 + row_base + 1) * bs + x8 * 8; + const uint32_t qpack2 = ggml_load_u32(x2[l].qs + (sb * 8 + g) * 4); + const __m256i bit2 = ggml_q1_bit01_8cols(qpack2, byte_sel, bit_mask, ones_8, zero); + pos2 = _mm256_add_epi16(pos2, ggml_q1_pospairs_i16(bit2, yrep)); + + const uint32_t qpack3 = ggml_load_u32(x3[l].qs + (sb * 8 + g) * 4); + const __m256i bit3 = ggml_q1_bit01_8cols(qpack3, byte_sel, bit_mask, ones_8, zero); + pos3 = _mm256_add_epi16(pos3, ggml_q1_pospairs_i16(bit3, yrep)); + } - GGML_GEMM_Q1_0_4COL(1, 2, 4, 8, 0, 1, 2, 3, 0) - GGML_GEMM_Q1_0_4COL(16, 32, 64, -128, 4, 5, 6, 7, 4) + const __m256i dot0_16 = _mm256_sub_epi16(_mm256_add_epi16(pos0, pos0), total16); + const __m256i dot1_16 = _mm256_sub_epi16(_mm256_add_epi16(pos1, pos1), total16); + const __m256i dot2_16 = _mm256_sub_epi16(_mm256_add_epi16(pos2, pos2), total16); + const __m256i dot3_16 = _mm256_sub_epi16(_mm256_add_epi16(pos3, pos3), total16); + + const __m256 ydv = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + tmp0 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot0_16, ones_16)), ydv, tmp0); + tmp1 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot1_16, ones_16)), ydv, tmp1); + tmp2 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot2_16, ones_16)), ydv, tmp2); + tmp3 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot3_16, ones_16)), ydv, tmp3); } + + const __m256 bdv0 = GGML_F32Cx8_LOAD(x0[l].d); + const __m256 bdv1 = GGML_F32Cx8_LOAD(x1[l].d); + const __m256 bdv2 = GGML_F32Cx8_LOAD(x2[l].d); + const __m256 bdv3 = GGML_F32Cx8_LOAD(x3[l].d); + + acc0 = _mm256_fmadd_ps(tmp0, bdv0, acc0); + acc1 = _mm256_fmadd_ps(tmp1, bdv1, acc1); + acc2 = _mm256_fmadd_ps(tmp2, bdv2, acc2); + acc3 = _mm256_fmadd_ps(tmp3, bdv3, acc3); } + + float *sbase = s + cx * 32; + _mm256_storeu_ps(sbase + 0, acc0); + _mm256_storeu_ps(sbase + 8, acc1); + _mm256_storeu_ps(sbase + 16, acc2); + _mm256_storeu_ps(sbase + 24, acc3); } - return; + // 16-column tail + if (nc_tail >= 16) { + const block_q1_0x8 * GGML_RESTRICT xa = x + (size_t)(ncols32 * 4) * nb; + const block_q1_0x8 * GGML_RESTRICT xb = xa + nb; + + __m256 acca = _mm256_setzero_ps(); + __m256 accb = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmpa = _mm256_setzero_ps(); + __m256 tmpb = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * GGML_RESTRICT yb = &y[l * 4 + sb]; + + __m256i total16 = _mm256_setzero_si256(); + __m256i posa = _mm256_setzero_si256(); + __m256i posb = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const __m256i yrep = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 4)); + + total16 = _mm256_add_epi16(total16, ggml_q8_totalpairs_i16(yrep, ones_8)); + + const uint32_t qpacka = ggml_load_u32(xa[l].qs + (sb * 8 + g) * 4); + const __m256i bita = ggml_q1_bit01_8cols(qpacka, byte_sel, bit_mask, ones_8, zero); + posa = _mm256_add_epi16(posa, ggml_q1_pospairs_i16(bita, yrep)); + + const uint32_t qpackb = ggml_load_u32(xb[l].qs + (sb * 8 + g) * 4); + const __m256i bitb = ggml_q1_bit01_8cols(qpackb, byte_sel, bit_mask, ones_8, zero); + posb = _mm256_add_epi16(posb, ggml_q1_pospairs_i16(bitb, yrep)); + } + + const __m256i dota_16 = _mm256_sub_epi16(_mm256_add_epi16(posa, posa), total16); + const __m256i dotb_16 = _mm256_sub_epi16(_mm256_add_epi16(posb, posb), total16); + + const __m256 ydv = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + tmpa = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dota_16, ones_16)), ydv, tmpa); + tmpb = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dotb_16, ones_16)), ydv, tmpb); + } + + const __m256 bdva = GGML_F32Cx8_LOAD(xa[l].d); + const __m256 bdvb = GGML_F32Cx8_LOAD(xb[l].d); + + acca = _mm256_fmadd_ps(tmpa, bdva, acca); + accb = _mm256_fmadd_ps(tmpb, bdvb, accb); + } + + _mm256_storeu_ps(s + ncols32 * 32, acca); + _mm256_storeu_ps(s + ncols32 * 32 + 8, accb); + } } -#endif // defined( __AVX2__ ) || defined( __AVX512F__ ) +#endif // __AVX512VL__ + + return; +#else + ggml_gemv_q1_0_8x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} - ggml_gemm_q1_0_8x32_q8_0_generic(n, s, bs, vx, vy, nr, nc); +void ggml_gemm_q1_0_8x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) && defined(__FMA__) + assert(n % QK1_0 == 0); + assert(nr % 4 == 0); + assert(nc % 16 == 0); + + const int nb = n / QK1_0; + const int nb_q8_0 = n / QK8_0; + const int ncols16 = nc / 16; + const int nrows4 = nr / 4; + + const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; + + const __m256i byte_sel = _mm256_loadu_si256((const __m256i *) q1_byte_sel); + const __m256i bit_mask = _mm256_loadu_si256((const __m256i *) q1_bit_mask); + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i zero = _mm256_setzero_si256(); + + for (int y = 0; y < nrows4; ++y) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *)vy + (y * nb_q8_0); + + for (int x = 0; x < ncols16; ++x) { + const block_q1_0x8 * ba = vx_bi + (size_t)(x * 2) * nb; + const block_q1_0x8 * bb = ba + nb; + + __m256 acc0a = _mm256_setzero_ps(); + __m256 acc1a = _mm256_setzero_ps(); + __m256 acc2a = _mm256_setzero_ps(); + __m256 acc3a = _mm256_setzero_ps(); + __m256 acc0b = _mm256_setzero_ps(); + __m256 acc1b = _mm256_setzero_ps(); + __m256 acc2b = _mm256_setzero_ps(); + __m256 acc3b = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmp0a = _mm256_setzero_ps(); + __m256 tmp1a = _mm256_setzero_ps(); + __m256 tmp2a = _mm256_setzero_ps(); + __m256 tmp3a = _mm256_setzero_ps(); + __m256 tmp0b = _mm256_setzero_ps(); + __m256 tmp1b = _mm256_setzero_ps(); + __m256 tmp2b = _mm256_setzero_ps(); + __m256 tmp3b = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0x4 * yb = &a_ptr[l * 4 + sb]; + + __m256i isum0a16 = _mm256_setzero_si256(); + __m256i isum1a16 = _mm256_setzero_si256(); + __m256i isum2a16 = _mm256_setzero_si256(); + __m256i isum3a16 = _mm256_setzero_si256(); + __m256i isum0b16 = _mm256_setzero_si256(); + __m256i isum1b16 = _mm256_setzero_si256(); + __m256i isum2b16 = _mm256_setzero_si256(); + __m256i isum3b16 = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const uint32_t qpacka = ggml_load_u32(ba[l].qs + (sb * 8 + g) * 4); + const __m256i nega = ggml_q1_negmask_8cols(qpacka, byte_sel, bit_mask, zero); + const uint32_t qpackb = ggml_load_u32(bb[l].qs + (sb * 8 + g) * 4); + const __m256i negb = ggml_q1_negmask_8cols(qpackb, byte_sel, bit_mask, zero); + + const __m256i yrep0 = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 16 + 0 * 4)); + const __m256i yrep1 = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 16 + 1 * 4)); + const __m256i yrep2 = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 16 + 2 * 4)); + const __m256i yrep3 = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 16 + 3 * 4)); + + isum0a16 = _mm256_add_epi16(isum0a16, ggml_q1_dotpairs_xorsub_i16(nega, yrep0, ones_8)); + isum1a16 = _mm256_add_epi16(isum1a16, ggml_q1_dotpairs_xorsub_i16(nega, yrep1, ones_8)); + isum2a16 = _mm256_add_epi16(isum2a16, ggml_q1_dotpairs_xorsub_i16(nega, yrep2, ones_8)); + isum3a16 = _mm256_add_epi16(isum3a16, ggml_q1_dotpairs_xorsub_i16(nega, yrep3, ones_8)); + + isum0b16 = _mm256_add_epi16(isum0b16, ggml_q1_dotpairs_xorsub_i16(negb, yrep0, ones_8)); + isum1b16 = _mm256_add_epi16(isum1b16, ggml_q1_dotpairs_xorsub_i16(negb, yrep1, ones_8)); + isum2b16 = _mm256_add_epi16(isum2b16, ggml_q1_dotpairs_xorsub_i16(negb, yrep2, ones_8)); + isum3b16 = _mm256_add_epi16(isum3b16, ggml_q1_dotpairs_xorsub_i16(negb, yrep3, ones_8)); + } + + const __m256i isum0a = _mm256_madd_epi16(isum0a16, ones_16); + const __m256i isum1a = _mm256_madd_epi16(isum1a16, ones_16); + const __m256i isum2a = _mm256_madd_epi16(isum2a16, ones_16); + const __m256i isum3a = _mm256_madd_epi16(isum3a16, ones_16); + const __m256i isum0b = _mm256_madd_epi16(isum0b16, ones_16); + const __m256i isum1b = _mm256_madd_epi16(isum1b16, ones_16); + const __m256i isum2b = _mm256_madd_epi16(isum2b16, ones_16); + const __m256i isum3b = _mm256_madd_epi16(isum3b16, ones_16); + + const __m256 yd_all = ggml_cvt_fp16x4_to_fp32x2(yb->d); + const __m256 yd0 = _mm256_permute_ps(yd_all, 0); + const __m256 yd1 = _mm256_permute_ps(yd_all, 85); + const __m256 yd2 = _mm256_permute_ps(yd_all, 170); + const __m256 yd3 = _mm256_permute_ps(yd_all, 255); + tmp0a = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum0a), yd0, tmp0a); + tmp1a = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum1a), yd1, tmp1a); + tmp2a = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum2a), yd2, tmp2a); + tmp3a = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum3a), yd3, tmp3a); + tmp0b = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum0b), yd0, tmp0b); + tmp1b = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum1b), yd1, tmp1b); + tmp2b = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum2b), yd2, tmp2b); + tmp3b = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum3b), yd3, tmp3b); + } + + const __m256 bdva = GGML_F32Cx8_LOAD(ba[l].d); + const __m256 bdvb = GGML_F32Cx8_LOAD(bb[l].d); + + acc0a = _mm256_fmadd_ps(tmp0a, bdva, acc0a); + acc1a = _mm256_fmadd_ps(tmp1a, bdva, acc1a); + acc2a = _mm256_fmadd_ps(tmp2a, bdva, acc2a); + acc3a = _mm256_fmadd_ps(tmp3a, bdva, acc3a); + acc0b = _mm256_fmadd_ps(tmp0b, bdvb, acc0b); + acc1b = _mm256_fmadd_ps(tmp1b, bdvb, acc1b); + acc2b = _mm256_fmadd_ps(tmp2b, bdvb, acc2b); + acc3b = _mm256_fmadd_ps(tmp3b, bdvb, acc3b); + } + + _mm256_storeu_ps(s + (y * 4 + 0) * bs + x * 16, acc0a); + _mm256_storeu_ps(s + (y * 4 + 1) * bs + x * 16, acc1a); + _mm256_storeu_ps(s + (y * 4 + 2) * bs + x * 16, acc2a); + _mm256_storeu_ps(s + (y * 4 + 3) * bs + x * 16, acc3a); + _mm256_storeu_ps(s + (y * 4 + 0) * bs + x * 16 + 8, acc0b); + _mm256_storeu_ps(s + (y * 4 + 1) * bs + x * 16 + 8, acc1b); + _mm256_storeu_ps(s + (y * 4 + 2) * bs + x * 16 + 8, acc2b); + _mm256_storeu_ps(s + (y * 4 + 3) * bs + x * 16 + 8, acc3b); + } + } + + return; +#else + ggml_gemm_q1_0_8x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 7b0a0a1a8a0..8c87c603efc 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -208,42 +208,6 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } -void ggml_quantize_mat_q8_0_4x32_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; - - // scalar - float srcv[4][QK8_0]; - float id[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; - amax = MAX(amax, fabsf(srcv[row_iter][j])); - } - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d); - } - - // Store each row's 32 bytes contiguously - for (int r = 0; r < 4; r++) { - for (int j = 0; j < QK8_0; j++) { - float x0 = srcv[r][j] * id[r]; - y[i].qs[r * QK8_0 + j] = roundf(x0); - } - } - } -} - void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); @@ -375,12 +339,6 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<32, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { - assert(nrow == 4); - UNUSED(nrow); - ggml_quantize_mat_q8_0_4x32(x, vy, n_per_row); -} - #if defined __riscv_zvfh template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); @@ -926,7 +884,7 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemv_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q1_0_8x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert (n % QK1_0 == 0); assert (nc % 8 == 0); assert (nr == 1); @@ -955,16 +913,28 @@ void ggml_gemv_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, for (int sb = 0; sb < 4; ++sb) { const block_q8_0 * yb = &a_ptr[l * 4 + sb]; const float dy = GGML_CPU_FP16_TO_FP32(yb->d); - - const uint8_t * qs = (const uint8_t *)b_ptr[l].qs + sb * 32; const int8_t * y = yb->qs; - for (int c = 0; c < 8; ++c) { - int sumi = 0; - for (int i = 0; i < QK8_0; ++i) { - sumi += ((qs[i] >> c) & 1) ? y[i] : -y[i]; + for (int g = 0; g < 8; ++g) { + const uint8_t * qs_row = (const uint8_t *)b_ptr[l].qs + (sb * 8 + g) * 4; + + for (int pair = 0; pair < 4; ++pair) { + uint8_t byte = qs_row[pair]; + int col_even = 2 * pair; + int col_odd = 2 * pair + 1; + int sumi_even = 0; + int sumi_odd = 0; + for (int t = 0; t < 4; ++t) { + int k = g * 4 + t; + int yv = y[k]; + if ((byte >> t) & 1) sumi_even += yv; + else sumi_even -= yv; + if ((byte >> (t + 4)) & 1) sumi_odd += yv; + else sumi_odd -= yv; + } + block_acc[col_even] += dy * (float)sumi_even; + block_acc[col_odd] += dy * (float)sumi_odd; } - block_acc[c] += dy * (float)sumi; } } @@ -972,9 +942,9 @@ void ggml_gemv_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, acc[c] += bd[c] * block_acc[c]; } } - + static_assert(sizeof(acc) == 32); - memcpy(s + x*8, acc, sizeof(acc)); + memcpy(s + x * 8, acc, sizeof(acc)); } } @@ -1913,7 +1883,7 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemm_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_q1_0_8x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert (n % QK1_0 == 0); assert (nr % 4 == 0); assert (nc % 8 == 0); @@ -1943,18 +1913,31 @@ void ggml_gemm_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, for (int sb = 0; sb < 4; ++sb) { const block_q8_0x4 * yb = &a_ptr[l * 4 + sb]; - const uint8_t * qs = (const uint8_t *)b_ptr[l].qs + sb * 32; for (int r = 0; r < 2; ++r) { const float dy = GGML_CPU_FP16_TO_FP32(yb->d[row_base + r]); - - for (int c = 0; c < 8; ++c) { - int sumi = 0; - for (int i = 0; i < QK8_0; ++i) { - const int8_t y_val = yb->qs[(row_base + r) * 32 + i]; - sumi += ((qs[i] >> c) & 1) ? y_val : -y_val; + const int row = row_base + r; + + for (int g = 0; g < 8; ++g) { + const uint8_t * qs_row = (const uint8_t *)b_ptr[l].qs + (sb * 8 + g) * 4; + const int8_t * y = yb->qs + g * 16 + row * 4; + + for (int pair = 0; pair < 4; ++pair) { + uint8_t byte = qs_row[pair]; + int col_even = 2 * pair; + int col_odd = 2 * pair + 1; + int sumi_even = 0; + int sumi_odd = 0; + for (int t = 0; t < 4; ++t) { + int yv = y[t]; + if ((byte >> t) & 1) sumi_even += yv; + else sumi_even -= yv; + if ((byte >> (t + 4)) & 1) sumi_odd += yv; + else sumi_odd -= yv; + } + block_acc[r][col_even] += dy * (float)sumi_even; + block_acc[r][col_odd] += dy * (float)sumi_odd; } - block_acc[r][c] += dy * (float)sumi; } } } @@ -2966,23 +2949,30 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } -static block_q1_0x8 make_block_q1_0x8(block_q1_0 * in, unsigned int blck_size_interleave) { +static block_q1_0x8 make_block_q1_0x8(block_q1_0 * in) { block_q1_0x8 out; - GGML_ASSERT(blck_size_interleave == 8); - for (int i = 0; i < 8; ++i) { out.d[i] = in[i].d; } + memset(out.qs, 0, sizeof(out.qs)); + for (int sb = 0; sb < 4; ++sb) { - for (int i = 0; i < 32; ++i) { - uint8_t byte = 0; - for (int c = 0; c < 8; ++c) { - uint8_t src = in[c].qs[sb * 4 + i / 8]; - byte |= ((src >> (i % 8)) & 1) << c; + for (int g = 0; g < 8; ++g) { + for (int pair = 0; pair < 4; ++pair) { + uint8_t byte = 0; + for (int t = 0; t < 4; ++t) { + int k = g * 4 + t; + int col_even = 2 * pair; + int col_odd = 2 * pair + 1; + uint8_t src_even = in[col_even].qs[sb * 4 + k / 8]; + uint8_t src_odd = in[col_odd].qs[sb * 4 + k / 8]; + if ((src_even >> (k % 8)) & 1) byte |= (1 << t); + if ((src_odd >> (k % 8)) & 1) byte |= (1 << (t + 4)); + } + out.qs[(sb * 8 + g) * 4 + pair] = byte; } - out.qs[sb * 32 + i] = byte; } } @@ -3659,8 +3649,8 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block } static int repack_q1_0_to_q1_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + UNUSED(interleave_block); GGML_ASSERT(t->type == GGML_TYPE_Q1_0); - GGML_ASSERT(interleave_block == 8); constexpr int nrows_interleaved = 8; block_q1_0x8 * dst = (block_q1_0x8 *) t->data; @@ -3680,7 +3670,7 @@ static int repack_q1_0_to_q1_0_8_bl(struct ggml_tensor * t, int interleave_block for (int i = 0; i < nrows_interleaved; ++i) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q1_0x8(dst_tmp, interleave_block); + *dst++ = make_block_q1_0x8(dst_tmp); } src += nrows_interleaved * nblocks; } @@ -4145,8 +4135,8 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); } -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q1_0_to_q1_0_8_bl(t, 8, data, data_size); +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q1_0_to_q1_0_8_bl(t, 0, data, data_size); } #if defined __riscv_zvfh @@ -4246,8 +4236,8 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q1_0_8x32_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q1_0_8x4_q8_0(n, s, bs, vx, vy, nr, nc); } #if defined __riscv_zvfh @@ -4347,8 +4337,8 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q1_0_8x32_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q1_0_8x4_q8_0(n, s, bs, vx, vy, nr, nc); } #if defined __riscv_zvfh @@ -4782,7 +4772,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; // instance for Q1_0 - static const ggml::cpu::repack::tensor_traits q1_0_8x32_q8_0; + static const ggml::cpu::repack::tensor_traits q1_0_8x4_q8_0; // instances for RISC-V // @@ -4825,8 +4815,8 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } } else if (cur->type == GGML_TYPE_Q1_0) { if (ggml_cpu_has_avx2()) { - if (cur->ne[1] % 8 == 0) { - return &q1_0_8x32_q8_0; + if (cur->ne[1] % 16 == 0) { + return &q1_0_8x4_q8_0; } } } else if (cur->type == GGML_TYPE_Q4_K) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 5db7ca8269f..4307cc4f2cb 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -35,7 +35,7 @@ static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding"); -static_assert(sizeof(block<1, 8>) == 8 * sizeof(ggml_half) + QK1_0, "wrong block<1,8> size/padding"); +static_assert(sizeof(block<1, 8>) == 8 * sizeof(ggml_half) + QK1_0, "wrong block_q1_0x8 size"); using block_q1_0x8 = block<1, 8>; using block_q4_0x4 = block<4, 4>; @@ -144,7 +144,6 @@ extern "C" { void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_0_4x32(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -163,7 +162,7 @@ void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q1_0_8x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -180,7 +179,7 @@ void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q1_0_8x32_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q1_0_8x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); @@ -199,7 +198,6 @@ void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v // Native implementations void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_0_4x32_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -218,7 +216,7 @@ void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q1_0_8x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -235,7 +233,7 @@ void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q1_0_8x32_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q1_0_8x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);