Skip to main content

sapient_backends_cpu/kernels/
quant.rs

1//! Quantized weight storage and on-the-fly dequantizing dot-products.
2//!
3//! The whole point of running huge models on small devices is to **keep weights
4//! quantized in memory** and dequantize one block at a time *inside* the
5//! dot-product, instead of expanding the whole weight matrix to F32 (which costs
6//! 8× the RAM for Q4). This module is the Phase-0 spike proving that mechanic:
7//! a Q4_0 matrix-vector product computed straight from the packed blocks matches
8//! the F32 reference, while storing only 0.5625 bytes/weight.
9//!
10//! Block layouts follow the canonical **ggml** conventions (so the same code
11//! reads real GGUF files in Phase 1):
12//! - `Q4_0`: 32 weights per block, 18 bytes = f16 scale + 16 packed nibble bytes.
13//!   Byte `j` holds element `j` (low nibble) and element `j+16` (high nibble).
14//! - `Q8_0`: 32 weights per block, 34 bytes = f16 scale + 32 × i8.
15
16use half::f16;
17
18// ── SIMD helpers ─────────────────────────────────────────────────────────────
19
20// aarch64: NEON intrinsics (always available on Apple Silicon and ARM64 Linux)
21#[cfg(target_arch = "aarch64")]
22use std::arch::aarch64::*;
23
24// x86_64: AVX2 runtime check wrapper for the hot row-dot path
25#[cfg(target_arch = "x86_64")]
26use std::arch::x86_64::*;
27
28/// Weights per quantized block (both Q4_0 and Q8_0 use 32).
29pub const QK: usize = 32;
30/// Bytes per Q4_0 block: 2 (f16 scale) + 16 (packed nibbles).
31pub const Q4_0_BLOCK_BYTES: usize = 18;
32/// Bytes per Q8_0 block: 2 (f16 scale) + 32 (i8 quants).
33pub const Q8_0_BLOCK_BYTES: usize = 34;
34
35// ── Q4_0 ──────────────────────────────────────────────────────────────────────
36
37/// Quantize a length-`QK` slice of f32 into one Q4_0 block (ggml convention).
38pub fn quantize_q4_0_block(x: &[f32]) -> [u8; Q4_0_BLOCK_BYTES] {
39    debug_assert_eq!(x.len(), QK);
40    // Scale from the value with the largest magnitude, preserving its sign
41    // (this is how ggml derives `d`, which is why d can be negative).
42    let mut amax = 0.0f32;
43    let mut vmax = 0.0f32;
44    for &v in x {
45        if v.abs() > amax {
46            amax = v.abs();
47            vmax = v;
48        }
49    }
50    let d = vmax / -8.0;
51    let id = if d != 0.0 { 1.0 / d } else { 0.0 };
52
53    let mut out = [0u8; Q4_0_BLOCK_BYTES];
54    out[0..2].copy_from_slice(&f16::from_f32(d).to_le_bytes());
55    for j in 0..QK / 2 {
56        let q0 = nibble(x[j] * id);
57        let q1 = nibble(x[j + QK / 2] * id);
58        out[2 + j] = q0 | (q1 << 4);
59    }
60    out
61}
62
63#[inline]
64fn nibble(scaled: f32) -> u8 {
65    // ggml: MIN(15, (int)(x*id + 8.5)). Clamp into [0, 15].
66    let q = (scaled + 8.5) as i32;
67    q.clamp(0, 15) as u8
68}
69
70/// Dequantize one Q4_0 block into `out` (length `QK`).
71pub fn dequantize_q4_0_block(block: &[u8], out: &mut [f32]) {
72    debug_assert_eq!(block.len(), Q4_0_BLOCK_BYTES);
73    debug_assert_eq!(out.len(), QK);
74    let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
75    for j in 0..QK / 2 {
76        let byte = block[2 + j];
77        let lo = (byte & 0x0f) as i32 - 8;
78        let hi = (byte >> 4) as i32 - 8;
79        out[j] = lo as f32 * d;
80        out[j + QK / 2] = hi as f32 * d;
81    }
82}
83
84/// Dot product of one Q4_0 block with a length-`QK` f32 activation slice.
85///
86/// On aarch64 the NEON path vectorises the nibble unpacking and FMA in ~8
87/// NEON instructions per block (4× width vs scalar). Falls back to scalar
88/// on every other target.
89#[inline]
90pub fn dot_q4_0_block_f32(block: &[u8], x: &[f32]) -> f32 {
91    debug_assert_eq!(block.len(), Q4_0_BLOCK_BYTES);
92    debug_assert_eq!(x.len(), QK);
93    #[cfg(target_arch = "aarch64")]
94    return unsafe { dot_q4_0_block_neon(block, x) };
95    #[cfg(not(target_arch = "aarch64"))]
96    dot_q4_0_block_scalar(block, x)
97}
98
99#[inline(always)]
100#[allow(dead_code)] // used on non-aarch64 targets only
101fn dot_q4_0_block_scalar(block: &[u8], x: &[f32]) -> f32 {
102    let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
103    let mut acc = 0.0f32;
104    for j in 0..QK / 2 {
105        let byte = block[2 + j];
106        let lo = (byte & 0x0f) as i32 - 8;
107        let hi = (byte >> 4) as i32 - 8;
108        acc += lo as f32 * x[j] + hi as f32 * x[j + QK / 2];
109    }
110    acc * d
111}
112
113/// NEON Q4_0 block dot product.
114///
115/// Processes all 16 packed nibble bytes as two 16-element NEON vectors.
116/// Lo nibbles → first 16 activations; hi nibbles → second 16 activations.
117/// Subtract 8, widen u8→i8→i16→f32, then FMA with activations.
118#[cfg(target_arch = "aarch64")]
119#[target_feature(enable = "neon")]
120unsafe fn dot_q4_0_block_neon(block: &[u8], x: &[f32]) -> f32 {
121    let scale = f16::from_le_bytes([block[0], block[1]]).to_f32();
122    let packed_ptr = block.as_ptr().add(2); // 16 bytes of packed nibbles
123
124    // Load 16 packed bytes (= 32 nibbles)
125    let packed = vld1q_u8(packed_ptr);
126
127    // Extract lo nibbles (elements 0..15) and hi nibbles (elements 16..31)
128    let lo_u8 = vandq_u8(packed, vdupq_n_u8(0x0F));
129    let hi_u8 = vshrq_n_u8(packed, 4);
130
131    // Subtract 8 (u8 wrap-sub, then reinterpret as signed i8 giving [-8, 7])
132    let eight = vdupq_n_u8(8);
133    let lo_i8 = vreinterpretq_s8_u8(vsubq_u8(lo_u8, eight));
134    let hi_i8 = vreinterpretq_s8_u8(vsubq_u8(hi_u8, eight));
135
136    // Widen i8 → i16 → i32 → f32 in two halves each, then FMA with activations.
137    // Each vmovl call produces 8 i16 values; vmovl_high_s8 / vget_low_s8 split
138    // the 16-element vector into its low and high 8-element halves.
139    macro_rules! to_f32x4 {
140        ($i8vec:expr, $half:ident) => {{
141            let i16v = $half($i8vec);
142            let lo32 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(i16v)));
143            let hi32 = vcvtq_f32_s32(vmovl_high_s16(i16v));
144            (lo32, hi32)
145        }};
146    }
147
148    let (lo_f32_0, lo_f32_1) = to_f32x4!(lo_i8, vmovl_s8_low);
149    let (lo_f32_2, lo_f32_3) = to_f32x4!(lo_i8, vmovl_s8_high);
150    let (hi_f32_0, hi_f32_1) = to_f32x4!(hi_i8, vmovl_s8_low);
151    let (hi_f32_2, hi_f32_3) = to_f32x4!(hi_i8, vmovl_s8_high);
152
153    let xp = x.as_ptr();
154    let x0 = vld1q_f32(xp);
155    let x1 = vld1q_f32(xp.add(4));
156    let x2 = vld1q_f32(xp.add(8));
157    let x3 = vld1q_f32(xp.add(12));
158    let x4 = vld1q_f32(xp.add(16));
159    let x5 = vld1q_f32(xp.add(20));
160    let x6 = vld1q_f32(xp.add(24));
161    let x7 = vld1q_f32(xp.add(28));
162
163    let mut acc = vmulq_f32(lo_f32_0, x0);
164    acc = vfmaq_f32(acc, lo_f32_1, x1);
165    acc = vfmaq_f32(acc, lo_f32_2, x2);
166    acc = vfmaq_f32(acc, lo_f32_3, x3);
167    acc = vfmaq_f32(acc, hi_f32_0, x4);
168    acc = vfmaq_f32(acc, hi_f32_1, x5);
169    acc = vfmaq_f32(acc, hi_f32_2, x6);
170    acc = vfmaq_f32(acc, hi_f32_3, x7);
171
172    vaddvq_f32(acc) * scale
173}
174
175// Helper: widen the low half of an i8x16 vector to i16x8
176#[cfg(target_arch = "aarch64")]
177#[inline(always)]
178unsafe fn vmovl_s8_low(v: int8x16_t) -> int16x8_t {
179    vmovl_s8(vget_low_s8(v))
180}
181
182// Helper: widen the high half of an i8x16 vector to i16x8
183#[cfg(target_arch = "aarch64")]
184#[inline(always)]
185unsafe fn vmovl_s8_high(v: int8x16_t) -> int16x8_t {
186    vmovl_high_s8(v)
187}
188
189/// Dot product of a full Q4_0-quantized weight row.
190///
191/// Dispatches to the SIMD block kernel (NEON on aarch64, scalar + AVX2-auto
192/// on x86) block-by-block. The row loop itself is intentionally kept simple;
193/// rayon parallelises across rows at the matmul level.
194pub fn dot_q4_0_row_f32(row_blocks: &[u8], x: &[f32]) -> f32 {
195    let k = x.len();
196    debug_assert_eq!(k % QK, 0);
197    let mut acc = 0.0f32;
198    for (b, chunk) in row_blocks.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
199        acc += dot_q4_0_block_f32(chunk, &x[b * QK..b * QK + QK]);
200    }
201    acc
202}
203
204/// Quantize a full f32 weight row (`k % QK == 0`) into packed Q4_0 blocks.
205pub fn quantize_q4_0_row(w: &[f32]) -> Vec<u8> {
206    debug_assert_eq!(w.len() % QK, 0);
207    let mut out = Vec::with_capacity(w.len() / QK * Q4_0_BLOCK_BYTES);
208    for chunk in w.chunks_exact(QK) {
209        out.extend_from_slice(&quantize_q4_0_block(chunk));
210    }
211    out
212}
213
214// ── Q8_0 ──────────────────────────────────────────────────────────────────────
215
216/// Quantize a length-`QK` (32) slice of f32 into one Q8_0 block (ggml convention).
217///
218/// Returns the 34-byte block: 2-byte f16 scale followed by 32 i8 quantized values.
219/// Used for online quantization of F16/BF16 weight matrices at load time.
220pub fn quantize_q8_0_block(x: &[f32]) -> [u8; Q8_0_BLOCK_BYTES] {
221    debug_assert_eq!(x.len(), QK);
222    let max_abs = x.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
223    let scale = max_abs / 127.0;
224    let d = half::f16::from_f32(scale);
225    let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 };
226    let mut out = [0u8; Q8_0_BLOCK_BYTES];
227    out[0..2].copy_from_slice(&d.to_le_bytes());
228    for (i, &v) in x.iter().enumerate() {
229        out[2 + i] = (v * inv_scale).round().clamp(-127.0, 127.0) as i8 as u8;
230    }
231    out
232}
233
234/// Dot product of a Q8_0 block with an f32 activation slice.
235///
236/// On aarch64 widens i8→i16→f32 with NEON vfmaq, processing 8 elements per
237/// instruction. On x86 the scalar loop auto-vectorises to SSE/AVX.
238#[inline]
239pub fn dot_q8_0_block_f32(block: &[u8], x: &[f32]) -> f32 {
240    debug_assert_eq!(block.len(), Q8_0_BLOCK_BYTES);
241    debug_assert_eq!(x.len(), QK);
242    #[cfg(target_arch = "aarch64")]
243    return unsafe { dot_q8_0_block_neon(block, x) };
244    #[cfg(not(target_arch = "aarch64"))]
245    dot_q8_0_block_scalar(block, x)
246}
247
248#[inline(always)]
249#[allow(dead_code)] // used on non-aarch64 targets only
250fn dot_q8_0_block_scalar(block: &[u8], x: &[f32]) -> f32 {
251    let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
252    let mut acc = 0.0f32;
253    for j in 0..QK {
254        acc += block[2 + j] as i8 as f32 * x[j];
255    }
256    acc * d
257}
258
259/// NEON Q8_0 block dot product: widen four groups of 8 i8 values to f32,
260/// then fused-multiply-accumulate with f32 activations.
261#[cfg(target_arch = "aarch64")]
262#[target_feature(enable = "neon")]
263unsafe fn dot_q8_0_block_neon(block: &[u8], x: &[f32]) -> f32 {
264    let scale = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
265    let q_ptr = block.as_ptr().add(2) as *const i8;
266    let xp = x.as_ptr();
267    let mut acc = vdupq_n_f32(0.0);
268
269    macro_rules! fma_group {
270        ($qoff:expr, $xoff:expr) => {{
271            let q8 = vld1_s8(q_ptr.add($qoff));
272            let q16 = vmovl_s8(q8);
273            let qlo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
274            let qhi = vcvtq_f32_s32(vmovl_high_s16(q16));
275            acc = vfmaq_f32(acc, qlo, vld1q_f32(xp.add($xoff)));
276            acc = vfmaq_f32(acc, qhi, vld1q_f32(xp.add($xoff + 4)));
277        }};
278    }
279
280    fma_group!(0, 0);
281    fma_group!(8, 8);
282    fma_group!(16, 16);
283    fma_group!(24, 24);
284
285    vaddvq_f32(acc) * scale
286}
287
288// ── SDOT (ARMv8.4-A dotprod) hot path ─────────────────────────────────────────
289//
290// Strategy: quantize the f32 activation vector to i8 ONCE per GEMV row, then
291// use `vdotq_s32` (4 × 4-element integer dot products per cycle) for every
292// weight block.  Compared to the widening path (i8→i16→i32→f32 per element),
293// this needs ~6 NEON ops per Q8_0 block vs ~40, delivering a 4–5× compute
294// uplift on Apple Silicon M-series and DGX Spark (Grace ARM64 CPU).
295//
296// Accuracy: quantizing x to i8 (same resolution as Q8_0 weights) introduces
297// ≈ max|x| / (127 × √K) RMS error — indistinguishable from Q8_0 weight noise.
298
299/// Quantize a row of f32 activations to i8 (symmetric Q8_0 style).
300/// Returns (quantized_bytes, scale) where scale = max_abs / 127.
301/// The caller multiplies each block's weight scale by this scale to recover f32.
302pub fn quantize_row_to_i8(x: &[f32]) -> (Vec<i8>, f32) {
303    let max_abs = x.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
304    let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
305    let inv = if scale > 0.0 { 1.0 / scale } else { 0.0 };
306    let q = x
307        .iter()
308        .map(|v| (v * inv).round().clamp(-127.0, 127.0) as i8)
309        .collect();
310    (q, scale)
311}
312
313/// Q8_0 block dot product against a pre-quantized i8 activation slice.
314///
315/// Uses the ARMv8.4-A `sdot` instruction via inline assembly — stable Rust,
316/// no unstable features required. `vdotq_s32` is still behind an unstable
317/// feature gate, but inline asm lets us emit the same bytes directly.
318///
319/// `sdot v0.4s, v1.16b, v2.16b` computes four 4-element i8 dot products
320/// into an i32x4 accumulator in one instruction (16 MAC ops per cycle).
321/// Two calls cover all 32 elements of a Q8_0 block.
322#[cfg(target_arch = "aarch64")]
323#[target_feature(enable = "neon")]
324unsafe fn dot_q8_0_block_sdot(block: &[u8], x_i8: &[i8]) -> i32 {
325    use std::arch::aarch64::*;
326    debug_assert_eq!(block.len(), Q8_0_BLOCK_BYTES);
327    debug_assert_eq!(x_i8.len(), QK);
328
329    let w_ptr = block.as_ptr().add(2) as *const i8;
330    let x_ptr = x_i8.as_ptr();
331
332    let w0 = vld1q_s8(w_ptr);
333    let x0 = vld1q_s8(x_ptr);
334    let w1 = vld1q_s8(w_ptr.add(16));
335    let x1 = vld1q_s8(x_ptr.add(16));
336
337    let mut acc = vdupq_n_s32(0i32);
338    // sdot v_acc.4s, v_w.16b, v_x.16b — ARM SDOT instruction via inline asm.
339    // The :v modifier formats the register as the 128-bit v-register view.
340    core::arch::asm!(
341        "sdot {0:v}.4s, {1:v}.16b, {2:v}.16b",
342        inout(vreg) acc,
343        in(vreg) w0,
344        in(vreg) x0,
345        options(nomem, nostack),
346    );
347    core::arch::asm!(
348        "sdot {0:v}.4s, {1:v}.16b, {2:v}.16b",
349        inout(vreg) acc,
350        in(vreg) w1,
351        in(vreg) x1,
352        options(nomem, nostack),
353    );
354    vaddvq_s32(acc)
355}
356
357/// Full Q8_0 row dot product with pre-quantized i8 activations.
358/// Called by matmul_nt_q8_0 after quantize_row_to_i8 when dotprod is available.
359///
360/// # Safety
361/// Caller must verify `is_aarch64_feature_detected!("dotprod")` before calling.
362/// `row_blocks` must be a valid slice of packed Q8_0 blocks; `x_i8` must have
363/// length equal to the number of elements covered by those blocks.
364#[cfg(target_arch = "aarch64")]
365#[target_feature(enable = "neon")]
366pub unsafe fn dot_q8_0_row_sdot(row_blocks: &[u8], x_i8: &[i8], x_scale: f32) -> f32 {
367    let mut acc = 0.0f32;
368    let mut x_off = 0usize;
369    for block in row_blocks.chunks_exact(Q8_0_BLOCK_BYTES) {
370        let w_scale = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
371        let dot = dot_q8_0_block_sdot(block, &x_i8[x_off..x_off + QK]);
372        acc += w_scale * x_scale * dot as f32;
373        x_off += QK;
374    }
375    acc
376}
377
378/// Scalar fallback for dot_q8_0_row_sdot (non-dotprod aarch64 or other platforms).
379/// Uses i32 integer arithmetic — no widening chain, still faster than the f32 path
380/// for targets without AVX2, and correct everywhere.
381pub fn dot_q8_0_row_i8_scalar(row_blocks: &[u8], x_i8: &[i8], x_scale: f32) -> f32 {
382    let mut acc = 0.0f32;
383    let mut x_off = 0usize;
384    for block in row_blocks.chunks_exact(Q8_0_BLOCK_BYTES) {
385        let w_scale = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
386        let w = &block[2..];
387        let dot: i32 = w[..QK]
388            .iter()
389            .zip(&x_i8[x_off..x_off + QK])
390            .map(|(&wi, &xi)| wi as i8 as i32 * xi as i32)
391            .sum();
392        acc += w_scale * x_scale * dot as f32;
393        x_off += QK;
394    }
395    acc
396}
397
398/// Dot product of a full Q8_0-quantized weight row with an f32 activation vector.
399///
400/// On x86_64 with AVX2 at runtime, uses the wider FMA path that processes
401/// 8 floats per cycle; otherwise falls back to the NEON or scalar block kernel.
402pub fn dot_q8_0_row_f32(row_blocks: &[u8], x: &[f32]) -> f32 {
403    #[cfg(target_arch = "x86_64")]
404    if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
405        return unsafe { dot_q8_0_row_avx2(row_blocks, x) };
406    }
407    let k = x.len();
408    debug_assert_eq!(k % QK, 0);
409    let mut acc = 0.0f32;
410    for (b, chunk) in row_blocks.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
411        acc += dot_q8_0_block_f32(chunk, &x[b * QK..b * QK + QK]);
412    }
413    acc
414}
415
416/// AVX2+FMA path for Q8_0 row dot product on x86_64.
417///
418/// Processes 8 f32 values per `_mm256_fmadd_ps` instruction.  The i8→f32
419/// widening uses `_mm256_cvtepi8_epi32` to convert 8 i8 at a time.
420#[cfg(target_arch = "x86_64")]
421#[target_feature(enable = "avx2,fma")]
422unsafe fn dot_q8_0_row_avx2(row_blocks: &[u8], x: &[f32]) -> f32 {
423    let k = x.len();
424    debug_assert_eq!(k % QK, 0);
425    let mut row_acc = _mm256_setzero_ps();
426
427    for (b, block) in row_blocks.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
428        let scale = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
429        let q_ptr = block.as_ptr().add(2) as *const i32; // load 4 bytes at a time
430        let xp = x.as_ptr().add(b * QK);
431        let mut block_acc = _mm256_setzero_ps();
432
433        // 4 groups of 8 values each (8 i8 → 8 i32 → 8 f32)
434        for g in 0..4usize {
435            let q_i32_4 = _mm_loadu_si32(q_ptr.add(2 * g) as *const _); // 4 bytes
436            let q_i32_4b = _mm_loadu_si32(q_ptr.add(2 * g + 1) as *const _);
437            let q_a = _mm256_cvtepi8_epi32(q_i32_4); // 4 i8 → 4 i32 (low lane)
438            let q_b = _mm256_cvtepi8_epi32(q_i32_4b); // next 4
439            let xv_a = _mm256_loadu_ps(xp.add(g * 8));
440            let xv_b = _mm256_loadu_ps(xp.add(g * 8 + 4));
441            let qf_a = _mm256_cvtepi32_ps(q_a);
442            let qf_b = _mm256_cvtepi32_ps(q_b);
443            block_acc = _mm256_fmadd_ps(qf_a, xv_a, block_acc);
444            block_acc = _mm256_fmadd_ps(qf_b, xv_b, block_acc);
445        }
446        // Horizontal sum of block_acc, multiply by scale, add to row accumulator
447        let scale_v = _mm256_set1_ps(scale);
448        row_acc = _mm256_fmadd_ps(block_acc, scale_v, row_acc);
449    }
450
451    // Horizontal sum of the 8-lane AVX2 accumulator
452    let lo = _mm256_castps256_ps128(row_acc);
453    let hi = _mm256_extractf128_ps(row_acc, 1);
454    let sum4 = _mm_add_ps(lo, hi);
455    let shuf = _mm_movehdup_ps(sum4);
456    let sum2 = _mm_add_ps(sum4, shuf);
457    let sum1 = _mm_add_ss(sum2, _mm_movehl_ps(shuf, sum2));
458    _mm_cvtss_f32(sum1)
459}
460
461// ── K-quants (Q4_K, Q5_K, Q6_K) ─────────────────────────────────────────────
462//
463// K-quant blocks use QK_K = 256 elements per block with super-block scaling.
464// Weights are kept as packed blocks and dequantized one block at a time inside
465// the dot product — no F32 expansion at load time.
466
467pub const QK_K: usize = 256;
468pub const Q4_K_BLOCK_BYTES: usize = 144;
469pub const Q5_K_BLOCK_BYTES: usize = 176;
470pub const Q6_K_BLOCK_BYTES: usize = 210;
471
472/// Extract 6-bit scale and min for K-quant sub-block `j` (0..7) from the
473/// 12-byte `scales` field of a Q4_K or Q5_K block.
474#[inline(always)]
475fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
476    if j < 4 {
477        (scales[j] & 63, scales[j + 4] & 63)
478    } else {
479        (
480            (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4),
481            (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4),
482        )
483    }
484}
485
486/// Dot product of a full Q4_K-quantized weight row with an f32 activation vector.
487///
488/// Q4_K block layout (144 bytes, 256 weights):
489///   [0..1] d (f16) — super-block scale
490///   [2..3] dmin (f16) — super-block min scale
491///   [4..15] scales (12 bytes) — 8 pairs of 6-bit (scale, min) packed
492///   [16..143] qs (128 bytes) — 256 × 4-bit quantized values (lo/hi nibble)
493///
494/// Dispatches to the NEON path on aarch64, scalar otherwise.
495pub fn dot_q4_k_row_f32(row_data: &[u8], x: &[f32]) -> f32 {
496    #[cfg(target_arch = "aarch64")]
497    return unsafe { dot_q4_k_row_f32_neon(row_data, x) };
498    #[cfg(not(target_arch = "aarch64"))]
499    dot_q4_k_row_f32_scalar(row_data, x)
500}
501
502/// Scalar fallback for Q4_K row dot product.
503#[allow(dead_code)]
504fn dot_q4_k_row_f32_scalar(row_data: &[u8], x: &[f32]) -> f32 {
505    let mut acc = 0.0f32;
506    let mut x_off = 0usize;
507    for block in row_data.chunks_exact(Q4_K_BLOCK_BYTES) {
508        let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
509        let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
510        let scales = &block[4..16];
511        let qs = &block[16..Q4_K_BLOCK_BYTES];
512        let mut q_off = 0usize;
513        let mut is = 0usize;
514        for _ in 0..(QK_K / 64) {
515            let (sc1, m1) = get_scale_min_k4(is, scales);
516            let d1 = d * sc1 as f32;
517            let m1v = dmin * m1 as f32;
518            let (sc2, m2) = get_scale_min_k4(is + 1, scales);
519            let d2 = d * sc2 as f32;
520            let m2v = dmin * m2 as f32;
521            for l in 0..32 {
522                acc += (d1 * (qs[q_off + l] & 0x0F) as f32 - m1v) * x[x_off + l];
523                acc += (d2 * (qs[q_off + l] >> 4) as f32 - m2v) * x[x_off + l + 32];
524            }
525            x_off += 64;
526            q_off += 32;
527            is += 2;
528        }
529    }
530    acc
531}
532
533/// NEON Q4_K row dot product (aarch64).
534///
535/// Processes 8 bytes (16 nibbles) per NEON iteration with FMA for both the
536/// lo-nibble (sub-block 1) and hi-nibble (sub-block 2) contributions.
537/// Also accumulates the x sums required for the min correction term.
538#[cfg(target_arch = "aarch64")]
539#[target_feature(enable = "neon")]
540unsafe fn dot_q4_k_row_f32_neon(row_data: &[u8], x: &[f32]) -> f32 {
541    let mut acc = 0.0f32;
542    let mut x_off = 0usize;
543    let mask4 = vdup_n_u8(0x0F);
544
545    for block in row_data.chunks_exact(Q4_K_BLOCK_BYTES) {
546        let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
547        let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
548        let scales = &block[4..16];
549        let qs = &block[16..Q4_K_BLOCK_BYTES];
550
551        let mut q_off = 0usize;
552        let mut is = 0usize;
553
554        // QK_K / 64 = 4 iterations, each handling 64 weights (32 lo-nibble + 32 hi-nibble)
555        for _ in 0..(QK_K / 64) {
556            let (sc1, m1) = get_scale_min_k4(is, scales);
557            let (sc2, m2) = get_scale_min_k4(is + 1, scales);
558            let d1 = d * sc1 as f32;
559            let m1v = dmin * m1 as f32;
560            let d2 = d * sc2 as f32;
561            let m2v = dmin * m2 as f32;
562
563            // x_lo = activations for lo-nibble sub-block (indices x_off..x_off+32)
564            // x_hi = activations for hi-nibble sub-block (indices x_off+32..x_off+64)
565            let x_lo = &x[x_off..x_off + 32];
566            let x_hi = &x[x_off + 32..x_off + 64];
567
568            // NEON: process 8 bytes of qs at a time -> 8 lo-nibbles, 8 hi-nibbles
569            // 4 rounds x 8 bytes = 32 bytes total (covers the 32 lo and 32 hi elements)
570            let mut vsum_lo = vdupq_n_f32(0.0); // dot(lo_nibbles, x_lo)
571            let mut vsum_hi = vdupq_n_f32(0.0); // dot(hi_nibbles, x_hi)
572            let mut vsum_xl = vdupq_n_f32(0.0); // sum(x_lo) for min correction
573            let mut vsum_xh = vdupq_n_f32(0.0); // sum(x_hi) for min correction
574
575            for chunk in 0..4usize {
576                // Load 8 packed bytes -> 8 lo nibbles + 8 hi nibbles
577                let q8 = vld1_u8(qs.as_ptr().add(q_off + chunk * 8));
578                let lo8 = vand_u8(q8, mask4);
579                let hi8 = vshr_n_u8::<4>(q8);
580
581                // Widen u8x8 -> u16x8 -> two u32x4 -> two f32x4
582                let lo16 = vmovl_u8(lo8);
583                let lof0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16)));
584                let lof1 = vcvtq_f32_u32(vmovl_high_u16(lo16));
585
586                let hi16 = vmovl_u8(hi8);
587                let hif0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16)));
588                let hif1 = vcvtq_f32_u32(vmovl_high_u16(hi16));
589
590                // Load 8 activation values for each sub-block
591                let xl0 = vld1q_f32(x_lo.as_ptr().add(chunk * 8));
592                let xl1 = vld1q_f32(x_lo.as_ptr().add(chunk * 8 + 4));
593                let xh0 = vld1q_f32(x_hi.as_ptr().add(chunk * 8));
594                let xh1 = vld1q_f32(x_hi.as_ptr().add(chunk * 8 + 4));
595
596                vsum_lo = vfmaq_f32(vsum_lo, lof0, xl0);
597                vsum_lo = vfmaq_f32(vsum_lo, lof1, xl1);
598                vsum_hi = vfmaq_f32(vsum_hi, hif0, xh0);
599                vsum_hi = vfmaq_f32(vsum_hi, hif1, xh1);
600                vsum_xl = vaddq_f32(vsum_xl, vaddq_f32(xl0, xl1));
601                vsum_xh = vaddq_f32(vsum_xh, vaddq_f32(xh0, xh1));
602            }
603
604            // acc += d1 * sum(lo * x_lo) - m1v * sum(x_lo)
605            // acc += d2 * sum(hi * x_hi) - m2v * sum(x_hi)
606            acc += d1 * vaddvq_f32(vsum_lo) - m1v * vaddvq_f32(vsum_xl);
607            acc += d2 * vaddvq_f32(vsum_hi) - m2v * vaddvq_f32(vsum_xh);
608
609            x_off += 64;
610            q_off += 32;
611            is += 2;
612        }
613    }
614    acc
615}
616
617/// Dot product of a full Q5_K-quantized weight row with an f32 activation vector.
618///
619/// Q5_K block layout (176 bytes, 256 weights):
620///   [0..1] d (f16), [2..3] dmin (f16), [4..15] scales (12B),
621///   [16..47] qh (32B — high bits, one per 32-weight sub-block),
622///   [48..175] ql (128B — low 4-bit nibbles)
623pub fn dot_q5_k_row_f32(row_data: &[u8], x: &[f32]) -> f32 {
624    let mut acc = 0.0f32;
625    let mut x_off = 0usize;
626    for block in row_data.chunks_exact(Q5_K_BLOCK_BYTES) {
627        let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
628        let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
629        let scales = &block[4..16];
630        let qh = &block[16..48];
631        let ql = &block[48..Q5_K_BLOCK_BYTES];
632        let mut ql_off = 0usize;
633        let mut is = 0usize;
634        let mut u1: u8 = 1;
635        let mut u2: u8 = 2;
636        for _ in 0..(QK_K / 64) {
637            let (sc1, m1) = get_scale_min_k4(is, scales);
638            let d1 = d * sc1 as f32;
639            let m1v = dmin * m1 as f32;
640            let (sc2, m2) = get_scale_min_k4(is + 1, scales);
641            let d2 = d * sc2 as f32;
642            let m2v = dmin * m2 as f32;
643            let qh_byte = qh[is / 8];
644            for l in 0..32 {
645                let hi1 = if qh_byte & u1 != 0 { 16.0f32 } else { 0.0 };
646                let hi2 = if qh_byte & u2 != 0 { 16.0f32 } else { 0.0 };
647                acc += (d1 * ((ql[ql_off + l] & 0x0F) as f32 + hi1) - m1v) * x[x_off + l];
648                acc += (d2 * ((ql[ql_off + l] >> 4) as f32 + hi2) - m2v) * x[x_off + l + 32];
649            }
650            x_off += 64;
651            ql_off += 32;
652            is += 2;
653            if is % 8 == 0 {
654                u1 = 1;
655                u2 = 2;
656            } else {
657                u1 <<= 2;
658                u2 <<= 2;
659            }
660        }
661    }
662    acc
663}
664
665/// Dot product of a full Q6_K-quantized weight row with an f32 activation vector.
666///
667/// Q6_K block layout (210 bytes, 256 weights):
668///   [0..127] ql (128B — low 4-bit nibbles)
669///   [128..191] qh (64B — upper 2 bits, two per byte)
670///   [192..207] scales (16B — i8 per 16-element group)
671///   [208..209] d (f16)
672pub fn dot_q6_k_row_f32(row_data: &[u8], x: &[f32]) -> f32 {
673    let mut acc = 0.0f32;
674    let mut x_off = 0usize;
675    for block in row_data.chunks_exact(Q6_K_BLOCK_BYTES) {
676        let ql = &block[0..128];
677        let qh = &block[128..192];
678        let sc = &block[192..208];
679        let d = f16::from_le_bytes([block[208], block[209]]).to_f32();
680        let mut ql_off = 0usize;
681        let mut qh_off = 0usize;
682        let mut ib = 0usize;
683        for _ in 0..(QK_K / 128) {
684            for l in 0..32 {
685                let q1 =
686                    (((ql[ql_off + l] & 0x0F) | ((qh[qh_off + l] & 3) << 4)) as i32 - 32) as f32;
687                let q2 = (((ql[ql_off + l + 32] & 0x0F) | (((qh[qh_off + l] >> 2) & 3) << 4))
688                    as i32
689                    - 32) as f32;
690                let q3 = (((ql[ql_off + l] >> 4) | (((qh[qh_off + l] >> 4) & 3) << 4)) as i32 - 32)
691                    as f32;
692                let q4 = (((ql[ql_off + l + 32] >> 4) | (((qh[qh_off + l] >> 6) & 3) << 4)) as i32
693                    - 32) as f32;
694                acc += d * sc[ib] as i8 as f32 * q1 * x[x_off + l];
695                acc += d * sc[ib + 1] as i8 as f32 * q2 * x[x_off + l + 32];
696                acc += d * sc[ib + 2] as i8 as f32 * q3 * x[x_off + l + 64];
697                acc += d * sc[ib + 3] as i8 as f32 * q4 * x[x_off + l + 96];
698            }
699            x_off += 128;
700            ql_off += 64;
701            qh_off += 32;
702            ib += 4;
703        }
704    }
705    acc
706}
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711
712    // Deterministic pseudo-random f32 in roughly [-1, 1] (no rand dependency).
713    fn seq(n: usize) -> Vec<f32> {
714        let mut s: u64 = 0x9E3779B97F4A7C15;
715        (0..n)
716            .map(|_| {
717                s ^= s << 13;
718                s ^= s >> 7;
719                s ^= s << 17;
720                ((s >> 40) as f32 / (1u32 << 24) as f32) * 2.0 - 1.0
721            })
722            .collect()
723    }
724
725    #[test]
726    fn q4_0_on_the_fly_dot_matches_dequantized_reference() {
727        let k = 256;
728        let w = seq(k);
729        let x = seq(k).iter().map(|v| v * 0.5).collect::<Vec<_>>();
730
731        let blocks = quantize_q4_0_row(&w);
732        // Storage: 18 bytes per 32 weights = 0.5625 B/weight vs 4 B for F32.
733        assert_eq!(blocks.len(), k / QK * Q4_0_BLOCK_BYTES);
734
735        // Reference: dequantize fully, then dot.
736        let mut w_hat = vec![0.0f32; k];
737        for (b, chunk) in blocks.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
738            dequantize_q4_0_block(chunk, &mut w_hat[b * QK..b * QK + QK]);
739        }
740        let reference: f32 = w_hat.iter().zip(&x).map(|(a, b)| a * b).sum();
741
742        // On-the-fly path (what the real kernel will do): must match the
743        // dequantized reference to floating-point tolerance.
744        let on_the_fly = dot_q4_0_row_f32(&blocks, &x);
745        assert!(
746            (on_the_fly - reference).abs() < 1e-3,
747            "on-the-fly {on_the_fly} vs reference {reference}"
748        );
749    }
750
751    #[test]
752    fn q4_0_quantization_error_is_bounded() {
753        // Dequantized weights should track the originals within Q4 granularity.
754        let w = seq(QK * 4);
755        let blocks = quantize_q4_0_row(&w);
756        let mut w_hat = vec![0.0f32; w.len()];
757        for (b, chunk) in blocks.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
758            dequantize_q4_0_block(chunk, &mut w_hat[b * QK..b * QK + QK]);
759        }
760        // Max abs error within a block ≤ ~scale (one quant step). With |w|≤1 and
761        // 4-bit range, the step is small; assert a loose but real bound.
762        let max_err = w
763            .iter()
764            .zip(&w_hat)
765            .map(|(a, b)| (a - b).abs())
766            .fold(0.0f32, f32::max);
767        assert!(max_err < 0.2, "max quant error {max_err} too large");
768    }
769}