diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 1758d83c261..3163e194982 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -55,6 +55,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -71,6 +72,7 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 @@ -78,15 +80,16 @@ #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 // repack.cpp -#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 @@ -141,6 +144,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -157,6 +161,7 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__loongarch64) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -188,6 +193,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -204,6 +210,7 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 @@ -230,6 +237,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K @@ -245,6 +253,8 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 + #elif defined(__s390x__) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -282,6 +292,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -298,6 +309,7 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #elif defined(__wasm__) // quants.c #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 @@ -337,6 +349,7 @@ #define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_q1_0_8x4_q8_0_generic ggml_gemv_q1_0_8x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -353,4 +366,5 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_q1_0_8x4_q8_0_generic ggml_gemm_q1_0_8x4_q8_0 #endif diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index af1cebad131..4a5834ad548 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -175,6 +175,74 @@ static inline __m256i mul_sum_i8_pairs_acc_int32x8(const __m256i acc, const __m2 } #endif +void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + +#if defined(__AVX2__) + const __m256 signBit = _mm256_set1_ps(-0.0f); + + for (int i = 0; i < nb; ++i) { + __m256 srcv[4][4]; + __m256 idv[4]; + + for (int r = 0; r < 4; ++r) { + const float *xr = x + r * k + i * 32; + + srcv[r][0] = _mm256_loadu_ps(xr + 0); + srcv[r][1] = _mm256_loadu_ps(xr + 8); + srcv[r][2] = _mm256_loadu_ps(xr + 16); + srcv[r][3] = _mm256_loadu_ps(xr + 24); + + __m256 maxAbs = _mm256_andnot_ps(signBit, srcv[r][0]); + maxAbs = _mm256_max_ps(maxAbs, _mm256_andnot_ps(signBit, srcv[r][1])); + maxAbs = _mm256_max_ps(maxAbs, _mm256_andnot_ps(signBit, srcv[r][2])); + maxAbs = _mm256_max_ps(maxAbs, _mm256_andnot_ps(signBit, srcv[r][3])); + + __m128 max4 = _mm_max_ps(_mm256_castps256_ps128(maxAbs), _mm256_extractf128_ps(maxAbs, 1)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + + const float amax = _mm_cvtss_f32(max4); + const float d = amax / 127.0f; + const float id = amax != 0.0f ? 127.0f / amax : 0.0f; + + y[i].d[r] = GGML_CPU_FP32_TO_FP16(d); + idv[r] = _mm256_set1_ps(id); + } + + for (int j = 0; j < 4; ++j) { + __m256 q0 = _mm256_mul_ps(srcv[0][j], idv[0]); + __m256 q1 = _mm256_mul_ps(srcv[1][j], idv[1]); + __m256 q2 = _mm256_mul_ps(srcv[2][j], idv[2]); + __m256 q3 = _mm256_mul_ps(srcv[3][j], idv[3]); + + q0 = _mm256_round_ps(q0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q1 = _mm256_round_ps(q1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q2 = _mm256_round_ps(q2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + q3 = _mm256_round_ps(q3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + + __m256i i0 = _mm256_cvtps_epi32(q0); + __m256i i1 = _mm256_cvtps_epi32(q1); + __m256i i2 = _mm256_cvtps_epi32(q2); + __m256i i3 = _mm256_cvtps_epi32(q3); + + const __m256i p01 = _mm256_packs_epi32(i0, i1); + const __m256i p23 = _mm256_packs_epi32(i2, i3); + + const __m256i packed = _mm256_packs_epi16(p01, p23); + + _mm256_storeu_si256((__m256i *)(y[i].qs + j * 32), packed); + } + } +#else + ggml_quantize_mat_q8_0_4x4_generic(x, vy, k); +#endif +} + void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -2039,6 +2107,621 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +// Rotated Q1.0 layout helpers and kernels + +static const int8_t q1_byte_sel[32] = { + 0,0,0,0, 0,0,0,0, + 1,1,1,1, 1,1,1,1, + 2,2,2,2, 2,2,2,2, + 3,3,3,3, 3,3,3,3, +}; + +static const int8_t q1_bit_mask[32] = { + 1, 2, 4, 8, 16, 32, 64, -128, + 1, 2, 4, 8, 16, 32, 64, -128, + 1, 2, 4, 8, 16, 32, 64, -128, + 1, 2, 4, 8, 16, 32, 64, -128, +}; + +static inline uint32_t ggml_load_u32(const void * p) { + return *(const uint32_t *)p; +} + +#if defined(__F16C__) +static inline __m256 ggml_cvt_fp16x4_to_fp32x2(const ggml_half * p) { + __m128i h = _mm_loadu_si64((const void *)p); + __m128 f = _mm_cvtph_ps(h); + return _mm256_insertf128_ps(_mm256_castps128_ps256(f), f, 1); +} +#endif + + + +static inline __m256i ggml_q1_negmask_8cols( + uint32_t qpack, + __m256i byte_sel, + __m256i bit_mask, + __m256i zero) { + const __m256i qsrc = _mm256_set1_epi32((int32_t)qpack); + const __m256i qbyte = _mm256_shuffle_epi8(qsrc, byte_sel); + const __m256i qbit = _mm256_and_si256(qbyte, bit_mask); + // 0xff where q1 bit is zero, 0x00 where q1 bit is set + return _mm256_cmpeq_epi8(qbit, zero); +} + +// xor-sub pairs: accumulate in int16 (safe: 8*2*127=2032 < 32767) +static inline __m256i ggml_q1_dotpairs_xorsub_i16( + __m256i negmask, + __m256i yrep, + __m256i ones_8) { + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(yrep, negmask), negmask); + return _mm256_maddubs_epi16(ones_8, sy); +} + +// Bit extraction as 0/1 for positive-sum formulation: dot = 2*pos - total +static inline __m256i ggml_q1_bit01_8cols( + uint32_t qpack, + __m256i byte_sel, + __m256i bit_mask, + __m256i ones_8, + __m256i zero) { + const __m256i qsrc = _mm256_set1_epi32((int32_t) qpack); + const __m256i qbyte = _mm256_shuffle_epi8(qsrc, byte_sel); + const __m256i qbit = _mm256_and_si256(qbyte, bit_mask); + const __m256i is_zero = _mm256_cmpeq_epi8(qbit, zero); + return _mm256_andnot_si256(is_zero, ones_8); +} + +static inline __m256i ggml_q8_totalpairs_i16( + __m256i yrep, + __m256i ones_8) { + return _mm256_maddubs_epi16(ones_8, yrep); +} + +static inline __m256i ggml_q1_pospairs_i16( + __m256i bit01, + __m256i yrep) { + return _mm256_maddubs_epi16(bit01, yrep); +} + +void ggml_gemv_q1_0_8x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) && defined(__FMA__) + assert(n % QK1_0 == 0); + assert(nc % 16 == 0); + assert(nr == 1); + + UNUSED(bs); + UNUSED(nr); + + const int nb = n / QK1_0; + + const block_q1_0x8 * GGML_RESTRICT x = (const block_q1_0x8 *) vx; + const block_q8_0 * GGML_RESTRICT y = (const block_q8_0 *) vy; + + const __m256i byte_sel = _mm256_loadu_si256((const __m256i *) q1_byte_sel); + const __m256i bit_mask = _mm256_loadu_si256((const __m256i *) q1_bit_mask); + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i zero = _mm256_setzero_si256(); + + // 48-column main loop (AVX-512 register file) / 32-column (AVX2) +#if defined(__AVX512VL__) + { + const int ncols48 = nc / 48; + const int nc_tail = nc - ncols48 * 48; + + for (int cx = 0; cx < ncols48; ++cx) { + const block_q1_0x8 * GGML_RESTRICT x0 = x + (size_t)(cx * 6) * nb; + const block_q1_0x8 * GGML_RESTRICT x1 = x0 + nb; + const block_q1_0x8 * GGML_RESTRICT x2 = x1 + nb; + const block_q1_0x8 * GGML_RESTRICT x3 = x2 + nb; + const block_q1_0x8 * GGML_RESTRICT x4 = x3 + nb; + const block_q1_0x8 * GGML_RESTRICT x5 = x4 + nb; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + __m256 acc4 = _mm256_setzero_ps(); + __m256 acc5 = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmp0 = _mm256_setzero_ps(); + __m256 tmp1 = _mm256_setzero_ps(); + __m256 tmp2 = _mm256_setzero_ps(); + __m256 tmp3 = _mm256_setzero_ps(); + __m256 tmp4 = _mm256_setzero_ps(); + __m256 tmp5 = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * GGML_RESTRICT yb = &y[l * 4 + sb]; + + __m256i total16 = _mm256_setzero_si256(); + __m256i pos0 = _mm256_setzero_si256(); + __m256i pos1 = _mm256_setzero_si256(); + __m256i pos2 = _mm256_setzero_si256(); + __m256i pos3 = _mm256_setzero_si256(); + __m256i pos4 = _mm256_setzero_si256(); + __m256i pos5 = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const __m256i yrep = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 4)); + + total16 = _mm256_add_epi16(total16, ggml_q8_totalpairs_i16(yrep, ones_8)); + + const uint32_t qpack0 = ggml_load_u32(x0[l].qs + (sb * 8 + g) * 4); + const __m256i bit0 = ggml_q1_bit01_8cols(qpack0, byte_sel, bit_mask, ones_8, zero); + pos0 = _mm256_add_epi16(pos0, ggml_q1_pospairs_i16(bit0, yrep)); + + const uint32_t qpack1 = ggml_load_u32(x1[l].qs + (sb * 8 + g) * 4); + const __m256i bit1 = ggml_q1_bit01_8cols(qpack1, byte_sel, bit_mask, ones_8, zero); + pos1 = _mm256_add_epi16(pos1, ggml_q1_pospairs_i16(bit1, yrep)); + + const uint32_t qpack2 = ggml_load_u32(x2[l].qs + (sb * 8 + g) * 4); + const __m256i bit2 = ggml_q1_bit01_8cols(qpack2, byte_sel, bit_mask, ones_8, zero); + pos2 = _mm256_add_epi16(pos2, ggml_q1_pospairs_i16(bit2, yrep)); + + const uint32_t qpack3 = ggml_load_u32(x3[l].qs + (sb * 8 + g) * 4); + const __m256i bit3 = ggml_q1_bit01_8cols(qpack3, byte_sel, bit_mask, ones_8, zero); + pos3 = _mm256_add_epi16(pos3, ggml_q1_pospairs_i16(bit3, yrep)); + + const uint32_t qpack4 = ggml_load_u32(x4[l].qs + (sb * 8 + g) * 4); + const __m256i bit4 = ggml_q1_bit01_8cols(qpack4, byte_sel, bit_mask, ones_8, zero); + pos4 = _mm256_add_epi16(pos4, ggml_q1_pospairs_i16(bit4, yrep)); + + const uint32_t qpack5 = ggml_load_u32(x5[l].qs + (sb * 8 + g) * 4); + const __m256i bit5 = ggml_q1_bit01_8cols(qpack5, byte_sel, bit_mask, ones_8, zero); + pos5 = _mm256_add_epi16(pos5, ggml_q1_pospairs_i16(bit5, yrep)); + } + + const __m256i dot0_16 = _mm256_sub_epi16(_mm256_add_epi16(pos0, pos0), total16); + const __m256i dot1_16 = _mm256_sub_epi16(_mm256_add_epi16(pos1, pos1), total16); + const __m256i dot2_16 = _mm256_sub_epi16(_mm256_add_epi16(pos2, pos2), total16); + const __m256i dot3_16 = _mm256_sub_epi16(_mm256_add_epi16(pos3, pos3), total16); + const __m256i dot4_16 = _mm256_sub_epi16(_mm256_add_epi16(pos4, pos4), total16); + const __m256i dot5_16 = _mm256_sub_epi16(_mm256_add_epi16(pos5, pos5), total16); + + const __m256 ydv = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + tmp0 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot0_16, ones_16)), ydv, tmp0); + tmp1 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot1_16, ones_16)), ydv, tmp1); + tmp2 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot2_16, ones_16)), ydv, tmp2); + tmp3 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot3_16, ones_16)), ydv, tmp3); + tmp4 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot4_16, ones_16)), ydv, tmp4); + tmp5 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot5_16, ones_16)), ydv, tmp5); + } + + const __m256 bdv0 = GGML_F32Cx8_LOAD(x0[l].d); + const __m256 bdv1 = GGML_F32Cx8_LOAD(x1[l].d); + const __m256 bdv2 = GGML_F32Cx8_LOAD(x2[l].d); + const __m256 bdv3 = GGML_F32Cx8_LOAD(x3[l].d); + const __m256 bdv4 = GGML_F32Cx8_LOAD(x4[l].d); + const __m256 bdv5 = GGML_F32Cx8_LOAD(x5[l].d); + + acc0 = _mm256_fmadd_ps(tmp0, bdv0, acc0); + acc1 = _mm256_fmadd_ps(tmp1, bdv1, acc1); + acc2 = _mm256_fmadd_ps(tmp2, bdv2, acc2); + acc3 = _mm256_fmadd_ps(tmp3, bdv3, acc3); + acc4 = _mm256_fmadd_ps(tmp4, bdv4, acc4); + acc5 = _mm256_fmadd_ps(tmp5, bdv5, acc5); + } + + float *sbase = s + cx * 48; + _mm256_storeu_ps(sbase + 0, acc0); + _mm256_storeu_ps(sbase + 8, acc1); + _mm256_storeu_ps(sbase + 16, acc2); + _mm256_storeu_ps(sbase + 24, acc3); + _mm256_storeu_ps(sbase + 32, acc4); + _mm256_storeu_ps(sbase + 40, acc5); + } + + // 32-column tail + if (nc_tail >= 32) { + const block_q1_0x8 * GGML_RESTRICT xa = x + (size_t)(ncols48 * 6) * nb; + const block_q1_0x8 * GGML_RESTRICT xb = xa + nb; + const block_q1_0x8 * GGML_RESTRICT xc = xb + nb; + const block_q1_0x8 * GGML_RESTRICT xd = xc + nb; + + __m256 acca = _mm256_setzero_ps(); + __m256 accb = _mm256_setzero_ps(); + __m256 accc = _mm256_setzero_ps(); + __m256 accd = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmpa = _mm256_setzero_ps(); + __m256 tmpb = _mm256_setzero_ps(); + __m256 tmpc = _mm256_setzero_ps(); + __m256 tmpd = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * GGML_RESTRICT yb = &y[l * 4 + sb]; + + __m256i total16 = _mm256_setzero_si256(); + __m256i posa = _mm256_setzero_si256(); + __m256i posb = _mm256_setzero_si256(); + __m256i posc = _mm256_setzero_si256(); + __m256i posd = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const __m256i yrep = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 4)); + + total16 = _mm256_add_epi16(total16, ggml_q8_totalpairs_i16(yrep, ones_8)); + + const uint32_t qpacka = ggml_load_u32(xa[l].qs + (sb * 8 + g) * 4); + const __m256i bita = ggml_q1_bit01_8cols(qpacka, byte_sel, bit_mask, ones_8, zero); + posa = _mm256_add_epi16(posa, ggml_q1_pospairs_i16(bita, yrep)); + + const uint32_t qpackb = ggml_load_u32(xb[l].qs + (sb * 8 + g) * 4); + const __m256i bitb = ggml_q1_bit01_8cols(qpackb, byte_sel, bit_mask, ones_8, zero); + posb = _mm256_add_epi16(posb, ggml_q1_pospairs_i16(bitb, yrep)); + + const uint32_t qpackc = ggml_load_u32(xc[l].qs + (sb * 8 + g) * 4); + const __m256i bitc = ggml_q1_bit01_8cols(qpackc, byte_sel, bit_mask, ones_8, zero); + posc = _mm256_add_epi16(posc, ggml_q1_pospairs_i16(bitc, yrep)); + + const uint32_t qpackd = ggml_load_u32(xd[l].qs + (sb * 8 + g) * 4); + const __m256i bitd = ggml_q1_bit01_8cols(qpackd, byte_sel, bit_mask, ones_8, zero); + posd = _mm256_add_epi16(posd, ggml_q1_pospairs_i16(bitd, yrep)); + } + + const __m256i dota_16 = _mm256_sub_epi16(_mm256_add_epi16(posa, posa), total16); + const __m256i dotb_16 = _mm256_sub_epi16(_mm256_add_epi16(posb, posb), total16); + const __m256i dotc_16 = _mm256_sub_epi16(_mm256_add_epi16(posc, posc), total16); + const __m256i dotd_16 = _mm256_sub_epi16(_mm256_add_epi16(posd, posd), total16); + + const __m256 ydv = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + tmpa = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dota_16, ones_16)), ydv, tmpa); + tmpb = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dotb_16, ones_16)), ydv, tmpb); + tmpc = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dotc_16, ones_16)), ydv, tmpc); + tmpd = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dotd_16, ones_16)), ydv, tmpd); + } + + const __m256 bdva = GGML_F32Cx8_LOAD(xa[l].d); + const __m256 bdvb = GGML_F32Cx8_LOAD(xb[l].d); + const __m256 bdvc = GGML_F32Cx8_LOAD(xc[l].d); + const __m256 bdvd = GGML_F32Cx8_LOAD(xd[l].d); + + acca = _mm256_fmadd_ps(tmpa, bdva, acca); + accb = _mm256_fmadd_ps(tmpb, bdvb, accb); + accc = _mm256_fmadd_ps(tmpc, bdvc, accc); + accd = _mm256_fmadd_ps(tmpd, bdvd, accd); + } + + float *sbase = s + ncols48 * 48; + _mm256_storeu_ps(sbase + 0, acca); + _mm256_storeu_ps(sbase + 8, accb); + _mm256_storeu_ps(sbase + 16, accc); + _mm256_storeu_ps(sbase + 24, accd); + } + + // 16-column tail + if (nc_tail == 16) { + const block_q1_0x8 * GGML_RESTRICT xa = x + (size_t)(ncols48 * 6) * nb; + const block_q1_0x8 * GGML_RESTRICT xb = xa + nb; + + __m256 acca = _mm256_setzero_ps(); + __m256 accb = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmpa = _mm256_setzero_ps(); + __m256 tmpb = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * GGML_RESTRICT yb = &y[l * 4 + sb]; + + __m256i total16 = _mm256_setzero_si256(); + __m256i posa = _mm256_setzero_si256(); + __m256i posb = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const __m256i yrep = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 4)); + + total16 = _mm256_add_epi16(total16, ggml_q8_totalpairs_i16(yrep, ones_8)); + + const uint32_t qpacka = ggml_load_u32(xa[l].qs + (sb * 8 + g) * 4); + const __m256i bita = ggml_q1_bit01_8cols(qpacka, byte_sel, bit_mask, ones_8, zero); + posa = _mm256_add_epi16(posa, ggml_q1_pospairs_i16(bita, yrep)); + + const uint32_t qpackb = ggml_load_u32(xb[l].qs + (sb * 8 + g) * 4); + const __m256i bitb = ggml_q1_bit01_8cols(qpackb, byte_sel, bit_mask, ones_8, zero); + posb = _mm256_add_epi16(posb, ggml_q1_pospairs_i16(bitb, yrep)); + } + + const __m256i dota_16 = _mm256_sub_epi16(_mm256_add_epi16(posa, posa), total16); + const __m256i dotb_16 = _mm256_sub_epi16(_mm256_add_epi16(posb, posb), total16); + + const __m256 ydv = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + tmpa = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dota_16, ones_16)), ydv, tmpa); + tmpb = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dotb_16, ones_16)), ydv, tmpb); + } + + const __m256 bdva = GGML_F32Cx8_LOAD(xa[l].d); + const __m256 bdvb = GGML_F32Cx8_LOAD(xb[l].d); + + acca = _mm256_fmadd_ps(tmpa, bdva, acca); + accb = _mm256_fmadd_ps(tmpb, bdvb, accb); + } + + float *sbase = s + ncols48 * 48; + _mm256_storeu_ps(sbase, acca); + _mm256_storeu_ps(sbase + 8, accb); + } + } +#else + { + const int ncols32 = nc / 32; + const int nc_tail = nc - ncols32 * 32; + + for (int cx = 0; cx < ncols32; ++cx) { + const block_q1_0x8 * GGML_RESTRICT x0 = x + (size_t)(cx * 4) * nb; + const block_q1_0x8 * GGML_RESTRICT x1 = x0 + nb; + const block_q1_0x8 * GGML_RESTRICT x2 = x1 + nb; + const block_q1_0x8 * GGML_RESTRICT x3 = x2 + nb; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmp0 = _mm256_setzero_ps(); + __m256 tmp1 = _mm256_setzero_ps(); + __m256 tmp2 = _mm256_setzero_ps(); + __m256 tmp3 = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * GGML_RESTRICT yb = &y[l * 4 + sb]; + + __m256i total16 = _mm256_setzero_si256(); + __m256i pos0 = _mm256_setzero_si256(); + __m256i pos1 = _mm256_setzero_si256(); + __m256i pos2 = _mm256_setzero_si256(); + __m256i pos3 = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const __m256i yrep = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 4)); + + total16 = _mm256_add_epi16(total16, ggml_q8_totalpairs_i16(yrep, ones_8)); + + const uint32_t qpack0 = ggml_load_u32(x0[l].qs + (sb * 8 + g) * 4); + const __m256i bit0 = ggml_q1_bit01_8cols(qpack0, byte_sel, bit_mask, ones_8, zero); + pos0 = _mm256_add_epi16(pos0, ggml_q1_pospairs_i16(bit0, yrep)); + + const uint32_t qpack1 = ggml_load_u32(x1[l].qs + (sb * 8 + g) * 4); + const __m256i bit1 = ggml_q1_bit01_8cols(qpack1, byte_sel, bit_mask, ones_8, zero); + pos1 = _mm256_add_epi16(pos1, ggml_q1_pospairs_i16(bit1, yrep)); + + const uint32_t qpack2 = ggml_load_u32(x2[l].qs + (sb * 8 + g) * 4); + const __m256i bit2 = ggml_q1_bit01_8cols(qpack2, byte_sel, bit_mask, ones_8, zero); + pos2 = _mm256_add_epi16(pos2, ggml_q1_pospairs_i16(bit2, yrep)); + + const uint32_t qpack3 = ggml_load_u32(x3[l].qs + (sb * 8 + g) * 4); + const __m256i bit3 = ggml_q1_bit01_8cols(qpack3, byte_sel, bit_mask, ones_8, zero); + pos3 = _mm256_add_epi16(pos3, ggml_q1_pospairs_i16(bit3, yrep)); + } + + const __m256i dot0_16 = _mm256_sub_epi16(_mm256_add_epi16(pos0, pos0), total16); + const __m256i dot1_16 = _mm256_sub_epi16(_mm256_add_epi16(pos1, pos1), total16); + const __m256i dot2_16 = _mm256_sub_epi16(_mm256_add_epi16(pos2, pos2), total16); + const __m256i dot3_16 = _mm256_sub_epi16(_mm256_add_epi16(pos3, pos3), total16); + + const __m256 ydv = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + tmp0 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot0_16, ones_16)), ydv, tmp0); + tmp1 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot1_16, ones_16)), ydv, tmp1); + tmp2 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot2_16, ones_16)), ydv, tmp2); + tmp3 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dot3_16, ones_16)), ydv, tmp3); + } + + const __m256 bdv0 = GGML_F32Cx8_LOAD(x0[l].d); + const __m256 bdv1 = GGML_F32Cx8_LOAD(x1[l].d); + const __m256 bdv2 = GGML_F32Cx8_LOAD(x2[l].d); + const __m256 bdv3 = GGML_F32Cx8_LOAD(x3[l].d); + + acc0 = _mm256_fmadd_ps(tmp0, bdv0, acc0); + acc1 = _mm256_fmadd_ps(tmp1, bdv1, acc1); + acc2 = _mm256_fmadd_ps(tmp2, bdv2, acc2); + acc3 = _mm256_fmadd_ps(tmp3, bdv3, acc3); + } + + float *sbase = s + cx * 32; + _mm256_storeu_ps(sbase + 0, acc0); + _mm256_storeu_ps(sbase + 8, acc1); + _mm256_storeu_ps(sbase + 16, acc2); + _mm256_storeu_ps(sbase + 24, acc3); + } + + // 16-column tail + if (nc_tail >= 16) { + const block_q1_0x8 * GGML_RESTRICT xa = x + (size_t)(ncols32 * 4) * nb; + const block_q1_0x8 * GGML_RESTRICT xb = xa + nb; + + __m256 acca = _mm256_setzero_ps(); + __m256 accb = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmpa = _mm256_setzero_ps(); + __m256 tmpb = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * GGML_RESTRICT yb = &y[l * 4 + sb]; + + __m256i total16 = _mm256_setzero_si256(); + __m256i posa = _mm256_setzero_si256(); + __m256i posb = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const __m256i yrep = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 4)); + + total16 = _mm256_add_epi16(total16, ggml_q8_totalpairs_i16(yrep, ones_8)); + + const uint32_t qpacka = ggml_load_u32(xa[l].qs + (sb * 8 + g) * 4); + const __m256i bita = ggml_q1_bit01_8cols(qpacka, byte_sel, bit_mask, ones_8, zero); + posa = _mm256_add_epi16(posa, ggml_q1_pospairs_i16(bita, yrep)); + + const uint32_t qpackb = ggml_load_u32(xb[l].qs + (sb * 8 + g) * 4); + const __m256i bitb = ggml_q1_bit01_8cols(qpackb, byte_sel, bit_mask, ones_8, zero); + posb = _mm256_add_epi16(posb, ggml_q1_pospairs_i16(bitb, yrep)); + } + + const __m256i dota_16 = _mm256_sub_epi16(_mm256_add_epi16(posa, posa), total16); + const __m256i dotb_16 = _mm256_sub_epi16(_mm256_add_epi16(posb, posb), total16); + + const __m256 ydv = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb->d)); + tmpa = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dota_16, ones_16)), ydv, tmpa); + tmpb = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_madd_epi16(dotb_16, ones_16)), ydv, tmpb); + } + + const __m256 bdva = GGML_F32Cx8_LOAD(xa[l].d); + const __m256 bdvb = GGML_F32Cx8_LOAD(xb[l].d); + + acca = _mm256_fmadd_ps(tmpa, bdva, acca); + accb = _mm256_fmadd_ps(tmpb, bdvb, accb); + } + + _mm256_storeu_ps(s + ncols32 * 32, acca); + _mm256_storeu_ps(s + ncols32 * 32 + 8, accb); + } + } +#endif // __AVX512VL__ + + return; +#else + ggml_gemv_q1_0_8x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + +void ggml_gemm_q1_0_8x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) && defined(__FMA__) + assert(n % QK1_0 == 0); + assert(nr % 4 == 0); + assert(nc % 16 == 0); + + const int nb = n / QK1_0; + const int nb_q8_0 = n / QK8_0; + const int ncols16 = nc / 16; + const int nrows4 = nr / 4; + + const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; + + const __m256i byte_sel = _mm256_loadu_si256((const __m256i *) q1_byte_sel); + const __m256i bit_mask = _mm256_loadu_si256((const __m256i *) q1_bit_mask); + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i zero = _mm256_setzero_si256(); + + for (int y = 0; y < nrows4; ++y) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *)vy + (y * nb_q8_0); + + for (int x = 0; x < ncols16; ++x) { + const block_q1_0x8 * ba = vx_bi + (size_t)(x * 2) * nb; + const block_q1_0x8 * bb = ba + nb; + + __m256 acc0a = _mm256_setzero_ps(); + __m256 acc1a = _mm256_setzero_ps(); + __m256 acc2a = _mm256_setzero_ps(); + __m256 acc3a = _mm256_setzero_ps(); + __m256 acc0b = _mm256_setzero_ps(); + __m256 acc1b = _mm256_setzero_ps(); + __m256 acc2b = _mm256_setzero_ps(); + __m256 acc3b = _mm256_setzero_ps(); + + for (int l = 0; l < nb; ++l) { + __m256 tmp0a = _mm256_setzero_ps(); + __m256 tmp1a = _mm256_setzero_ps(); + __m256 tmp2a = _mm256_setzero_ps(); + __m256 tmp3a = _mm256_setzero_ps(); + __m256 tmp0b = _mm256_setzero_ps(); + __m256 tmp1b = _mm256_setzero_ps(); + __m256 tmp2b = _mm256_setzero_ps(); + __m256 tmp3b = _mm256_setzero_ps(); + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0x4 * yb = &a_ptr[l * 4 + sb]; + + __m256i isum0a16 = _mm256_setzero_si256(); + __m256i isum1a16 = _mm256_setzero_si256(); + __m256i isum2a16 = _mm256_setzero_si256(); + __m256i isum3a16 = _mm256_setzero_si256(); + __m256i isum0b16 = _mm256_setzero_si256(); + __m256i isum1b16 = _mm256_setzero_si256(); + __m256i isum2b16 = _mm256_setzero_si256(); + __m256i isum3b16 = _mm256_setzero_si256(); + + for (int g = 0; g < 8; ++g) { + const uint32_t qpacka = ggml_load_u32(ba[l].qs + (sb * 8 + g) * 4); + const __m256i nega = ggml_q1_negmask_8cols(qpacka, byte_sel, bit_mask, zero); + const uint32_t qpackb = ggml_load_u32(bb[l].qs + (sb * 8 + g) * 4); + const __m256i negb = ggml_q1_negmask_8cols(qpackb, byte_sel, bit_mask, zero); + + const __m256i yrep0 = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 16 + 0 * 4)); + const __m256i yrep1 = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 16 + 1 * 4)); + const __m256i yrep2 = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 16 + 2 * 4)); + const __m256i yrep3 = _mm256_set1_epi32(ggml_load_u32(yb->qs + g * 16 + 3 * 4)); + + isum0a16 = _mm256_add_epi16(isum0a16, ggml_q1_dotpairs_xorsub_i16(nega, yrep0, ones_8)); + isum1a16 = _mm256_add_epi16(isum1a16, ggml_q1_dotpairs_xorsub_i16(nega, yrep1, ones_8)); + isum2a16 = _mm256_add_epi16(isum2a16, ggml_q1_dotpairs_xorsub_i16(nega, yrep2, ones_8)); + isum3a16 = _mm256_add_epi16(isum3a16, ggml_q1_dotpairs_xorsub_i16(nega, yrep3, ones_8)); + + isum0b16 = _mm256_add_epi16(isum0b16, ggml_q1_dotpairs_xorsub_i16(negb, yrep0, ones_8)); + isum1b16 = _mm256_add_epi16(isum1b16, ggml_q1_dotpairs_xorsub_i16(negb, yrep1, ones_8)); + isum2b16 = _mm256_add_epi16(isum2b16, ggml_q1_dotpairs_xorsub_i16(negb, yrep2, ones_8)); + isum3b16 = _mm256_add_epi16(isum3b16, ggml_q1_dotpairs_xorsub_i16(negb, yrep3, ones_8)); + } + + const __m256i isum0a = _mm256_madd_epi16(isum0a16, ones_16); + const __m256i isum1a = _mm256_madd_epi16(isum1a16, ones_16); + const __m256i isum2a = _mm256_madd_epi16(isum2a16, ones_16); + const __m256i isum3a = _mm256_madd_epi16(isum3a16, ones_16); + const __m256i isum0b = _mm256_madd_epi16(isum0b16, ones_16); + const __m256i isum1b = _mm256_madd_epi16(isum1b16, ones_16); + const __m256i isum2b = _mm256_madd_epi16(isum2b16, ones_16); + const __m256i isum3b = _mm256_madd_epi16(isum3b16, ones_16); + + const __m256 yd_all = ggml_cvt_fp16x4_to_fp32x2(yb->d); + const __m256 yd0 = _mm256_permute_ps(yd_all, 0); + const __m256 yd1 = _mm256_permute_ps(yd_all, 85); + const __m256 yd2 = _mm256_permute_ps(yd_all, 170); + const __m256 yd3 = _mm256_permute_ps(yd_all, 255); + tmp0a = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum0a), yd0, tmp0a); + tmp1a = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum1a), yd1, tmp1a); + tmp2a = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum2a), yd2, tmp2a); + tmp3a = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum3a), yd3, tmp3a); + tmp0b = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum0b), yd0, tmp0b); + tmp1b = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum1b), yd1, tmp1b); + tmp2b = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum2b), yd2, tmp2b); + tmp3b = _mm256_fmadd_ps(_mm256_cvtepi32_ps(isum3b), yd3, tmp3b); + } + + const __m256 bdva = GGML_F32Cx8_LOAD(ba[l].d); + const __m256 bdvb = GGML_F32Cx8_LOAD(bb[l].d); + + acc0a = _mm256_fmadd_ps(tmp0a, bdva, acc0a); + acc1a = _mm256_fmadd_ps(tmp1a, bdva, acc1a); + acc2a = _mm256_fmadd_ps(tmp2a, bdva, acc2a); + acc3a = _mm256_fmadd_ps(tmp3a, bdva, acc3a); + acc0b = _mm256_fmadd_ps(tmp0b, bdvb, acc0b); + acc1b = _mm256_fmadd_ps(tmp1b, bdvb, acc1b); + acc2b = _mm256_fmadd_ps(tmp2b, bdvb, acc2b); + acc3b = _mm256_fmadd_ps(tmp3b, bdvb, acc3b); + } + + _mm256_storeu_ps(s + (y * 4 + 0) * bs + x * 16, acc0a); + _mm256_storeu_ps(s + (y * 4 + 1) * bs + x * 16, acc1a); + _mm256_storeu_ps(s + (y * 4 + 2) * bs + x * 16, acc2a); + _mm256_storeu_ps(s + (y * 4 + 3) * bs + x * 16, acc3a); + _mm256_storeu_ps(s + (y * 4 + 0) * bs + x * 16 + 8, acc0b); + _mm256_storeu_ps(s + (y * 4 + 1) * bs + x * 16 + 8, acc1b); + _mm256_storeu_ps(s + (y * 4 + 2) * bs + x * 16 + 8, acc2b); + _mm256_storeu_ps(s + (y * 4 + 3) * bs + x * 16 + 8, acc3b); + } + } + + return; +#else + ggml_gemm_q1_0_8x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index f18758f16bb..8c87c603efc 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -884,6 +884,70 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q1_0_8x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert (n % QK1_0 == 0); + assert (nc % 8 == 0); + assert (nr == 1); + + UNUSED(bs); + UNUSED(nr); + + const int nb = n / QK1_0; + const int ncols8 = nc / 8; + + const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; + const block_q8_0 * a_ptr = (const block_q8_0 *)vy; + + for (int x = 0; x < ncols8; ++x) { + const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; + + float acc[8] = {0}; + + for (int l = 0; l < nb; ++l) { + float bd[8]; + for (int c = 0; c < 8; ++c) + bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); + + float block_acc[8] = {0}; + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0 * yb = &a_ptr[l * 4 + sb]; + const float dy = GGML_CPU_FP16_TO_FP32(yb->d); + const int8_t * y = yb->qs; + + for (int g = 0; g < 8; ++g) { + const uint8_t * qs_row = (const uint8_t *)b_ptr[l].qs + (sb * 8 + g) * 4; + + for (int pair = 0; pair < 4; ++pair) { + uint8_t byte = qs_row[pair]; + int col_even = 2 * pair; + int col_odd = 2 * pair + 1; + int sumi_even = 0; + int sumi_odd = 0; + for (int t = 0; t < 4; ++t) { + int k = g * 4 + t; + int yv = y[k]; + if ((byte >> t) & 1) sumi_even += yv; + else sumi_even -= yv; + if ((byte >> (t + 4)) & 1) sumi_odd += yv; + else sumi_odd -= yv; + } + block_acc[col_even] += dy * (float)sumi_even; + block_acc[col_odd] += dy * (float)sumi_odd; + } + } + } + + for (int c = 0; c < 8; ++c) { + acc[c] += bd[c] * block_acc[c]; + } + } + + static_assert(sizeof(acc) == 32); + memcpy(s + x * 8, acc, sizeof(acc)); + } +} + void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -1819,6 +1883,83 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q1_0_8x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert (n % QK1_0 == 0); + assert (nr % 4 == 0); + assert (nc % 8 == 0); + + const int nb = n / QK1_0; + const int nb_q8_0 = n / QK8_0; + const int ncols8 = nc / 8; + const int nrows4 = nr / 4; + + const block_q1_0x8 * vx_bi = (const block_q1_0x8 *)vx; + + for (int y = 0; y < nrows4; ++y) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *)vy + (y * nb_q8_0); + + for (int row_base = 0; row_base < 4; row_base += 2) { + for (int x = 0; x < ncols8; ++x) { + const block_q1_0x8 * b_ptr = vx_bi + (size_t)x * nb; + + float acc[2][8] = {{0}}; + + for (int l = 0; l < nb; ++l) { + float bd[8]; + for (int c = 0; c < 8; ++c) + bd[c] = GGML_CPU_FP16_TO_FP32(b_ptr[l].d[c]); + + float block_acc[2][8] = {{0}}; + + for (int sb = 0; sb < 4; ++sb) { + const block_q8_0x4 * yb = &a_ptr[l * 4 + sb]; + + for (int r = 0; r < 2; ++r) { + const float dy = GGML_CPU_FP16_TO_FP32(yb->d[row_base + r]); + const int row = row_base + r; + + for (int g = 0; g < 8; ++g) { + const uint8_t * qs_row = (const uint8_t *)b_ptr[l].qs + (sb * 8 + g) * 4; + const int8_t * y = yb->qs + g * 16 + row * 4; + + for (int pair = 0; pair < 4; ++pair) { + uint8_t byte = qs_row[pair]; + int col_even = 2 * pair; + int col_odd = 2 * pair + 1; + int sumi_even = 0; + int sumi_odd = 0; + for (int t = 0; t < 4; ++t) { + int yv = y[t]; + if ((byte >> t) & 1) sumi_even += yv; + else sumi_even -= yv; + if ((byte >> (t + 4)) & 1) sumi_odd += yv; + else sumi_odd -= yv; + } + block_acc[r][col_even] += dy * (float)sumi_even; + block_acc[r][col_odd] += dy * (float)sumi_odd; + } + } + } + } + + for (int r = 0; r < 2; ++r) { + for (int c = 0; c < 8; ++c) { + acc[r][c] += bd[c] * block_acc[r][c]; + } + } + } + + for (int r = 0; r < 2; ++r) { + float * row_out = s + (y * 4 + row_base + r) * bs + x * 8; + for (int c = 0; c < 8; ++c) { + row_out[c] = acc[r][c]; + } + } + } + } + } +} + void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -2808,6 +2949,36 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } +static block_q1_0x8 make_block_q1_0x8(block_q1_0 * in) { + block_q1_0x8 out; + + for (int i = 0; i < 8; ++i) { + out.d[i] = in[i].d; + } + + memset(out.qs, 0, sizeof(out.qs)); + + for (int sb = 0; sb < 4; ++sb) { + for (int g = 0; g < 8; ++g) { + for (int pair = 0; pair < 4; ++pair) { + uint8_t byte = 0; + for (int t = 0; t < 4; ++t) { + int k = g * 4 + t; + int col_even = 2 * pair; + int col_odd = 2 * pair + 1; + uint8_t src_even = in[col_even].qs[sb * 4 + k / 8]; + uint8_t src_odd = in[col_odd].qs[sb * 4 + k / 8]; + if ((src_even >> (k % 8)) & 1) byte |= (1 << t); + if ((src_odd >> (k % 8)) & 1) byte |= (1 << (t + 4)); + } + out.qs[(sb * 8 + g) * 4 + pair] = byte; + } + } + } + + return out; +} + static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { block_q4_0x16 out; @@ -3477,6 +3648,36 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q1_0_to_q1_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + UNUSED(interleave_block); + GGML_ASSERT(t->type == GGML_TYPE_Q1_0); + constexpr int nrows_interleaved = 8; + + block_q1_0x8 * dst = (block_q1_0x8 *) t->data; + const block_q1_0 * src = (const block_q1_0 *) data; + block_q1_0 dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK1_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q1_0)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; ++x) { + for (int i = 0; i < nrows_interleaved; ++i) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q1_0x8(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + + return 0; +} + static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, @@ -3934,6 +4135,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q1_0_to_q1_0_8_bl(t, 0, data, data_size); +} + #if defined __riscv_zvfh template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size); @@ -4031,6 +4236,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q1_0_8x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + #if defined __riscv_zvfh template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -4128,6 +4337,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q1_0_8x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + #if defined __riscv_zvfh template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -4558,6 +4771,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; + // instance for Q1_0 + static const ggml::cpu::repack::tensor_traits q1_0_8x4_q8_0; + // instances for RISC-V // // These implement outer-product style matrix multiplication kernels with @@ -4597,6 +4813,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } #endif } + } else if (cur->type == GGML_TYPE_Q1_0) { + if (ggml_cpu_has_avx2()) { + if (cur->ne[1] % 16 == 0) { + return &q1_0_8x4_q8_0; + } + } } else if (cur->type == GGML_TYPE_Q4_K) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index cb21edf6239..4307cc4f2cb 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -11,6 +11,9 @@ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void); template constexpr int QK_0() { + if constexpr (K == 1) { + return QK1_0; + } if constexpr (K == 4) { return QK4_0; } @@ -32,7 +35,9 @@ static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding"); +static_assert(sizeof(block<1, 8>) == 8 * sizeof(ggml_half) + QK1_0, "wrong block_q1_0x8 size"); +using block_q1_0x8 = block<1, 8>; using block_q4_0x4 = block<4, 4>; using block_q4_0x8 = block<4, 8>; using block_q4_0x16 = block<4, 16>; @@ -157,6 +162,7 @@ void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q1_0_8x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -173,6 +179,7 @@ void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q1_0_8x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); @@ -209,6 +216,7 @@ void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q1_0_8x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -225,6 +233,7 @@ void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q1_0_8x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);