From c7a20ea80ad4089a60390d45cf2f47c4a79edeeb Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Fri, 13 Feb 2026 20:51:44 +0500 Subject: [PATCH 1/4] ggml-cpu: add rvv 512b,1024b impls for iq4_xs --- ggml/src/ggml-cpu/arch/riscv/quants.c | 201 ++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index d3278d6489f..9a485b32778 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -3911,6 +3911,201 @@ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT } #endif +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + const uint16_t index[32] = { + 0, 1, 16, 17, + 2, 3, 18, 19, + 4, 5,20, 21, + 6, 7, 22, 23, + 8, 9, 24, 25, + 10, 11, 26, 27, + 12, 13,28, 29, + 14, 15, 30, 31, + }; + const vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 32); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi = 0; + + #pragma GCC unroll 1 + // Process the entire super-block together. + for (int ib = 0; ib < QK_K / 256; ++ib) { + // Weights and activations. + const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 128); + iq4 += 128; + + // Unpack the weight blocks. + const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 128); + const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 128); + const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + const vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 32)); + const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 256); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256); + const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 256); + q8 += 256; + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), __riscv_vmv_v_x_i32m1(0, 1), 32)); + + + const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32; + + sumi += acc0 * ls0; + sumi += acc1 * ls1; + sumi += acc2 * ls2; + sumi += acc3 * ls3; + sumi += acc4 * ls4; + sumi += acc5 * ls5; + sumi += acc6 * ls6; + sumi += acc7 * ls7; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + const uint16_t index[32] = { + 0, 1, 16, 17, + 2, 3, 18, 19, + 4, 5,20, 21, + 6, 7, 22, 23, + 8, 9, 24, 25, + 10, 11, 26, 27, + 12, 13,28, 29, + 14, 15, 30, 31, + }; + const vuint16mf2_t i_vec = __riscv_vle16_v_u16mf2(index, 32); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi = 0; + + #pragma GCC unroll 1 + // Process the entire super-block together. + for (int ib = 0; ib < QK_K / 256; ++ib) { + // Weights and activations. + const vuint8m1_t iq4_packed = __riscv_vle8_v_u8m1(iq4, 128); + iq4 += 128; + + // Unpack the weight blocks. + const vuint8m1_t iq4bits_lo = __riscv_vand_vx_u8m1(iq4_packed, 0xf, 128); + const vuint8m1_t iq4bits_hi = __riscv_vsrl_vx_u8m1(iq4_packed, 4, 128); + const vuint8m2_t iq4bits = __riscv_vcreate_v_u8m1_u8m2(iq4bits_lo, iq4bits_hi); + const vuint8m2_t iq4bits_reorder = __riscv_vreinterpret_v_u64m2_u8m2(__riscv_vrgatherei16_vv_u64m2(__riscv_vreinterpret_v_u8m2_u64m2(iq4bits), i_vec, 32)); + const vint8m2_t iq4b = __riscv_vrgather_vv_i8m2(values, iq4bits_reorder, 256); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256); + const vint16m4_t prod = __riscv_vwmul_vv_i16m4(iq4b, q8b, 256); + q8 += 256; + + // Mask for processing 32 elements per prod register. + const vuint16m1_t p_index = __riscv_vid_v_u16m1(64); + const vbool16_t p_mask = __riscv_vmsgtu_vx_u16m1_b16(p_index, 31, 64); + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 64)); + + const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32; + + sumi += acc0 * ls0; + sumi += acc1 * ls1; + sumi += acc2 * ls2; + sumi += acc3 * ls3; + sumi += acc4 * ls4; + sumi += acc5 * ls5; + sumi += acc6 * ls6; + sumi += acc7 * ls7; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi); + } + + *s = sumf; +} + void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { @@ -3920,6 +4115,12 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v case 256: ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq4_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq4_xs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; From da99a5942220e98dc2e48642a35697a9bee08666 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Sat, 14 Feb 2026 21:44:02 +0500 Subject: [PATCH 2/4] ggml-cpu: refactor; add rvv 512b, 1024b impls for q6_K, i-quants --- ggml/src/ggml-cpu/arch/riscv/quants.c | 1581 ++++++++++++++++++------- 1 file changed, 1163 insertions(+), 418 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index 9a485b32778..0080e6b6672 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -1220,7 +1220,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } -void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +static NOINLINE void ggml_vec_dot_q4_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -1239,8 +1240,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint32_t utmp[4]; -#if defined __riscv_xtheadvector - const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; @@ -1354,277 +1353,315 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; float sumf = 0; - const int vector_length = __riscv_vlenb() * 8; + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + float ftmp, ft2; + const uint8_t * restrict q40; + const uint8_t * restrict q41; + const uint8_t * restrict q42; + const uint8_t * restrict q43; + const int8_t * restrict q80; + const int8_t * restrict q81; + const int8_t * restrict q82; + const int8_t * restrict q83; + int s0, s1, s2, s3; - size_t vl = 8; + __asm__ __volatile__( + "li %[s1], 8\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vle32.v v1, (%[s6b])\n\t" + "vslide1down.vx v1, v1, zero\n\t" + "vmv.v.x v16, zero\n\t" + "vslidedown.vi v2, v1, 2\n\t" + "vmv1r.v v3, v2\n\t" + "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "vsetivli zero, 2, e32, m1, ta, ma\n\t" + "vmv.v.i v4, 4\n\t" + "vand.vx v8, v1, %[kmask1]\n\t" + "vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "vsrl.vi v6, v1, 6\n\t" + "vsrl.vv v7, v2, v5\n\t" + "vsse32.v v8, (%[utmp]), %[s1]\n\t" + "vand.vx v0, v6, %[kmask3]\n\t" + "vand.vx v2, v7, %[kmask2]\n\t" + "vsll.vi v6, v0, 4\n\t" + "addi %[s0], %[utmp], 4\n\t" + "vor.vv v1, v6, v2\n\t" + "vsse32.v v1, (%[s0]), %[s1]\n\t" + "vsetivli zero, 8, e16, m1, ta, ma\n\t" + "vle32.v v2, (%[bsums])\n\t" + "vnsrl.wi v0, v2, 0\n\t" + "vnsrl.wi v1, v2, 16\n\t" + "vadd.vv v2, v0, v1\n\t" + "vle8.v v3, (%[mins])\n\t" + "vzext.vf2 v4, v3\n\t" + "vwmul.vv v6, v4, v2\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vredsum.vs v0, v6, v16\n\t" + "vredsum.vs v0, v7, v0\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "vsetivli zero, 16, e8, m1, ta, ma\n\t" + "vle8.v v0, (%[xs])\n\t" + "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" + "addi %[q40], %[xs], 64\n\t" + "addi %[q41], %[xs], 16\n\t" + "addi %[q42], %[xs], 32\n\t" + "addi %[q43], %[xs], 48\n\t" + "addi %[q80], %[ys], 64\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "addi %[q81], %[ys], 16\n\t" + "addi %[q41], %[q41], 64\n\t" + "addi %[q82], %[ys], 32\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[ys])\n\t" + "addi %[q42], %[q42], 64\n\t" + "addi %[q83], %[ys], 48\n\t" + "addi %[q43], %[q43], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v5, v1, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q80])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vsrl.vi v6, v2, 4\n\t" + "addi %[q80], %[q80], 64\n\t" + "vle8.v v13, (%[q81])\n\t" + "vle8.v v14, (%[q82])\n\t" + "vand.vi v2, v2, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v7, v3, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vle8.v v15, (%[q83])\n\t" + "vle8.v v0, (%[q40])\n\t" + "vand.vi v3, v3, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vwmul.vv v24, v2, v12\n\t" + "vwmul.vv v20, v4, v10\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmacc.vv v16, v1, v9\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "vwmacc.vv v24, v3, v13\n\t" + "vwmacc.vv v20, v5, v11\n\t" + "vwmacc.vv v28, v7, v15\n\t" + "addi %[q40], %[q80], 64\n\t" + "addi %[q41], %[q81], 64\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[q80])\n\t" + "addi %[q42], %[q82], 64\n\t" + "addi %[q43], %[q83], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v7, v3, 4\n\t" + "vand.vi v3, v3, 0xF\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q40])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "vsrl.vi v6, v2, 4\n\t" + "vand.vi v2, v2, 0xF\n\t" + "vwmul.vv v18, v0, v8\n\t" + "vle8.v v13, (%[q41])\n\t" + "vle8.v v14, (%[q42])\n\t" + "vwmul.vv v26, v2, v12\n\t" + "vwmul.vv v22, v4, v10\n\t" + "vwmul.vv v30, v6, v14\n\t" + "vwmacc.vv v18, v1, v9\n\t" + "vle8.v v15, (%[q43])\n\t" + "vwmacc.vv v26, v3, v13\n\t" + "vwmacc.vv v22, v5, v11\n\t" + "vwmacc.vv v30, v7, v15\n\t" + "vmv.v.x v0, zero\n\t" + "vsetivli zero, 16, e16, m2, ta, ma\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "lbu %[s0], 0(%[scale])\n\t" + "vwredsum.vs v5, v20, v0\n\t" + "lbu %[s1], 1(%[scale])\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "lbu %[s2], 2(%[scale])\n\t" + "vwredsum.vs v7, v28, v0\n\t" + "lbu %[s3], 3(%[scale])\n\t" + "vwredsum.vs v8, v18, v0\n\t" + "lbu %[q40], 4(%[scale])\n\t" + "vwredsum.vs v9, v22, v0\n\t" + "lbu %[q41], 5(%[scale])\n\t" + "vwredsum.vs v10, v26, v0\n\t" + "lbu %[q42], 6(%[scale])\n\t" + "vwredsum.vs v11, v30, v0\n\t" + "lbu %[q43], 7(%[scale])\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vmul.vx v0, v4, %[s0]\n\t" + "vmul.vx v1, v8, %[q40]\n\t" + "vmacc.vx v0, %[s1], v5\n\t" + "vmacc.vx v1, %[q41], v9\n\t" + "vmacc.vx v0, %[s2], v6\n\t" + "vmacc.vx v1, %[q42], v10\n\t" + "vmacc.vx v0, %[s3], v7\n\t" + "vmacc.vx v1, %[q43], v11\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfcvt.f.x.v v1, v1\n\t" + "vfmv.f.s %[ft2], v0\n\t" + "vfmv.f.s %[ftmp], v1\n\t" + "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" + "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" + : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) + , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) + , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) + , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) + : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) + , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + } - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + *s = sumf; +} - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); +static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + const int nb = n / QK_K; - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + uint32_t utmp[4]; - vl = 32; + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; - int32_t sum_1 = 0; - int32_t sum_2 = 0; + float sumf = 0; + for (int i = 0; i < nb; ++i) { + size_t vl = 8; - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - q4 += 32; q8 += 64; + vl = 32; - } + int32_t sum_1 = 0; + int32_t sum_2 = 0; - sumf += d*(sum_1 + sum_2); + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - - float ftmp, ft2; - const uint8_t * restrict q40; - const uint8_t * restrict q41; - const uint8_t * restrict q42; - const uint8_t * restrict q43; - const int8_t * restrict q80; - const int8_t * restrict q81; - const int8_t * restrict q82; - const int8_t * restrict q83; - int s0, s1, s2, s3; + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; - __asm__ __volatile__( - "li %[s1], 8\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vle32.v v1, (%[s6b])\n\t" - "vslide1down.vx v1, v1, zero\n\t" - "vmv.v.x v16, zero\n\t" - "vslidedown.vi v2, v1, 2\n\t" - "vmv1r.v v3, v2\n\t" - "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} - "vsetivli zero, 2, e32, m1, ta, ma\n\t" - "vmv.v.i v4, 4\n\t" - "vand.vx v8, v1, %[kmask1]\n\t" - "vslide1up.vx v5, v4, zero\n\t" // {0, 4} - "vsrl.vi v6, v1, 6\n\t" - "vsrl.vv v7, v2, v5\n\t" - "vsse32.v v8, (%[utmp]), %[s1]\n\t" - "vand.vx v0, v6, %[kmask3]\n\t" - "vand.vx v2, v7, %[kmask2]\n\t" - "vsll.vi v6, v0, 4\n\t" - "addi %[s0], %[utmp], 4\n\t" - "vor.vv v1, v6, v2\n\t" - "vsse32.v v1, (%[s0]), %[s1]\n\t" - "vsetivli zero, 8, e16, m1, ta, ma\n\t" - "vle32.v v2, (%[bsums])\n\t" - "vnsrl.wi v0, v2, 0\n\t" - "vnsrl.wi v1, v2, 16\n\t" - "vadd.vv v2, v0, v1\n\t" - "vle8.v v3, (%[mins])\n\t" - "vzext.vf2 v4, v3\n\t" - "vwmul.vv v6, v4, v2\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vredsum.vs v0, v6, v16\n\t" - "vredsum.vs v0, v7, v0\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfmv.f.s %[ftmp], v0\n\t" - "vsetivli zero, 16, e8, m1, ta, ma\n\t" - "vle8.v v0, (%[xs])\n\t" - "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" - "addi %[q40], %[xs], 64\n\t" - "addi %[q41], %[xs], 16\n\t" - "addi %[q42], %[xs], 32\n\t" - "addi %[q43], %[xs], 48\n\t" - "addi %[q80], %[ys], 64\n\t" - "vle8.v v1, (%[q41])\n\t" - "vle8.v v2, (%[q42])\n\t" - "addi %[q81], %[ys], 16\n\t" - "addi %[q41], %[q41], 64\n\t" - "addi %[q82], %[ys], 32\n\t" - "vle8.v v3, (%[q43])\n\t" - "vle8.v v8, (%[ys])\n\t" - "addi %[q42], %[q42], 64\n\t" - "addi %[q83], %[ys], 48\n\t" - "addi %[q43], %[q43], 64\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vle8.v v9, (%[q81])\n\t" - "vle8.v v10, (%[q82])\n\t" - "vand.vi v0, v0, 0xF\n\t" - "addi %[q81], %[q81], 64\n\t" - "vsrl.vi v5, v1, 4\n\t" - "addi %[q82], %[q82], 64\n\t" - "vle8.v v11, (%[q83])\n\t" - "vle8.v v12, (%[q80])\n\t" - "vand.vi v1, v1, 0xF\n\t" - "addi %[q83], %[q83], 64\n\t" - "vsrl.vi v6, v2, 4\n\t" - "addi %[q80], %[q80], 64\n\t" - "vle8.v v13, (%[q81])\n\t" - "vle8.v v14, (%[q82])\n\t" - "vand.vi v2, v2, 0xF\n\t" - "addi %[q81], %[q81], 64\n\t" - "vsrl.vi v7, v3, 4\n\t" - "addi %[q82], %[q82], 64\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vle8.v v15, (%[q83])\n\t" - "vle8.v v0, (%[q40])\n\t" - "vand.vi v3, v3, 0xF\n\t" - "addi %[q83], %[q83], 64\n\t" - "vwmul.vv v24, v2, v12\n\t" - "vwmul.vv v20, v4, v10\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmacc.vv v16, v1, v9\n\t" - "vle8.v v1, (%[q41])\n\t" - "vle8.v v2, (%[q42])\n\t" - "vwmacc.vv v24, v3, v13\n\t" - "vwmacc.vv v20, v5, v11\n\t" - "vwmacc.vv v28, v7, v15\n\t" - "addi %[q40], %[q80], 64\n\t" - "addi %[q41], %[q81], 64\n\t" - "vle8.v v3, (%[q43])\n\t" - "vle8.v v8, (%[q80])\n\t" - "addi %[q42], %[q82], 64\n\t" - "addi %[q43], %[q83], 64\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vle8.v v9, (%[q81])\n\t" - "vle8.v v10, (%[q82])\n\t" - "vand.vi v0, v0, 0xF\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v7, v3, 4\n\t" - "vand.vi v3, v3, 0xF\n\t" - "vle8.v v11, (%[q83])\n\t" - "vle8.v v12, (%[q40])\n\t" - "vand.vi v1, v1, 0xF\n\t" - "vsrl.vi v6, v2, 4\n\t" - "vand.vi v2, v2, 0xF\n\t" - "vwmul.vv v18, v0, v8\n\t" - "vle8.v v13, (%[q41])\n\t" - "vle8.v v14, (%[q42])\n\t" - "vwmul.vv v26, v2, v12\n\t" - "vwmul.vv v22, v4, v10\n\t" - "vwmul.vv v30, v6, v14\n\t" - "vwmacc.vv v18, v1, v9\n\t" - "vle8.v v15, (%[q43])\n\t" - "vwmacc.vv v26, v3, v13\n\t" - "vwmacc.vv v22, v5, v11\n\t" - "vwmacc.vv v30, v7, v15\n\t" - "vmv.v.x v0, zero\n\t" - "vsetivli zero, 16, e16, m2, ta, ma\n\t" - "vwredsum.vs v4, v16, v0\n\t" - "lbu %[s0], 0(%[scale])\n\t" - "vwredsum.vs v5, v20, v0\n\t" - "lbu %[s1], 1(%[scale])\n\t" - "vwredsum.vs v6, v24, v0\n\t" - "lbu %[s2], 2(%[scale])\n\t" - "vwredsum.vs v7, v28, v0\n\t" - "lbu %[s3], 3(%[scale])\n\t" - "vwredsum.vs v8, v18, v0\n\t" - "lbu %[q40], 4(%[scale])\n\t" - "vwredsum.vs v9, v22, v0\n\t" - "lbu %[q41], 5(%[scale])\n\t" - "vwredsum.vs v10, v26, v0\n\t" - "lbu %[q42], 6(%[scale])\n\t" - "vwredsum.vs v11, v30, v0\n\t" - "lbu %[q43], 7(%[scale])\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vmul.vx v0, v4, %[s0]\n\t" - "vmul.vx v1, v8, %[q40]\n\t" - "vmacc.vx v0, %[s1], v5\n\t" - "vmacc.vx v1, %[q41], v9\n\t" - "vmacc.vx v0, %[s2], v6\n\t" - "vmacc.vx v1, %[q42], v10\n\t" - "vmacc.vx v0, %[s3], v7\n\t" - "vmacc.vx v1, %[q43], v11\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfcvt.f.x.v v1, v1\n\t" - "vfmv.f.s %[ft2], v0\n\t" - "vfmv.f.s %[ftmp], v1\n\t" - "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" - "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" - : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) - , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) - , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) - , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) - : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) - , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) - , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) - , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); } - break; - default: - assert(false && "Unsupported vector length"); - break; + + sumf += d*(sum_1 + sum_2); + } *s = sumf; +} +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q4_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q4_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q4_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(nb); - UNUSED(utmp); - ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -1725,7 +1762,6 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2); q5 += 32; q8 += 64; - } sums += aux32 * d; @@ -1748,7 +1784,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +static NOINLINE void ggml_vec_dot_q6_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -1761,8 +1798,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; for (int i = 0; i < nb; ++i) { @@ -1841,219 +1876,455 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(&x[i + 1].d, 0, 1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int q6h; + float ftmp; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "addi %[q6h], %[q6], 32\n\t" + "ld t0, 0(%[scale])\n\t" + "addi %[scale], %[scale], 8\n\t" + "slli t6, t0, 1 * 8\n\t" + "lb zero, 0(%[q6])\n\t" + "slli t5, t0, 2 * 8\n\t" + "slli t4, t0, 3 * 8\n\t" + "lb zero, 0(%[q6h])\n\t" + "slli t3, t0, 4 * 8\n\t" + "slli t2, t0, 5 * 8\n\t" + "lb zero, 0(%[qh])\n\t" + "lb zero, 31(%[q6h])\n\t" + "slli t1, t0, 6 * 8\n\t" + "srai a7, t0, 56\n\t" + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v8, (%[q6])\n\t" + "srai t6, t6, 56\n\t" + "srai t5, t5, 56\n\t" + "srai t4, t4, 56\n\t" + "srai t3, t3, 56\n\t" + "vle8.v v10, (%[q6h])\n\t" + "addi %[q6], %[q6], 64\n\t" + "slli t0, t0, 7 * 8\n\t" + "srai t2, t2, 56\n\t" + "srai t1, t1, 56\n\t" + "srai t0, t0, 56\n\t" + "vle8.v v4, (%[qh])\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v10, 4\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vi v8, v8, 0xF\n\t" + "vand.vi v10, v10, 0xF\n\t" + "lb zero, 32(%[q8])\n\t" + "vsll.vi v0, v4, 4\n\t" + "vsll.vi v2, v4, 2\n\t" + "lb zero, 64(%[q8])\n\t" + "vsrl.vi v6, v4, 2\n\t" + "vand.vx v0, v0, %[mask]\n\t" + "lb zero, 96(%[q8])\n\t" + "vand.vx v2, v2, %[mask]\n\t" + "vand.vx v4, v4, %[mask]\n\t" + "vand.vx v6, v6, %[mask]\n\t" + "vor.vv v8, v8, v0\n\t" + "lb zero, 127(%[q8])\n\t" + "vor.vv v10, v10, v2\n\t" + "vor.vv v12, v12, v4\n\t" + "vor.vv v14, v14, v6\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsub.vx v8, v8, %[vl32]\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v10, t0\n\t" + "vmul.vx v1, v9, t1\n\t" + "vmacc.vx v0, t2, v8\n\t" + "vmacc.vx v1, t3, v7\n\t" + "vmacc.vx v0, t4, v11\n\t" + "vmacc.vx v1, t5, v12\n\t" + "vmacc.vx v0, t6, v13\n\t" + "vmacc.vx v1, a7, v14\n\t" + "vadd.vv v0, v0, v1\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" + : [q6] "+&r" (q6), [q6h] "=&r" (q6h) + , [scale] "+&r" (scale) + , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) + : [qh] "r" (qh), [q8] "r" (q8) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30), [d] "f" (d) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" + , "a6", "a5", "a4", "a3" + ); + qh += 32; q8 += 128; + } + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; - const int vector_length = __riscv_vlenb() * 8; + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * GGML_RESTRICT scale = x[i].scales; - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + size_t vl; - const int8_t * GGML_RESTRICT scale = x[i].scales; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - size_t vl; + int sum_t = 0; + int is = 0; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + for (int j = 0; j < QK_K/128; ++j) { - int sum_t = 0; - int is = 0; + vl = 32; - for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - vl = 32; + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + vl = 16; - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - vl = 16; + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + q6 += 64; qh += 32; q8 += 128; is=8; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + } - q6 += 64; qh += 32; q8 += 128; is=8; + sumf += d * sum_t; - } + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16m1_t va_index = __riscv_vid_v_u16m1(32); + const vbool16_t va_mask = __riscv_vmsgtu_vx_u16m1_b16(va_index, 15, 32); + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; - sumf += d * sum_t; + size_t vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + int sum_t = 0; + int is = 0; + + vint32m2_t vaux_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); + + // load Q6 + vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); + vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+32, vl); + + vuint8mf2_t q6a_0 = __riscv_vand_vx_u8mf2(q6_0, 0x0F, vl); + vuint8mf2_t q6a_1 = __riscv_vand_vx_u8mf2(q6_1, 0x0F, vl); + vuint8mf2_t q6s_0 = __riscv_vsrl_vx_u8mf2(q6_0, 0x04, vl); + vuint8mf2_t q6s_1 = __riscv_vsrl_vx_u8mf2(q6_1, 0x04, vl); + + vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(qh_x, 0x03, vl); + vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x03 , vl); + vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x03 , vl); + vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x6, vl), 0x03 , vl); + + vuint8mf2_t qhi_0 = __riscv_vor_vv_u8mf2(q6a_0, __riscv_vsll_vx_u8mf2(qh_0, 0x04, vl), vl); + vuint8mf2_t qhi_1 = __riscv_vor_vv_u8mf2(q6a_1, __riscv_vsll_vx_u8mf2(qh_1, 0x04, vl), vl); + vuint8mf2_t qhi_2 = __riscv_vor_vv_u8mf2(q6s_0, __riscv_vsll_vx_u8mf2(qh_2, 0x04, vl), vl); + vuint8mf2_t qhi_3 = __riscv_vor_vv_u8mf2(q6s_1, __riscv_vsll_vx_u8mf2(qh_3, 0x04, vl), vl); + + vint8mf2_t a_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_0), 32, vl); + vint8mf2_t a_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_1), 32, vl); + vint8mf2_t a_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_2), 32, vl); + vint8mf2_t a_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_3), 32, vl); + + // load Q8 and take product + vint16m1_t va_q_0 = __riscv_vwmul_vv_i16m1(a_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t va_q_1 = __riscv_vwmul_vv_i16m1(a_1, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t va_q_2 = __riscv_vwmul_vv_i16m1(a_2, __riscv_vle8_v_i8mf2(q8+64, vl), vl); + vint16m1_t va_q_3 = __riscv_vwmul_vv_i16m1(a_3, __riscv_vle8_v_i8mf2(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m2(vaux_0, scale[is+0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m2(vaux_1, scale[is+2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m2(vaux_2, scale[is+4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m2(vaux_3, scale[is+6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_0, scale[is+1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_1, scale[is+3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_2, scale[is+5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_3, scale[is+7], va_q_3, vl); + + q6 += 64; qh += 32; q8 += 128; is=8; } - break; - case 128: - for (int i = 0; i < nb; ++i) { - __builtin_prefetch(&x[i + 1].d, 0, 1); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; + sumf += d * sum_t; - const int8_t * restrict scale = x[i].scales; + } - int q6h; - float ftmp; + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16mf2_t va_index = __riscv_vid_v_u16mf2(32); + const vbool32_t va_mask = __riscv_vmsgtu_vx_u16mf2_b32(va_index, 15, 32); + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + size_t vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + vint32m1_t vaux_0 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_1 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_2 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_3 = __riscv_vmv_v_x_i32m1(0, vl); + + for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8mf4_t qh_x = __riscv_vle8_v_u8mf4(qh, vl); + + // load Q6 + vuint8mf4_t q6_0 = __riscv_vle8_v_u8mf4(q6, vl); + vuint8mf4_t q6_1 = __riscv_vle8_v_u8mf4(q6+32, vl); + + vuint8mf4_t q6a_0 = __riscv_vand_vx_u8mf4(q6_0, 0x0F, vl); + vuint8mf4_t q6a_1 = __riscv_vand_vx_u8mf4(q6_1, 0x0F, vl); + vuint8mf4_t q6s_0 = __riscv_vsrl_vx_u8mf4(q6_0, 0x04, vl); + vuint8mf4_t q6s_1 = __riscv_vsrl_vx_u8mf4(q6_1, 0x04, vl); + + vuint8mf4_t qh_0 = __riscv_vand_vx_u8mf4(qh_x, 0x03, vl); + vuint8mf4_t qh_1 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x2, vl), 0x03 , vl); + vuint8mf4_t qh_2 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x4, vl), 0x03 , vl); + vuint8mf4_t qh_3 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x6, vl), 0x03 , vl); + + vuint8mf4_t qhi_0 = __riscv_vor_vv_u8mf4(q6a_0, __riscv_vsll_vx_u8mf4(qh_0, 0x04, vl), vl); + vuint8mf4_t qhi_1 = __riscv_vor_vv_u8mf4(q6a_1, __riscv_vsll_vx_u8mf4(qh_1, 0x04, vl), vl); + vuint8mf4_t qhi_2 = __riscv_vor_vv_u8mf4(q6s_0, __riscv_vsll_vx_u8mf4(qh_2, 0x04, vl), vl); + vuint8mf4_t qhi_3 = __riscv_vor_vv_u8mf4(q6s_1, __riscv_vsll_vx_u8mf4(qh_3, 0x04, vl), vl); + + vint8mf4_t a_0 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_0), 32, vl); + vint8mf4_t a_1 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_1), 32, vl); + vint8mf4_t a_2 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_2), 32, vl); + vint8mf4_t a_3 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_3), 32, vl); + + // load Q8 and take product + vint16mf2_t va_q_0 = __riscv_vwmul_vv_i16mf2(a_0, __riscv_vle8_v_i8mf4(q8, vl), vl); + vint16mf2_t va_q_1 = __riscv_vwmul_vv_i16mf2(a_1, __riscv_vle8_v_i8mf4(q8+32, vl), vl); + vint16mf2_t va_q_2 = __riscv_vwmul_vv_i16mf2(a_2, __riscv_vle8_v_i8mf4(q8+64, vl), vl); + vint16mf2_t va_q_3 = __riscv_vwmul_vv_i16mf2(a_3, __riscv_vle8_v_i8mf4(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m1(vaux_0, scale[is+0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m1(vaux_1, scale[is+2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m1(vaux_2, scale[is+4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m1(vaux_3, scale[is+6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_0, scale[is+1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_1, scale[is+3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_2, scale[is+5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_3, scale[is+7], va_q_3, vl); + + q6 += 64; qh += 32; q8 += 128; is=8; - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "addi %[q6h], %[q6], 32\n\t" - "ld t0, 0(%[scale])\n\t" - "addi %[scale], %[scale], 8\n\t" - "slli t6, t0, 1 * 8\n\t" - "lb zero, 0(%[q6])\n\t" - "slli t5, t0, 2 * 8\n\t" - "slli t4, t0, 3 * 8\n\t" - "lb zero, 0(%[q6h])\n\t" - "slli t3, t0, 4 * 8\n\t" - "slli t2, t0, 5 * 8\n\t" - "lb zero, 0(%[qh])\n\t" - "lb zero, 31(%[q6h])\n\t" - "slli t1, t0, 6 * 8\n\t" - "srai a7, t0, 56\n\t" - "vsetvli zero, %[vl32], e8, m2\n\t" - "vle8.v v8, (%[q6])\n\t" - "srai t6, t6, 56\n\t" - "srai t5, t5, 56\n\t" - "srai t4, t4, 56\n\t" - "srai t3, t3, 56\n\t" - "vle8.v v10, (%[q6h])\n\t" - "addi %[q6], %[q6], 64\n\t" - "slli t0, t0, 7 * 8\n\t" - "srai t2, t2, 56\n\t" - "srai t1, t1, 56\n\t" - "srai t0, t0, 56\n\t" - "vle8.v v4, (%[qh])\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v10, 4\n\t" - "lb zero, 0(%[q8])\n\t" - "vand.vi v8, v8, 0xF\n\t" - "vand.vi v10, v10, 0xF\n\t" - "lb zero, 32(%[q8])\n\t" - "vsll.vi v0, v4, 4\n\t" - "vsll.vi v2, v4, 2\n\t" - "lb zero, 64(%[q8])\n\t" - "vsrl.vi v6, v4, 2\n\t" - "vand.vx v0, v0, %[mask]\n\t" - "lb zero, 96(%[q8])\n\t" - "vand.vx v2, v2, %[mask]\n\t" - "vand.vx v4, v4, %[mask]\n\t" - "vand.vx v6, v6, %[mask]\n\t" - "vor.vv v8, v8, v0\n\t" - "lb zero, 127(%[q8])\n\t" - "vor.vv v10, v10, v2\n\t" - "vor.vv v12, v12, v4\n\t" - "vor.vv v14, v14, v6\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "vsub.vx v8, v8, %[vl32]\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v10, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v8, v20, v0\n\t" - "vwredsum.vs v7, v22, v0\n\t" - "vwredsum.vs v11, v24, v0\n\t" - "vwredsum.vs v12, v26, v0\n\t" - "vwredsum.vs v13, v28, v0\n\t" - "vwredsum.vs v14, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v10, t0\n\t" - "vmul.vx v1, v9, t1\n\t" - "vmacc.vx v0, t2, v8\n\t" - "vmacc.vx v1, t3, v7\n\t" - "vmacc.vx v0, t4, v11\n\t" - "vmacc.vx v1, t5, v12\n\t" - "vmacc.vx v0, t6, v13\n\t" - "vmacc.vx v1, a7, v14\n\t" - "vadd.vv v0, v0, v1\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfmv.f.s %[ftmp], v0\n\t" - "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" - : [q6] "+&r" (q6), [q6h] "=&r" (q6h) - , [scale] "+&r" (scale) - , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) - : [qh] "r" (qh), [q8] "r" (q8) - , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) - , [mask] "r" (0x30), [d] "f" (d) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" - , "a6", "a5", "a4", "a3" - ); - qh += 32; q8 += 128; - } } - break; - default: - assert(false && "Unsupported vector length"); - break; + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + sumf += d * sum_t; + } *s = sumf; +} +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q6_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q6_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_q6_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_q6_K_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_q6_K_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -2268,6 +2539,186 @@ static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT } #endif +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf4_t qh = __riscv_vle16_v_u16mf4(x[i].qh, 8); + + // Calculate ls. + vuint16mf4_t temp = __riscv_vsrl_vx_u16mf4(qh, 12, 8); + temp = __riscv_vand_vx_u16mf4(temp, 7, 8); + vint32mf2_t ls = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vwmulu_vx_u32mf2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32mf2(ls, 1, 8); + + // Calculate delta. + vbool64_t mask = __riscv_vmseq_vx_u16mf4_b64(__riscv_vand_vx_u16mf4(qh, 0x8000, 8), 0, 8); + vint32mf2_t delta_neg = __riscv_vmv_v_x_i32mf2(-1, 8); + vint32mf2_t delta_pos = __riscv_vmv_v_x_i32mf2(1, 8); + vint32mf2_t delta = __riscv_vmerge_vvm_i32mf2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8mf2_t qs = __riscv_vle8_v_u8mf2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m1_t qh_shift = __riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(shift, 8)); + vuint16m1_t qh_gather_index = __riscv_vreinterpret_v_i16m1_u16m1( + __riscv_vdiv_vx_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(32)), 4, 32)); + vuint16m1_t qh_ext = __riscv_vlmul_ext_v_u16mf2_u16m1(__riscv_vlmul_ext_v_u16mf4_u16mf2(qh)); + vuint16m1_t qh_index = __riscv_vrgather_vv_u16m1(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m1(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m1(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m1(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m1(qh_index, __riscv_vzext_vf2_u16m1(qs, 32), 32); + vuint16m1_t index = __riscv_vsll_vx_u16m1(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-8 + { + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, index, 32)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 256); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 256); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 1), one_scalar, 32)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 2), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 3), one_scalar, 32)); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 4), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 5), one_scalar, 32)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 6), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 7), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32mf2_t lsums = __riscv_vle32_v_i32mf2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16mf2_t bsums_0 = __riscv_vle16_v_i16mf2(y[i].bsums, 16); + const vuint32mf2_t bsums_i32 = __riscv_vreinterpret_v_u16mf2_u32mf2(__riscv_vreinterpret_v_i16mf2_u16mf2(bsums_0)); + const vint16mf4_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 0, 8)); + const vint16mf4_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 16, 8)); + const vint32mf2_t bsums = __riscv_vwadd_vv_i32mf2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32mf2_t sumi_v = __riscv_vmul_vv_i32mf2(ls, lsums, 8); + vint32mf2_t sumi1_v = __riscv_vmul_vv_i32mf2(__riscv_vmul_vv_i32mf2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // Mask for processing 32 elements per lsum register. + vuint16m1_t l_index = __riscv_vid_v_u16m1(64); + vbool16_t l_mask = __riscv_vmsgtu_vx_u16m1_b16(l_index, 31, 64); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf4_t qh = __riscv_vle16_v_u16mf4(x[i].qh, 8); + + // Calculate ls. + vuint16mf4_t temp = __riscv_vsrl_vx_u16mf4(qh, 12, 8); + temp = __riscv_vand_vx_u16mf4(temp, 7, 8); + vint32mf2_t ls = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vwmulu_vx_u32mf2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32mf2(ls, 1, 8); + + // Calculate delta. + vbool64_t mask = __riscv_vmseq_vx_u16mf4_b64(__riscv_vand_vx_u16mf4(qh, 0x8000, 8), 0, 8); + vint32mf2_t delta_neg = __riscv_vmv_v_x_i32mf2(-1, 8); + vint32mf2_t delta_pos = __riscv_vmv_v_x_i32mf2(1, 8); + vint32mf2_t delta = __riscv_vmerge_vvm_i32mf2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8mf2_t qs = __riscv_vle8_v_u8mf2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m1_t qh_shift = __riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(shift, 8)); + vuint16m1_t qh_gather_index = __riscv_vreinterpret_v_i16m1_u16m1( + __riscv_vdiv_vx_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(32)), 4, 32)); + vuint16m1_t qh_ext = __riscv_vlmul_ext_v_u16mf2_u16m1(__riscv_vlmul_ext_v_u16mf4_u16mf2(qh)); + vuint16m1_t qh_index = __riscv_vrgather_vv_u16m1(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m1(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m1(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m1(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m1(qh_index, __riscv_vzext_vf2_u16m1(qs, 32), 32); + vuint16mf2_t index = __riscv_vlmul_trunc_v_u16m1_u16mf2(__riscv_vsll_vx_u16m1(qh_index, 3, 32)); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-8 + { + vint8m2_t grid0 = __riscv_vreinterpret_v_i64m2_i8m2(__riscv_vluxei16_v_i64m2((const int64_t*)iq1s_grid, index, 32)); + vint8m2_t q80 = __riscv_vle8_v_i8m2(y[i].qs, 256); + vint16m4_t lsum0 = __riscv_vwmul_vv_i16m4(grid0, q80, 256); + + // Reduce. + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 0), one_scalar, 64)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 1), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 1), one_scalar, 64)); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 2), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 2), one_scalar, 64)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 3), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 3), one_scalar, 64)); + } + __asm__ __volatile__("" ::: "memory"); + vint32mf2_t lsums = __riscv_vle32_v_i32mf2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16mf2_t bsums_0 = __riscv_vle16_v_i16mf2(y[i].bsums, 16); + const vuint32mf2_t bsums_i32 = __riscv_vreinterpret_v_u16mf2_u32mf2(__riscv_vreinterpret_v_i16mf2_u16mf2(bsums_0)); + const vint16mf4_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 0, 8)); + const vint16mf4_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 16, 8)); + const vint32mf2_t bsums = __riscv_vwadd_vv_i32mf2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32mf2_t sumi_v = __riscv_vmul_vv_i32mf2(ls, lsums, 8); + vint32mf2_t sumi1_v = __riscv_vmul_vv_i32mf2(__riscv_vmul_vv_i32mf2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { @@ -2277,6 +2728,12 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq1_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq1_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2568,6 +3025,283 @@ static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT } #endif +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + + // Mask for processing 16 elements per lsum register. + const vuint16m1_t l_index = __riscv_vid_v_u16m1(32); + const vbool16_t l_mask = __riscv_vmsgtu_vx_u16m1_b16(l_index, 15, 32); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 32); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 32); + + // We process all the sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/256; ib++) { + // Load qh for all 16 sub-blocks. + const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 16); + const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 16); + const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 16); + const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1( + __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 16)), 32); + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 32), 32); + const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 16)); + vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 32), 0x700, 32), 32); + index = __riscv_vsll_vx_u16m1(index, 3, 32); + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, index, 32))); + + // Prepare the deltas. + const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( + __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 16)), 32), 0, 32); + const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 32); + const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( + __riscv_vmerge_vxm_i64m4(delta_pos, 0xffffffffffffffff, mask, 32)); + + // Load q8 for sub-blocks. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256); + + // Calculate the lsums. + const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 256); + const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 256); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_2 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_3 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_4 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_5 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_6 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_7 = 2*((sc[1] >> 9) & 0x7) + 1; + const int16_t ls_8 = 2*((sc[2] >> 0) & 0x7) + 1; + const int16_t ls_9 = 2*((sc[2] >> 3) & 0x7) + 1; + const int16_t ls_10 = 2*((sc[2] >> 6) & 0x7) + 1; + const int16_t ls_11 = 2*((sc[2] >> 9) & 0x7) + 1; + const int16_t ls_12 = 2*((sc[3] >> 0) & 0x7) + 1; + const int16_t ls_13 = 2*((sc[3] >> 3) & 0x7) + 1; + const int16_t ls_14 = 2*((sc[3] >> 6) & 0x7) + 1; + const int16_t ls_15 = 2*((sc[3] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_1, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_1, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_2, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_3, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_2, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_3, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_4, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_5, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_4, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_5, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_6, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_7, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_6, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_7, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_8, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_9, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_8, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_9, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_10, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_11, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_10, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_11, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_12, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_13, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_12, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_13, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_14, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_15, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_14, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_15, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 32); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 32)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 32)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 64); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 64); + + // We process all the sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/256; ib++) { + // Load qh for all 16 sub-blocks. + const vuint8mf8_t qh_8 = __riscv_vle8_v_u8mf8(qh, 16); + const vuint16mf4_t qh_16_lo = __riscv_vzext_vf2_u16mf4(qh_8, 16); + const vuint16mf4_t qh_16_hi = __riscv_vsll_vx_u16mf4(qh_16_lo, 8, 16); + const vuint16mf2_t qhb = __riscv_vzext_vf2_u16mf2( + __riscv_vreinterpret_v_u16mf4_u8mf4(__riscv_vor_vv_u16mf4(qh_16_lo, qh_16_hi, 16)), 32); + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16mf2_t qsb = __riscv_vzext_vf2_u16mf2(__riscv_vle8_v_u8mf4(&qs[0], 32), 32); + const vuint16mf2_t shift = __riscv_vreinterpret_v_u32mf2_u16mf2(__riscv_vmv_v_x_u32mf2(0x00040008, 16)); + vuint16mf2_t index = __riscv_vor_vv_u16mf2(qsb, __riscv_vand_vx_u16mf2(__riscv_vsll_vv_u16mf2(qhb, shift, 32), 0x700, 32), 32); + index = __riscv_vsll_vx_u16mf2(index, 3, 32); + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m2_t iq1b = __riscv_vreinterpret_v_i64m2_i8m2(__riscv_vreinterpret_v_u64m2_i64m2( + __riscv_vluxei16_v_u64m2(iq1s_grid, index, 32))); + + // Prepare the deltas. + const vbool32_t mask = __riscv_vmsgtu_vx_u16mf2_b32( + __riscv_vand_vv_u16mf2(qhb, __riscv_vreinterpret_v_u32mf2_u16mf2(__riscv_vmv_v_x_u32mf2(0x00800008, 16)), 32), 0, 32); + const vint64m2_t delta_pos = __riscv_vmv_v_x_i64m2(0x0101010101010101, 32); + const vint8m2_t delta = __riscv_vreinterpret_v_i64m2_i8m2( + __riscv_vmerge_vxm_i64m2(delta_pos, 0xffffffffffffffff, mask, 32)); + + // Load q8 for sub-blocks. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256); + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(iq1b, q8b, 256); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(delta, q8b, 256); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_2 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_3 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_4 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_5 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_6 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_7 = 2*((sc[1] >> 9) & 0x7) + 1; + const int16_t ls_8 = 2*((sc[2] >> 0) & 0x7) + 1; + const int16_t ls_9 = 2*((sc[2] >> 3) & 0x7) + 1; + const int16_t ls_10 = 2*((sc[2] >> 6) & 0x7) + 1; + const int16_t ls_11 = 2*((sc[2] >> 9) & 0x7) + 1; + const int16_t ls_12 = 2*((sc[3] >> 0) & 0x7) + 1; + const int16_t ls_13 = 2*((sc[3] >> 3) & 0x7) + 1; + const int16_t ls_14 = 2*((sc[3] >> 6) & 0x7) + 1; + const int16_t ls_15 = 2*((sc[3] >> 9) & 0x7) + 1; + + // Mask for processing 16 elements per lsum register. + const vuint16m1_t l_index = __riscv_vid_v_u16m1(64); + + // Accumulate in acc1 and acc2 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_4, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_4, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_8, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_8, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_12, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_12, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 16); + // + const vbool16_t l_mask_16_32 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 15, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_1, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_1, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_5, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_5, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_9, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_9, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_13, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_13, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 32); + // + const vbool16_t l_mask_32_48 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 31, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_2, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_2, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_6, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_6, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_10, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_10, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_14, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_14, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 48); + // + const vbool16_t l_mask_48_64 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 47, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_3, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_3, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_7, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_7, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_11, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_11, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_15, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_15, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 64); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 64)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 64)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { @@ -2577,6 +3311,12 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq1_m_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq1_m_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2964,13 +3704,17 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64)); vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64)); + // Apply signs vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64); + // Load Q8 weights (64 elements) vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64); q8 += 64; + // Multiply (Widening to int16, 64 elements -> LMUL=4) vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64); + // Reduction vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( @@ -2982,6 +3726,7 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16)); + // Apply Scales const uint8_t scale_byte_1 = scales[0]; const uint8_t scale_byte_2 = scales[1]; scales += 2; From 9565a409c7edbb8b6b10509272e7a7c79c2f490c Mon Sep 17 00:00:00 2001 From: Rehan Qasim Date: Tue, 24 Feb 2026 11:52:54 +0500 Subject: [PATCH 3/4] added 512 and 1024 implementations of tq3_s, iq3_xxs, iq2_s, iq2_xs, iq2_xxs --- ggml/src/ggml-cpu/arch/riscv/quants.c | 682 +++++++++++++++++++++++++- 1 file changed, 677 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index 0080e6b6672..9b0d8bc1fb5 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -3529,7 +3529,272 @@ static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT } *s = 0.125f * sumf; } -#endif + +static void ggml_vec_dot_iq2_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + vuint8m2_t v_ids = __riscv_vid_v_u8m2(128); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 128); + + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 128); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 128); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 128); + + uint16_t gather_qh_arr[16] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 16); + + uint16_t shift_qh_arr[16] = {11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 16); + + // Masks for selecting lower/upper 16 lanes within a 32-lane i16m1 register + vuint16m1_t v_ids16 = __riscv_vid_v_u16m1(32); + vbool16_t m_hi16 = __riscv_vmsgeu_vx_u16m1_b16(v_ids16, 16, 32); + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 2; ++ib) { + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 16); + qs += 16; + + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8(qh, 4); + qh += 4; + + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 4); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 16); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 16); + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 16); + + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 16); + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 16); + + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 16); + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 16); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 16); + signs_ptr += 16; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 128); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 128); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 128); + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 128); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 128); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint16m1_t v0 = __riscv_vget_v_i16m4_i16m1(v_dot, 0); + vint16m1_t v1 = __riscv_vget_v_i16m4_i16m1(v_dot, 1); + vint16m1_t v2 = __riscv_vget_v_i16m4_i16m1(v_dot, 2); + vint16m1_t v3 = __riscv_vget_v_i16m4_i16m1(v_dot, 3); + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v0, v_zero, 16)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v0, v_zero, 32)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v1, v_zero, 16)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v1, v_zero, 32)); + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v2, v_zero, 16)); + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v2, v_zero, 32)); + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( v3, v_zero, 16)); + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v3, v_zero, 32)); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + uint8_t sc2 = scales[2]; + uint8_t sc3 = scales[3]; + scales += 4; + + sum_block += s0 * (2 * (sc0 & 0xF) + 1); + sum_block += s1 * (2 * (sc0 >> 4) + 1); + sum_block += s2 * (2 * (sc1 & 0xF) + 1); + sum_block += s3 * (2 * (sc1 >> 4) + 1); + sum_block += s4 * (2 * (sc2 & 0xF) + 1); + sum_block += s5 * (2 * (sc2 >> 4) + 1); + sum_block += s6 * (2 * (sc3 & 0xF) + 1); + sum_block += s7 * (2 * (sc3 >> 4) + 1); + } + + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +static void ggml_vec_dot_iq2_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + vuint8m2_t v_ids = __riscv_vid_v_u8m2(256); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 256); + + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 256); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 256); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 256); + + uint16_t gather_qh_arr[32] = { + 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, + 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7 + }; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 32); + + uint16_t shift_qh_arr[32] = { + 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, + 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5 + }; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 32); + + // Masks for 4 groups of 16 lanes within a 64-lane i16m4 chunk + vuint16m4_t v_ids64 = __riscv_vid_v_u16m4(64); + vbool4_t m_g0 = __riscv_vmsltu_vx_u16m4_b4(v_ids64, 16, 64); + vbool4_t m_g1 = __riscv_vmand_mm_b4( + __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 16, 64), + __riscv_vmsltu_vx_u16m4_b4(v_ids64, 32, 64), 64); + vbool4_t m_g2 = __riscv_vmand_mm_b4( + __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 32, 64), + __riscv_vmsltu_vx_u16m4_b4(v_ids64, 48, 64), 64); + vbool4_t m_g3 = __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 48, 64); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 32); + qs += 32; + + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8(qh, 8); + qh += 8; + + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 8); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 32); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 32); + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 32); + + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 32); + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 32); + + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 32); + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 32); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + //loading signs + vuint8mf2_t v_signs_raw = __riscv_vle8_v_u8mf2(signs_ptr, 32); + signs_ptr += 32; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf2_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 256); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 256); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 256); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 256); + q8 += 256; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 256); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 256); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + vint16m4_t c = v_dot; + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s8 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s9 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s10 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s11 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s12 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s13 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s14 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s15 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + int32_t sums_arr[16] = { s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15 }; + + // Load 8 scale bytes and split into 16 nibbles + vuint8mf2_t v_sc8 = __riscv_vle8_v_u8mf2(scales, 8); + scales += 8; + + vuint8mf2_t v_lo8 = __riscv_vand_vx_u8mf2(v_sc8, 0x0F, 8); + vuint8mf2_t v_hi8 = __riscv_vsrl_vx_u8mf2(v_sc8, 4, 8); + + vuint8m1_t v_idx16 = __riscv_vid_v_u8m1(16); + vuint8m1_t v_half = __riscv_vsrl_vx_u8m1(v_idx16, 1, 16); + vbool8_t m_even = __riscv_vmseq_vx_u8m1_b8(__riscv_vand_vx_u8m1(v_idx16, 1, 16), 0, 16); + + vuint8m1_t v_lo_ext = __riscv_vlmul_ext_v_u8mf2_u8m1(v_lo8); + vuint8m1_t v_hi_ext = __riscv_vlmul_ext_v_u8mf2_u8m1(v_hi8); + vuint8m1_t v_lo_g = __riscv_vrgather_vv_u8m1(v_lo_ext, v_half, 16); + vuint8m1_t v_hi_g = __riscv_vrgather_vv_u8m1(v_hi_ext, v_half, 16); + vuint8m1_t v_nib = __riscv_vmerge_vvm_u8m1(v_lo_g, v_hi_g, m_even, 16); + + static const uint8_t iq2s_scale_lut_16_local[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 + }; + vuint8m1_t v_lut = __riscv_vle8_v_u8m1(iq2s_scale_lut_16_local, 16); + vuint8m1_t v_sc8v = __riscv_vrgather_vv_u8m1(v_lut, v_nib, 16); + + vint32m4_t v_sums = __riscv_vle32_v_i32m4(sums_arr, 16); + vuint16m2_t v_sc16 = __riscv_vwcvtu_x_x_v_u16m2(v_sc8v, 16); + vuint32m4_t v_sc32u = __riscv_vwcvtu_x_x_v_u32m4(v_sc16, 16); + vint32m4_t v_sc32 = __riscv_vreinterpret_v_u32m4_i32m4(v_sc32u); + vint32m4_t v_prod = __riscv_vmul_vv_i32m4(v_sums, v_sc32, 16); + + vint32m1_t v_zero32 = __riscv_vmv_v_x_i32m1(0, 1); + int32_t sum_part = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(v_prod, v_zero32, 16)); + sum_block += sum_part; + + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3741,7 +4006,71 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT } *s = 0.125f * sumf; } -#endif + +static void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vint8m4_t q8_all = __riscv_vle8_v_i8m4(q8, 256); + + // Load indices --- + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 32); + + // Extract low 9 bits and multiply by 8 (shift left 3) for byte offset into uint64 table + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 32), 3, 32); + + // Extract high 7 bits (shift right 9) and multiply by 8 (shift left 3) for byte offset + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 32), 3, 32); + + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 32); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 32); + + vint8m4_t q2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t s2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + vint8m4_t q2_signed = __riscv_vmul_vv_i8m4(q2_all, s2_all, 256); + vint16m8_t dot_all = __riscv_vwmul_vv_i16m8(q2_signed, q8_all, 256); + float sum = 0.0f; + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + +#pragma GCC unroll 1 + for (int j = 0; j < 8; ++j) { + uint8_t sc = scales[j]; + int16_t sc_lo = 2 * (sc & 0x0F) + 1; + int16_t sc_hi = 2 * (sc >> 4) + 1; + + vint32m1_t sum_v0 = __riscv_vwredsum_vs_i16m8_i32m1( + __riscv_vslidedown_vx_i16m8(dot_all, j * 32, 16), zero_vec, 16); + int32_t isum0 = __riscv_vmv_x_s_i32m1_i32(sum_v0); + + vint32m1_t sum_v1 = __riscv_vwredsum_vs_i16m8_i32m1( + __riscv_vslidedown_vx_i16m8(dot_all, j * 32 + 16, 16), zero_vec, 16); + int32_t isum1 = __riscv_vmv_x_s_i32m1_i32(sum_v1); + + sum += (float)isum0 * sc_lo + (float)isum1 * sc_hi; + } + + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} + void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3946,7 +4275,79 @@ static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRIC } *s = 0.125f * sumf; } -#endif + + +static void ggml_vec_dot_iq2_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + // Shift pattern {0,7,14,21} repeated 8 times for all 8 sub-blocks + uint8_t shift_arr[32] = { + 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, + 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21 + }; + vuint8mf2_t v_shifts = __riscv_vle8_v_u8mf2(shift_arr, 32); + + // Gather pattern to broadcast the 8 sub-block scales across the 32 lookup slots + uint8_t gather_arr[32] = { + 0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3, + 4,4,4,4, 5,5,5,5, 6,6,6,6, 7,7,7,7 + }; + vuint8mf2_t v_sign_gather_idx = __riscv_vle8_v_u8mf2(gather_arr, 32); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + vint8m4_t q8_all = __riscv_vle8_v_i8m4(q8, 256); + + // De-interleave all 8 Index/Scale pairs for the 8x32-element sub-blocks + vuint32mf2x2_t tuple = __riscv_vlseg2e32_v_u32mf2x2((const uint32_t*)q2_ptr, 8); + vuint32mf2_t v_ind32 = __riscv_vget_v_u32mf2x2_u32mf2(tuple, 0); + vuint32mf2_t v_sc32 = __riscv_vget_v_u32mf2x2_u32mf2(tuple, 1); + + vuint8mf2_t v_raw_q2 = __riscv_vreinterpret_v_u32mf2_u8mf2(v_ind32); + vuint16m1_t vidx_q2 = __riscv_vwcvtu_x_x_v_u16m1(v_raw_q2, 32); + vidx_q2 = __riscv_vsll_vx_u16m1(vidx_q2, 3, 32); + + vuint32m2_t v_s = __riscv_vrgatherei16_vv_u32m2(__riscv_vlmul_ext_v_u32mf2_u32m2(v_sc32), __riscv_vwcvtu_x_x_v_u16m1(v_sign_gather_idx,32), 32); + v_s = __riscv_vsrl_vv_u32m2(v_s, __riscv_vwcvtu_x_x_v_u32m2(__riscv_vwcvtu_x_x_v_u16m1(v_shifts,32),32), 32); + v_s = __riscv_vand_vx_u32m2(v_s, 127, 32); + vuint16m1_t vidx_s2 = __riscv_vsll_vx_u16m1(__riscv_vncvt_x_x_w_u16m1(v_s, 32), 3, 32); + + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_q2, 32); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_s2, 32); + vint8m4_t q2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t s2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + vint8m4_t q8s_all = __riscv_vmul_vv_i8m4(q8_all, s2_all, 256); + vint16m8_t dot_all = __riscv_vwmul_vv_i16m8(q8s_all, q2_all, 256); + + float sum = 0.0f; + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + for (int j = 0; j < 8; ++j) { + uint32_t s_p = __riscv_vmv_x_s_u32mf2_u32(__riscv_vslidedown_vx_u32mf2(v_sc32, j, 8)); + int16_t sc = 2 * ((s_p >> 28) & 0xF) + 1; + dot_all=__riscv_vslidedown_vx_i16m8(dot_all,j*32,32); + vint32m1_t sum_v = __riscv_vwredsum_vs_i16m8_i32m1(dot_all, zero_vec, 32); + int32_t isum = __riscv_vmv_x_s_i32m1_i32(sum_v); + sum += (float)isum * sc; + } + + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -4153,7 +4554,95 @@ static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT } *s = sumf; } -#endif + +static void ggml_vec_dot_iq3_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint32_t * grid32 = (const uint32_t *)iq3s_grid; + + // Generate Constants + vuint8mf2_t v_id_32 = __riscv_vid_v_u8mf2(32); + vuint8mf2_t v_qh_gather = __riscv_vsrl_vx_u8mf2(v_id_32, 3, 32); + vuint8mf2_t v_qh_shifts = __riscv_vand_vx_u8mf2(v_id_32, 7, 32); + vuint8m2_t v_id_128 = __riscv_vid_v_u8m2(128); + vuint8m2_t v_sign_gather = __riscv_vsrl_vx_u8m2(v_id_128, 3, 128); // byte index + vuint8m2_t v_sign_shift_amts = __riscv_vand_vx_u8m2(v_id_128, 7, 128); // bit shift + vuint8m2_t v_one_128 = __riscv_vmv_v_x_u8m2(1, 128); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_one_128, v_sign_shift_amts, 128); + vuint8m2_t v_scale_indices = __riscv_vsrl_vx_u8m2(v_id_128, 5, 128); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + for (int ib = 0; ib < 2; ++ib) { + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 32); + qs += 32; + vuint8mf2_t v_qh_loaded = __riscv_vle8_v_u8mf2(qh, 4); + qh += 4; + vuint8mf2_t v_qh_expanded = __riscv_vrgather_vv_u8mf2(v_qh_loaded, v_qh_gather, 32); + v_qh_expanded = __riscv_vsrl_vv_u8mf2(v_qh_expanded, v_qh_shifts, 32); + v_qh_expanded = __riscv_vand_vx_u8mf2(v_qh_expanded, 1, 32); + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 32); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 32); // * 4 + + vuint16m1_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qh_expanded, 32); + v_qh_u16 = __riscv_vsll_vx_u16m1(v_qh_u16, 10, 32); // * 256 * 4 + + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_u16, 32); + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2(grid32, v_grid_offsets, 32); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + vuint8mf2_t v_signs_raw = __riscv_vle8_v_u8mf2(signs, 16); + signs += 16; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf2_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather, 128); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 128); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 128); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 128); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 128); + uint16_t sc_raw; + memcpy(&sc_raw, scales, 2); + scales += 2; // Advance 2 bytes + + uint8_t sc_unpacked[4]; + sc_unpacked[0] = (sc_raw & 0xF); + sc_unpacked[1] = (sc_raw >> 4) & 0xF; + sc_unpacked[2] = (sc_raw >> 8) & 0xF; + sc_unpacked[3] = (sc_raw >> 12) & 0xF; + + vuint8mf2_t v_sc_4 = __riscv_vle8_v_u8mf2(sc_unpacked, 4); + v_sc_4 = __riscv_vmul_vx_u8mf2(v_sc_4, 2, 4); + v_sc_4 = __riscv_vadd_vx_u8mf2(v_sc_4, 1, 4); + vuint8m2_t v_sc_4_expanded = __riscv_vlmul_ext_v_u8mf2_u8m2(v_sc_4); + vuint8m2_t v_scales_bcast = __riscv_vrgather_vv_u8m2(v_sc_4_expanded, v_scale_indices, 128); + vint16m4_t v_scales_i16 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwcvtu_x_x_v_u16m4(v_scales_bcast, 128)); + vint32m8_t v_weighted_sum = __riscv_vwmul_vv_i32m8(v_dot, v_scales_i16, 128); + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + int32_t s_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m8_i32m1(v_weighted_sum, v_zero, 128)); + + sum_block += s_val; + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -4359,7 +4848,190 @@ static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRIC } *s = 0.25f * sumf; } -#endif +static void ggml_vec_dot_iq3_xxs_q8_K_vl512(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + size_t bx, + const void * GGML_RESTRICT vy, + size_t by, + int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // generate constants for unpacking metadata words into sign indices + vuint32m1_t v_shifts; + { + vuint32m1_t v_base = __riscv_vid_v_u32m1(16); + vuint32m1_t v_mod4 = __riscv_vand_vx_u32m1(v_base, 3, 16); + v_shifts = __riscv_vmul_vx_u32m1(v_mod4, 7, 16); + } + + vuint16mf2_t v_gather_idx; + { + vuint16mf2_t v_idx = __riscv_vid_v_u16mf2(16); + v_gather_idx = __riscv_vsrl_vx_u16mf2(v_idx, 2, 16); + } + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + for (int ib128 = 0; ib128 < 2; ++ib128) { + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 32); + q3_indices += 32; + + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 32); + vuint32m2_t v_q3_mag_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 32); + vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u32m2_u8m2(v_q3_mag_u32)); + vuint32m1_t v_aux = __riscv_vreinterpret_v_u8m1_u32m1(__riscv_vle8_v_u8m1(metadata, 16)); + metadata += 4 * sizeof(uint32_t); + + vuint32m1_t v_aux_expanded = __riscv_vrgatherei16_vv_u32m1(v_aux, v_gather_idx, 16); + + vuint32m1_t v_s_raw = __riscv_vand_vx_u32m1( + __riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 16), 127, 16); + vuint16mf2_t sign_byte_offset = __riscv_vsll_vx_u16mf2( + __riscv_vncvt_x_x_w_u16mf2(v_s_raw, 16), 3, 16); + vuint64m2_t v_s_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_byte_offset, 16); + vint8m2_t v_signs = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u64m2_u8m2(v_s_u64)); + vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_signs, 128); + vint16m4_t prod = __riscv_vwmul_vv_i16m4(v_q3_signed, v_q8, 128); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + int32_t group0_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 32)); + int32_t group1_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 32)); + int32_t group2_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 32)); + int32_t group3_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 32)); + + vuint32m1_t v_scales_raw = __riscv_vsrl_vx_u32m1(v_aux, 28, 4); + vuint32m1_t v_scales = __riscv_vadd_vx_u32m1( + __riscv_vsll_vx_u32m1(v_scales_raw, 1, 4), + 1, 4); + int32_t scale0 = (int32_t)__riscv_vmv_x_s_u32m1_u32(v_scales); + int32_t scale1 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 1, 4)); + int32_t scale2 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 2, 4)); + int32_t scale3 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 3, 4)); + + block_sum += (float)(group0_sum * scale0 + group1_sum * scale1 + + group2_sum * scale2 + group3_sum * scale3); + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +static void ggml_vec_dot_iq3_xxs_q8_K_vl1024(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + size_t bx, + const void * GGML_RESTRICT vy, + size_t by, + int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + vuint32m1_t v_shifts; + { + vuint32m1_t v_id = __riscv_vid_v_u32m1(32); + vuint32m1_t v_mod4 = __riscv_vand_vx_u32m1(v_id, 3, 32); + v_shifts = __riscv_vmul_vx_u32m1(v_mod4, 7, 32); + } + vuint16mf2_t v_gather_idx; + { + vuint16mf2_t v_id_16 = __riscv_vid_v_u16mf2(32); + v_gather_idx = __riscv_vsrl_vx_u16mf2(v_id_16, 2, 32); + } + + float sumf = 0.0f; + uint32_t aux32[8]; // Buffer for block metadata + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 256); + vuint8mf2_t v_q3_idx_raw = __riscv_vle8_v_u8mf2(q3_indices, 64); + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_raw, 4, 64); + + vuint32m2_t v_q3_grid_vals = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 64); + + vint8m2_t v_q3_mags = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u32m2_u8m2(v_q3_grid_vals)); + + memcpy(aux32, metadata, 8 * sizeof(uint32_t)); + vuint32m1_t v_aux_8 = __riscv_vle32_v_u32m1(aux32, 8); + + vuint32m1_t v_aux_32 = __riscv_vrgatherei16_vv_u32m1(v_aux_8, v_gather_idx, 32); + + vuint32m1_t v_sign_idx_raw = __riscv_vand_vx_u32m1( + __riscv_vsrl_vv_u32m1(v_aux_32, v_shifts, 32), 127, 32); + + vuint16mf2_t v_sign_offsets = __riscv_vsll_vx_u16mf2( + __riscv_vncvt_x_x_w_u16mf2(v_sign_idx_raw, 32), 3, 32); + + vuint64m2_t v_signs_u64 = __riscv_vluxei16_v_u64m2(signs64, v_sign_offsets, 32); + + vint8m2_t v_signs = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u64m2_u8m2(v_signs_u64)); + + vint8m2_t v_q3_final = __riscv_vmul_vv_i8m2(v_q3_mags, v_signs, 256); + + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_final, 256); + float block_sum = 0.0f; + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint16m4_t v_accum = v_dot; + + for (int j = 0; j < 8; ++j) { + float scale = (float)(2 * (aux32[j] >> 28) + 1); + + vint32m1_t v_partial_sum = __riscv_vwredsum_vs_i16m4_i32m1(v_accum, v_zero, 32); + + int32_t partial_sum_i = __riscv_vmv_x_s_i32m1_i32(v_partial_sum); + block_sum += partial_sum_i * scale; + v_accum = __riscv_vslidedown_vx_i16m4(v_accum, 32, 32); + + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic From e3e2c6c86049558d8383bf1965b8c5f42834ff1c Mon Sep 17 00:00:00 2001 From: Rehan Qasim Date: Tue, 24 Feb 2026 15:14:44 +0500 Subject: [PATCH 4/4] ggml-cpu: refactor; improve iq2_xs impl for rvv 256 Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/arch/riscv/quants.c | 1481 ++++++++++++++++--------- 1 file changed, 953 insertions(+), 528 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index 9b0d8bc1fb5..9ebfaae1ddc 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -123,7 +123,7 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in assert(k % QK_K == 0); size_t nb = k / QK_K; -#if defined __riscv_v_intrinsic +#if defined __riscv_v block_q8_K * y_blocks = (block_q8_K *)y; const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); @@ -480,7 +480,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +void ggml_vec_dot_q2_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -492,8 +493,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; uint8_t atmp[16]; @@ -588,246 +587,281 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +void ggml_vec_dot_q2_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; uint8_t atmp[16]; - const int vector_length = __riscv_vlenb() * 8; uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 16, e8, m1\n\t" + "vmv.v.x v8, zero\n\t" + "lb zero, 15(%[sc])\n\t" + "vle8.v v1, (%[sc])\n\t" + "vle8.v v2, (%[bsums])\n\t" + "addi %[tmp], %[bsums], 16\n\t" + "vand.vi v0, v1, 0xF\n\t" + "vsrl.vi v1, v1, 4\n\t" + "vle8.v v3, (%[tmp])\n\t" + "vse8.v v0, (%[scale])\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vzext.vf2 v0, v1\n\t" + "vwmul.vv v4, v0, v2\n\t" + "vsetivli zero, 16, e32, m4\n\t" + "vredsum.vs v8, v4, v8\n\t" + "vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "lb zero, 31(%[q2])\n\t" + "addi %[tmp], %[q2], 16\n\t" + "addi %[t1], %[q8], 16\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vle8.v v0, (%[q2])\n\t" + "vle8.v v1, (%[tmp])\n\t" + "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v3, v1, 2\n\t" + "vsrl.vi v4, v0, 4\n\t" + "addi %[tmp], %[q8], 32\n\t" + "vle8.v v8, (%[q8])\n\t" + "vle8.v v9, (%[t1])\n\t" + "addi %[t1], %[t1], 32\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vsrl.vi v7, v1, 6\n\t" + "vle8.v v10, (%[tmp])\n\t" + "vle8.v v11, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v1, v1, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vle8.v v12, (%[tmp])\n\t" + "vle8.v v13, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v3, v3, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vand.vi v5, v5, 0x3\n\t" + "vle8.v v14, (%[tmp])\n\t" + "vle8.v v15, (%[t1])\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v18, v1, v9\n\t" + "vwmul.vv v20, v2, v10\n\t" + "vwmul.vv v22, v3, v11\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vwmul.vv v26, v5, v13\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v30, v7, v15\n\t" + "vsetivli zero, 8, e16, m1\n\t" + "vmv.v.x v0, zero\n\t" + "lbu %[tmp], 0(%[scale])\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "lbu %[t1], 1(%[scale])\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "lbu %[t2], 2(%[scale])\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "lbu %[t3], 3(%[scale])\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "lbu %[t4], 4(%[scale])\n\t" + "vwredsum.vs v8, v17, v8\n\t" + "vwredsum.vs v9, v19, v9\n\t" + "lbu %[t5], 5(%[scale])\n\t" + "vwredsum.vs v10, v21, v10\n\t" + "vwredsum.vs v11, v23, v11\n\t" + "lbu %[t6], 6(%[scale])\n\t" + "vwredsum.vs v12, v25, v12\n\t" + "vwredsum.vs v13, v27, v13\n\t" + "lbu %[t7], 7(%[scale])\n\t" + "vwredsum.vs v14, v29, v14\n\t" + "vwredsum.vs v15, v31, v15\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } - size_t vl = 16; + sumf += dall * isum; + } - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + *s = sumf; +} - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); +void ggml_vec_dot_q2_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + const int nb = n / QK_K; - vl = 32; + float sumf = 0; + uint8_t atmp[16]; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - uint8_t is = 0; - int isum = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; - for (int j = 0; j < QK_K / 128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); + size_t vl = 16; - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + vl = 32; - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - q2 += 32; - q8 += 128; - is = 8; - } + uint8_t is = 0; + int isum = 0; - sumf += dall * isum; - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - uint8_t *patmp = atmp; - int vsums; - int tmp, t1, t2, t3, t4, t5, t6, t7; - __asm__ __volatile__( - "vsetivli zero, 16, e8, m1\n\t" - "vmv.v.x v8, zero\n\t" - "lb zero, 15(%[sc])\n\t" - "vle8.v v1, (%[sc])\n\t" - "vle8.v v2, (%[bsums])\n\t" - "addi %[tmp], %[bsums], 16\n\t" - "vand.vi v0, v1, 0xF\n\t" - "vsrl.vi v1, v1, 4\n\t" - "vle8.v v3, (%[tmp])\n\t" - "vse8.v v0, (%[scale])\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vzext.vf2 v0, v1\n\t" - "vwmul.vv v4, v0, v2\n\t" - "vsetivli zero, 16, e32, m4\n\t" - "vredsum.vs v8, v4, v8\n\t" - "vmv.x.s %[vsums], v8" - : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) - : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - sumf += dmin * vsums; - int isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "lb zero, 31(%[q2])\n\t" - "addi %[tmp], %[q2], 16\n\t" - "addi %[t1], %[q8], 16\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vle8.v v0, (%[q2])\n\t" - "vle8.v v1, (%[tmp])\n\t" - "vsrl.vi v2, v0, 2\n\t" - "vsrl.vi v3, v1, 2\n\t" - "vsrl.vi v4, v0, 4\n\t" - "addi %[tmp], %[q8], 32\n\t" - "vle8.v v8, (%[q8])\n\t" - "vle8.v v9, (%[t1])\n\t" - "addi %[t1], %[t1], 32\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v6, v0, 6\n\t" - "vsrl.vi v7, v1, 6\n\t" - "vle8.v v10, (%[tmp])\n\t" - "vle8.v v11, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v0, v0, 0x3\n\t" - "vand.vi v1, v1, 0x3\n\t" - "vand.vi v2, v2, 0x3\n\t" - "vle8.v v12, (%[tmp])\n\t" - "vle8.v v13, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v3, v3, 0x3\n\t" - "vand.vi v4, v4, 0x3\n\t" - "vand.vi v5, v5, 0x3\n\t" - "vle8.v v14, (%[tmp])\n\t" - "vle8.v v15, (%[t1])\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v18, v1, v9\n\t" - "vwmul.vv v20, v2, v10\n\t" - "vwmul.vv v22, v3, v11\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vwmul.vv v26, v5, v13\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmul.vv v30, v7, v15\n\t" - "vsetivli zero, 8, e16, m1\n\t" - "vmv.v.x v0, zero\n\t" - "lbu %[tmp], 0(%[scale])\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "lbu %[t1], 1(%[scale])\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "lbu %[t2], 2(%[scale])\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "lbu %[t3], 3(%[scale])\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "lbu %[t4], 4(%[scale])\n\t" - "vwredsum.vs v8, v17, v8\n\t" - "vwredsum.vs v9, v19, v9\n\t" - "lbu %[t5], 5(%[scale])\n\t" - "vwredsum.vs v10, v21, v10\n\t" - "vwredsum.vs v11, v23, v11\n\t" - "lbu %[t6], 6(%[scale])\n\t" - "vwredsum.vs v12, v25, v12\n\t" - "vwredsum.vs v13, v27, v13\n\t" - "lbu %[t7], 7(%[scale])\n\t" - "vwredsum.vs v14, v29, v14\n\t" - "vwredsum.vs v15, v31, v15\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [isum] "+&r" (isum) - : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q2 += 32; q8 += 128; patmp += 8; - } + for (int j = 0; j < QK_K / 128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - sumf += dall * isum; + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); + + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2 += 32; + q8 += 128; + is = 8; } - break; - default: - assert(false && "Unsupported vector length"); - break; + + sumf += dall * isum; } *s = sumf; +} +#endif +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q2_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q2_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q2_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +void ggml_vec_dot_q3_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -843,8 +877,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - uint32_t utmp[4]; float sumf = 0; @@ -966,258 +998,538 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - sumf += d * isum; + sumf += d * isum; + } + + *s = sumf; +} +#endif + +#if defined __riscv_v +void ggml_vec_dot_q3_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + int8_t * scale = (int8_t *)utmp; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v0, (%[s6b])\n\t" + "vmv1r.v v2, v0\n\t" + "vsetivli zero, 2, e64, m1\n\t" + "vmv.v.x v9, %[sh]\n\t"\ + "vslidedown.vi v1, v0, 1\n\t" + "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "vsetivli zero, 4, e32, m1\n\t" + "vid.v v9\n\t" + "vmv.x.s %[tmp], v1\n\t" + "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "vsrl.vv v4, v1, v9\n\t" + "vsrl.vv v2, v0, v8\n\t" + "vand.vx v5, v4, %[kmask1]\n\t" + "vand.vx v3, v2, %[kmask2]\n\t" + "vsll.vi v6, v5, 4\n\t" + "vor.vv v7, v6, v3\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vsub.vx v0, v7, %[c]\n\t" + "vse8.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + "lb zero, 31(%[q3])\n\t" + "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" + "vle8.v v8, (%[q3])\n\t" + "vsrl.vi v10, v8, 2\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v8, 6\n\t" + "lb zero, 64(%[q8])\n\t" + "vand.vi v8, v8, 3\n\t" + "vand.vi v10, v10, 3\n\t" + "vand.vi v12, v12, 3\n\t" + "vle8.v v2, (%[qh])\n\t" + "lb zero, 127(%[q8])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v8, v8, -4, v0.t\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v10, v10, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v12, v12, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v14, v14, -4, v0.t\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "lb %[tmp], 0(%[scale])\n\t" + "lb %[t1], 1(%[scale])\n\t" + "lb %[t2], 2(%[scale])\n\t" + "lb %[t3], 3(%[scale])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "lb %[t4], 4(%[scale])\n\t" + "lb %[t5], 5(%[scale])\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "lb %[t6], 6(%[scale])\n\t" + "lb %[t7], 7(%[scale])\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; + } + + *s = sumf; +} + +void ggml_vec_dot_q3_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + + int sum_t = 0; + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q3 += 32; q8 += 128; scale += 8; + + } + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + + } + + *s = sumf; +} + +void ggml_vec_dot_q3_K_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16m1_t va_index = __riscv_vid_v_u16m1(32); + const vbool16_t va_mask = __riscv_vmsgtu_vx_u16m1_b16(va_index, 15, 32); + + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(qh, vl); + + int sum_t = 0; + + vint32m2_t vaux_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + + vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x03, vl)); + vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x2, vl), 0x03 , vl)); + vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x4, vl), 0x03 , vl)); + vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8mf2_t qh_m0 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_0 = __riscv_vmseq_vx_u8mf2_b16(qh_m0, 0, vl); + vint8mf2_t q3_m0 = __riscv_vsub_vx_i8mf2_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m1 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_1 = __riscv_vmseq_vx_u8mf2_b16(qh_m1, 0, vl); + vint8mf2_t q3_m1 = __riscv_vsub_vx_i8mf2_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m2 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_2 = __riscv_vmseq_vx_u8mf2_b16(qh_m2, 0, vl); + vint8mf2_t q3_m2 = __riscv_vsub_vx_i8mf2_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m3 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_3 = __riscv_vmseq_vx_u8mf2_b16(qh_m3, 0, vl); + vint8mf2_t q3_m3 = __riscv_vsub_vx_i8mf2_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product + vint16m1_t va_q_0 = __riscv_vwmul_vv_i16m1(q3_m0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t va_q_1 = __riscv_vwmul_vv_i16m1(q3_m1, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t va_q_2 = __riscv_vwmul_vv_i16m1(q3_m2, __riscv_vle8_v_i8mf2(q8+64, vl), vl); + vint16m1_t va_q_3 = __riscv_vwmul_vv_i16m1(q3_m3, __riscv_vle8_v_i8mf2(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m2(vaux_0, scale[0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m2(vaux_1, scale[2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m2(vaux_2, scale[4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m2(vaux_3, scale[6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_0, scale[1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_1, scale[3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_2, scale[5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_3, scale[7], va_q_3, vl); + + q3 += 32; q8 += 128; scale += 8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; } *s = sumf; +} -#elif defined __riscv_v - - uint32_t utmp[4]; - float sumf = 0; - uint32_t aux[3]; - const int vector_length = __riscv_vlenb() * 8; +void ggml_vec_dot_q3_K_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + const int nb = n / QK_K; - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; + // mask for processing 16 elements per prod register + const vuint16mf2_t va_index = __riscv_vid_v_u16mf2(32); + const vbool32_t va_mask = __riscv_vmsgtu_vx_u16mf2_b32(va_index, 15, 32); + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; - size_t vl = 32; - uint8_t m = 1; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - int sum_t = 0; + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; - for (int j = 0; j < QK_K; j += 128) { - vl = 32; + size_t vl = 32; + uint8_t m = 1; - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8mf4_t vqh = __riscv_vle8_v_u8mf4(qh, vl); - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + int sum_t = 0; - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; + vint32m1_t vaux_0 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_1 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_2 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_3 = __riscv_vmv_v_x_i32m1(0, vl); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; + for (int j = 0; j < QK_K; j += 128) { - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; + vl = 32; - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; + // load Q3 + vuint8mf4_t q3_x = __riscv_vle8_v_u8mf4(q3, vl); - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + vint8mf4_t q3_0 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(q3_x, 0x03, vl)); + vint8mf4_t q3_1 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x2, vl), 0x03 , vl)); + vint8mf4_t q3_2 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x4, vl), 0x03 , vl)); + vint8mf4_t q3_3 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x6, vl), 0x03 , vl)); - vl = 16; + // compute mask for subtraction + vuint8mf4_t qh_m0 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_0 = __riscv_vmseq_vx_u8mf4_b32(qh_m0, 0, vl); + vint8mf4_t q3_m0 = __riscv_vsub_vx_i8mf4_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + vuint8mf4_t qh_m1 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_1 = __riscv_vmseq_vx_u8mf4_b32(qh_m1, 0, vl); + vint8mf4_t q3_m1 = __riscv_vsub_vx_i8mf4_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + vuint8mf4_t qh_m2 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_2 = __riscv_vmseq_vx_u8mf4_b32(qh_m2, 0, vl); + vint8mf4_t q3_m2 = __riscv_vsub_vx_i8mf4_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + vuint8mf4_t qh_m3 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_3 = __riscv_vmseq_vx_u8mf4_b32(qh_m3, 0, vl); + vint8mf4_t q3_m3 = __riscv_vsub_vx_i8mf4_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; - q3 += 32; q8 += 128; scale += 8; + // load Q8 and take product + vint16mf2_t va_q_0 = __riscv_vwmul_vv_i16mf2(q3_m0, __riscv_vle8_v_i8mf4(q8, vl), vl); + vint16mf2_t va_q_1 = __riscv_vwmul_vv_i16mf2(q3_m1, __riscv_vle8_v_i8mf4(q8+32, vl), vl); + vint16mf2_t va_q_2 = __riscv_vwmul_vv_i16mf2(q3_m2, __riscv_vle8_v_i8mf4(q8+64, vl), vl); + vint16mf2_t va_q_3 = __riscv_vwmul_vv_i16mf2(q3_m3, __riscv_vle8_v_i8mf4(q8+96, vl), vl); - } + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m1(vaux_0, scale[0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m1(vaux_1, scale[2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m1(vaux_2, scale[4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m1(vaux_3, scale[6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_0, scale[1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_1, scale[3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_2, scale[5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_3, scale[7], va_q_3, vl); - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + q3 += 32; q8 += 128; scale += 8; + } - sumf += d*sum_t; + vint32m1_t isum0 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_2, vaux_3, vl), isum0, vl); - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - int8_t * scale = (int8_t *)utmp; - int tmp, t1, t2, t3, t4, t5, t6, t7; - __asm__ __volatile__( - "vsetivli zero, 12, e8, m1\n\t" - "vle8.v v0, (%[s6b])\n\t" - "vmv1r.v v2, v0\n\t" - "vsetivli zero, 2, e64, m1\n\t" - "vmv.v.x v9, %[sh]\n\t"\ - "vslidedown.vi v1, v0, 1\n\t" - "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} - "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} - "vsetivli zero, 4, e32, m1\n\t" - "vid.v v9\n\t" - "vmv.x.s %[tmp], v1\n\t" - "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} - "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} - "vsrl.vv v4, v1, v9\n\t" - "vsrl.vv v2, v0, v8\n\t" - "vand.vx v5, v4, %[kmask1]\n\t" - "vand.vx v3, v2, %[kmask2]\n\t" - "vsll.vi v6, v5, 4\n\t" - "vor.vv v7, v6, v3\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vsub.vx v0, v7, %[c]\n\t" - "vse8.v v0, (%[scale])" - : [tmp] "=&r" (tmp) - : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) - , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); - uint8_t m = 1; - int isum = 0; - for (int j = 0; j < QK_K; j += 128) { - __asm__ __volatile__( - "lb zero, 31(%[q3])\n\t" - "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" - "vle8.v v8, (%[q3])\n\t" - "vsrl.vi v10, v8, 2\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v8, 6\n\t" - "lb zero, 64(%[q8])\n\t" - "vand.vi v8, v8, 3\n\t" - "vand.vi v10, v10, 3\n\t" - "vand.vi v12, v12, 3\n\t" - "vle8.v v2, (%[qh])\n\t" - "lb zero, 127(%[q8])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v8, v8, -4, v0.t\n\t" - "lb zero, 0(%[q8])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v10, v10, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v12, v12, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v14, v14, -4, v0.t\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "lb %[tmp], 0(%[scale])\n\t" - "lb %[t1], 1(%[scale])\n\t" - "lb %[t2], 2(%[scale])\n\t" - "lb %[t3], 3(%[scale])\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "lb %[t4], 4(%[scale])\n\t" - "lb %[t5], 5(%[scale])\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "lb %[t6], 6(%[scale])\n\t" - "lb %[t7], 7(%[scale])\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [m] "+&r" (m), [isum] "+&r" (isum) - : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) - , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q3 += 32; q8 += 128; scale += 8; - } + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - sumf += d * isum; - } - break; - default: - assert(false && "Unsupported vector length"); - break; + sumf += d*sum_t; } *s = sumf; +} +#endif +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q3_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q3_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_q3_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_q3_K_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_q3_K_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif - } #if defined __riscv_xtheadvector @@ -1356,6 +1668,7 @@ static NOINLINE void ggml_vec_dot_q4_K_q8_K_xtheadvector(int n, float * GGML_RES } #endif +#if defined __riscv_v static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -1648,16 +1961,17 @@ static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl256(int n, float * GGML_RESTRICT s *s = sumf; } +#endif void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_xtheadvector ggml_vec_dot_q4_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); -#elif defined __riscv_v_intrinsic +#elif defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_q4_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: + default: // 256 and above ggml_vec_dot_q4_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } @@ -1879,6 +2193,7 @@ static NOINLINE void ggml_vec_dot_q6_K_q8_K_xtheadvector(int n, float * GGML_RES } #endif +#if defined __riscv_v static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2302,11 +2617,12 @@ static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl1024(int n, float * GGML_RESTRICT *s = sumf; } +#endif void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_xtheadvector ggml_vec_dot_q6_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); -#elif defined __riscv_v_intrinsic +#elif defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_q6_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2329,7 +2645,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2537,7 +2853,6 @@ static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } -#endif static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); @@ -2718,9 +3033,10 @@ static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl1024(int n, float * GGML_RESTRICT *s = sumf; } +#endif void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq1_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2743,7 +3059,7 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -3023,7 +3339,6 @@ static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } -#endif static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); @@ -3301,9 +3616,10 @@ static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl1024(int n, float * GGML_RESTRICT *s = sumf; } +#endif void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq1_m_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3326,7 +3642,7 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static const uint8_t sign_gather_indices_arr[64] = { 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 @@ -3530,7 +3846,7 @@ static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT *s = 0.125f * sumf; } -static void ggml_vec_dot_iq2_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -3640,7 +3956,7 @@ static void ggml_vec_dot_iq2_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t *s = 0.125f * sumf; } -static void ggml_vec_dot_iq2_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -3794,10 +4110,10 @@ static void ggml_vec_dot_iq2_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_ } *s = 0.125f * sumf; } - +#endif void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3805,8 +4121,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq2_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; default: - ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq2_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -3814,7 +4133,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -3952,62 +4271,74 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT int32_t sum_int = 0; - // Loop over 4 subblocks of 64 elements (QK_K = 256) - for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) { - // Load 8 uint16 indices (controls 64 values) - vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8); - qs += 8; + for (int ib128 = 0; ib128 < 2; ++ib128) { - // Extract indices for grid (low 9 bits) and signs (high 7 bits) - // Multiply by 8 (<< 3) for byte offsets into the uint64 tables - vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8); - vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8); + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 16); + qs += 16; - vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8); - vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8); + // Prepare offsets for grid and signs + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 16), 3, 16); + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 16), 3, 16); - vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64)); - vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64)); + // Indexed load 128 weights (16 x 8-byte chunks) + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 16); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 16); - // Apply signs - vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64); + vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); - // Load Q8 weights (64 elements) - vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64); - q8 += 64; + // Apply signs to get dequantized IQ2 values + vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 128); + asm volatile("" ::: "memory"); - // Multiply (Widening to int16, 64 elements -> LMUL=4) - vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64); + // Load corresponding Q8 weights + vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 128); + q8 += 128; + + vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 128); + asm volatile("" ::: "memory"); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + uint8_t sc2 = scales[2]; + uint8_t sc3 = scales[3]; + scales += 4; - // Reduction vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); - int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16)); - int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16)); - int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16)); - int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16)); + // 9. Reduce each 16-element chunk and apply corresponding nibble scale - // Apply Scales - const uint8_t scale_byte_1 = scales[0]; - const uint8_t scale_byte_2 = scales[1]; - scales += 2; + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), zero_vec, 16)); + sum_int += s0 * ((sc0 & 0x0F) * 2 + 1); - sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1); - sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1); - sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1); - sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), zero_vec, 16)); + sum_int += s1 * ((sc0 >> 4) * 2 + 1); + + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), zero_vec, 16)); + sum_int += s2 * ((sc1 & 0x0F) * 2 + 1); + + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), zero_vec, 16)); + sum_int += s3 * ((sc1 >> 4) * 2 + 1); + + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), zero_vec, 16)); + sum_int += s4 * ((sc2 & 0x0F) * 2 + 1); + + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), zero_vec, 16)); + sum_int += s5 * ((sc2 >> 4) * 2 + 1); + + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), zero_vec, 16)); + sum_int += s6 * ((sc3 & 0x0F) * 2 + 1); + + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), zero_vec, 16)); + sum_int += s7 * ((sc3 >> 4) * 2 + 1); } - sumf += d * sum_int; + sumf += d * (float)sum_int; } *s = 0.125f * sumf; } -static void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -4070,10 +4401,10 @@ static void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_ } *s = 0.125f * sumf; } - +#endif void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -4081,8 +4412,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v case 256: ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_iq2_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -4090,7 +4421,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -4276,8 +4607,7 @@ static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRIC *s = 0.125f * sumf; } - -static void ggml_vec_dot_iq2_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -4348,23 +4678,27 @@ static void ggml_vec_dot_iq2_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size } *s = 0.125f * sumf; } +#endif void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: // 256 and above + case 256: ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + default: // 512 and above + ggml_vec_dot_iq2_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; } #else ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -4555,7 +4889,7 @@ static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } -static void ggml_vec_dot_iq3_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); const block_iq3_s * GGML_RESTRICT x = vx; @@ -4643,18 +4977,19 @@ static void ggml_vec_dot_iq3_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t } *s = sumf; } +#endif void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: - ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; case 256: ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_iq3_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -4662,7 +4997,7 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -4848,14 +5183,8 @@ static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRIC } *s = 0.25f * sumf; } -static void ggml_vec_dot_iq3_xxs_q8_K_vl512(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - size_t bx, - const void * GGML_RESTRICT vy, - size_t by, - int nrc) { + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -4944,14 +5273,7 @@ static void ggml_vec_dot_iq3_xxs_q8_K_vl512(int n, *s = 0.25f * sumf; } -static void ggml_vec_dot_iq3_xxs_q8_K_vl1024(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - size_t bx, - const void * GGML_RESTRICT vy, - size_t by, - int nrc) { +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -5032,9 +5354,10 @@ static void ggml_vec_dot_iq3_xxs_q8_K_vl1024(int n, } *s = 0.25f * sumf; } +#endif void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq3_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -5042,8 +5365,11 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const case 256: ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + case 512: + ggml_vec_dot_iq3_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 1024 and above + ggml_vec_dot_iq3_xxs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -5051,7 +5377,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -5166,7 +5492,7 @@ static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -5180,7 +5506,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -5326,7 +5652,6 @@ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } -#endif static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); @@ -5522,9 +5847,10 @@ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl1024(int n, float * GGML_RESTRIC *s = sumf; } +#endif void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq4_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -5547,7 +5873,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -5750,10 +6076,112 @@ static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint16m1_t suml1; + { + const int vl = 32; + vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3, vl), 8, vl); + vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl); + vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl); + vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl); + vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 0, vl), vl); + vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 32, vl), vl); + vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 64, vl), vl); + vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 96, vl), vl); + vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 128, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl); + vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl); + vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); + vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); + + vint16m1_t sumi0 = __riscv_vadd_vv_i16m1(sum0, sum1, vl); + vint16m1_t sumi1 = __riscv_vadd_vv_i16m1(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i16m1(sum4, __riscv_vadd_vv_i16m1(sumi0, sumi1, vl), vl); + } + + // Second loop. + vint16mf2_t suml2; + { + const int vl = 16; + vuint8mf4_t tq = __riscv_vle8_v_u8mf4(x[i].qs + 32, vl); + + vuint16mf2_t tq0 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(tq, 3 * 1, vl), 8, vl); + vuint16mf2_t tq1 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 3, vl), 3, vl), 8, vl); + vuint16mf2_t tq2 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 9, vl), 3, vl), 8, vl); + vuint16mf2_t tq3 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 27, vl), 3, vl), 8, vl); + vuint16mf2_t tq4 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 81, vl), 3, vl), 8, vl); + + vint16mf2_t q80 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 160, vl), vl); + vint16mf2_t q81 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 176, vl), vl); + vint16mf2_t q82 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 192, vl), vl); + vint16mf2_t q83 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 208, vl), vl); + vint16mf2_t q84 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 224, vl), vl); + + vint16mf2_t sum0 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq0, 1, vl)), q80, vl); + vint16mf2_t sum1 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq1, 1, vl)), q81, vl); + vint16mf2_t sum2 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq2, 1, vl)), q82, vl); + vint16mf2_t sum3 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq3, 1, vl)), q83, vl); + vint16mf2_t sum4 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq4, 1, vl)), q84, vl); + + vint16mf2_t sumi0 = __riscv_vadd_vv_i16mf2(sum0, sum1, vl); + vint16mf2_t sumi1 = __riscv_vadd_vv_i16mf2(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i16mf2(sum4, __riscv_vadd_vv_i16mf2(sumi0, sumi1, vl), vl); + } + + // Third loop. + vint16mf2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + vuint8mf4_t tq = __riscv_vlmul_trunc_v_u8mf2_u8mf4(__riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4))); + + vuint8mf4_t p = __riscv_vle8_v_u8mf4(pow, vl); + + vuint16mf2_t tq0 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vv_u8mf4(tq, p, vl), 3, vl), 8, vl); + + vint16mf2_t q80 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 240, vl), vl); + + suml3 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq0, 1, vl)), q80, vl); + } + + vint32m1_t sum = __riscv_vwredsum_vs_i16m1_i32m1(suml1, __riscv_vmv_v_x_i32m1(0, 1), 32); + sum = __riscv_vwredsum_vs_i16mf2_i32m1(__riscv_vadd_vv_i16mf2(suml2, suml3, 16), sum, 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} #endif void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_tq1_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -5761,8 +6189,8 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_tq1_0_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -5770,7 +6198,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl128(const int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -5926,24 +6354,21 @@ static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_tq2_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - case 256: + default: // 256 and above ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); - break; } #else ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -6058,7 +6483,7 @@ static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);