1use half::f16;
17
18#[cfg(target_arch = "aarch64")]
22use std::arch::aarch64::*;
23
24#[cfg(target_arch = "x86_64")]
26use std::arch::x86_64::*;
27
28pub const QK: usize = 32;
30pub const Q4_0_BLOCK_BYTES: usize = 18;
32pub const Q8_0_BLOCK_BYTES: usize = 34;
34
35pub fn quantize_q4_0_block(x: &[f32]) -> [u8; Q4_0_BLOCK_BYTES] {
39 debug_assert_eq!(x.len(), QK);
40 let mut amax = 0.0f32;
43 let mut vmax = 0.0f32;
44 for &v in x {
45 if v.abs() > amax {
46 amax = v.abs();
47 vmax = v;
48 }
49 }
50 let d = vmax / -8.0;
51 let id = if d != 0.0 { 1.0 / d } else { 0.0 };
52
53 let mut out = [0u8; Q4_0_BLOCK_BYTES];
54 out[0..2].copy_from_slice(&f16::from_f32(d).to_le_bytes());
55 for j in 0..QK / 2 {
56 let q0 = nibble(x[j] * id);
57 let q1 = nibble(x[j + QK / 2] * id);
58 out[2 + j] = q0 | (q1 << 4);
59 }
60 out
61}
62
63#[inline]
64fn nibble(scaled: f32) -> u8 {
65 let q = (scaled + 8.5) as i32;
67 q.clamp(0, 15) as u8
68}
69
70pub fn dequantize_q4_0_block(block: &[u8], out: &mut [f32]) {
72 debug_assert_eq!(block.len(), Q4_0_BLOCK_BYTES);
73 debug_assert_eq!(out.len(), QK);
74 let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
75 for j in 0..QK / 2 {
76 let byte = block[2 + j];
77 let lo = (byte & 0x0f) as i32 - 8;
78 let hi = (byte >> 4) as i32 - 8;
79 out[j] = lo as f32 * d;
80 out[j + QK / 2] = hi as f32 * d;
81 }
82}
83
84#[inline]
90pub fn dot_q4_0_block_f32(block: &[u8], x: &[f32]) -> f32 {
91 debug_assert_eq!(block.len(), Q4_0_BLOCK_BYTES);
92 debug_assert_eq!(x.len(), QK);
93 #[cfg(target_arch = "aarch64")]
94 return unsafe { dot_q4_0_block_neon(block, x) };
95 #[cfg(not(target_arch = "aarch64"))]
96 dot_q4_0_block_scalar(block, x)
97}
98
99#[inline(always)]
100#[allow(dead_code)] fn dot_q4_0_block_scalar(block: &[u8], x: &[f32]) -> f32 {
102 let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
103 let mut acc = 0.0f32;
104 for j in 0..QK / 2 {
105 let byte = block[2 + j];
106 let lo = (byte & 0x0f) as i32 - 8;
107 let hi = (byte >> 4) as i32 - 8;
108 acc += lo as f32 * x[j] + hi as f32 * x[j + QK / 2];
109 }
110 acc * d
111}
112
113#[cfg(target_arch = "aarch64")]
119#[target_feature(enable = "neon")]
120unsafe fn dot_q4_0_block_neon(block: &[u8], x: &[f32]) -> f32 {
121 let scale = f16::from_le_bytes([block[0], block[1]]).to_f32();
122 let packed_ptr = block.as_ptr().add(2); let packed = vld1q_u8(packed_ptr);
126
127 let lo_u8 = vandq_u8(packed, vdupq_n_u8(0x0F));
129 let hi_u8 = vshrq_n_u8(packed, 4);
130
131 let eight = vdupq_n_u8(8);
133 let lo_i8 = vreinterpretq_s8_u8(vsubq_u8(lo_u8, eight));
134 let hi_i8 = vreinterpretq_s8_u8(vsubq_u8(hi_u8, eight));
135
136 macro_rules! to_f32x4 {
140 ($i8vec:expr, $half:ident) => {{
141 let i16v = $half($i8vec);
142 let lo32 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(i16v)));
143 let hi32 = vcvtq_f32_s32(vmovl_high_s16(i16v));
144 (lo32, hi32)
145 }};
146 }
147
148 let (lo_f32_0, lo_f32_1) = to_f32x4!(lo_i8, vmovl_s8_low);
149 let (lo_f32_2, lo_f32_3) = to_f32x4!(lo_i8, vmovl_s8_high);
150 let (hi_f32_0, hi_f32_1) = to_f32x4!(hi_i8, vmovl_s8_low);
151 let (hi_f32_2, hi_f32_3) = to_f32x4!(hi_i8, vmovl_s8_high);
152
153 let xp = x.as_ptr();
154 let x0 = vld1q_f32(xp);
155 let x1 = vld1q_f32(xp.add(4));
156 let x2 = vld1q_f32(xp.add(8));
157 let x3 = vld1q_f32(xp.add(12));
158 let x4 = vld1q_f32(xp.add(16));
159 let x5 = vld1q_f32(xp.add(20));
160 let x6 = vld1q_f32(xp.add(24));
161 let x7 = vld1q_f32(xp.add(28));
162
163 let mut acc = vmulq_f32(lo_f32_0, x0);
164 acc = vfmaq_f32(acc, lo_f32_1, x1);
165 acc = vfmaq_f32(acc, lo_f32_2, x2);
166 acc = vfmaq_f32(acc, lo_f32_3, x3);
167 acc = vfmaq_f32(acc, hi_f32_0, x4);
168 acc = vfmaq_f32(acc, hi_f32_1, x5);
169 acc = vfmaq_f32(acc, hi_f32_2, x6);
170 acc = vfmaq_f32(acc, hi_f32_3, x7);
171
172 vaddvq_f32(acc) * scale
173}
174
175#[cfg(target_arch = "aarch64")]
177#[inline(always)]
178unsafe fn vmovl_s8_low(v: int8x16_t) -> int16x8_t {
179 vmovl_s8(vget_low_s8(v))
180}
181
182#[cfg(target_arch = "aarch64")]
184#[inline(always)]
185unsafe fn vmovl_s8_high(v: int8x16_t) -> int16x8_t {
186 vmovl_high_s8(v)
187}
188
189pub fn dot_q4_0_row_f32(row_blocks: &[u8], x: &[f32]) -> f32 {
195 let k = x.len();
196 debug_assert_eq!(k % QK, 0);
197 let mut acc = 0.0f32;
198 for (b, chunk) in row_blocks.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
199 acc += dot_q4_0_block_f32(chunk, &x[b * QK..b * QK + QK]);
200 }
201 acc
202}
203
204pub fn quantize_q4_0_row(w: &[f32]) -> Vec<u8> {
206 debug_assert_eq!(w.len() % QK, 0);
207 let mut out = Vec::with_capacity(w.len() / QK * Q4_0_BLOCK_BYTES);
208 for chunk in w.chunks_exact(QK) {
209 out.extend_from_slice(&quantize_q4_0_block(chunk));
210 }
211 out
212}
213
214pub fn quantize_q8_0_block(x: &[f32]) -> [u8; Q8_0_BLOCK_BYTES] {
221 debug_assert_eq!(x.len(), QK);
222 let max_abs = x.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
223 let scale = max_abs / 127.0;
224 let d = half::f16::from_f32(scale);
225 let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 };
226 let mut out = [0u8; Q8_0_BLOCK_BYTES];
227 out[0..2].copy_from_slice(&d.to_le_bytes());
228 for (i, &v) in x.iter().enumerate() {
229 out[2 + i] = (v * inv_scale).round().clamp(-127.0, 127.0) as i8 as u8;
230 }
231 out
232}
233
234#[inline]
239pub fn dot_q8_0_block_f32(block: &[u8], x: &[f32]) -> f32 {
240 debug_assert_eq!(block.len(), Q8_0_BLOCK_BYTES);
241 debug_assert_eq!(x.len(), QK);
242 #[cfg(target_arch = "aarch64")]
243 return unsafe { dot_q8_0_block_neon(block, x) };
244 #[cfg(not(target_arch = "aarch64"))]
245 dot_q8_0_block_scalar(block, x)
246}
247
248#[inline(always)]
249#[allow(dead_code)] fn dot_q8_0_block_scalar(block: &[u8], x: &[f32]) -> f32 {
251 let d = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
252 let mut acc = 0.0f32;
253 for j in 0..QK {
254 acc += block[2 + j] as i8 as f32 * x[j];
255 }
256 acc * d
257}
258
259#[cfg(target_arch = "aarch64")]
262#[target_feature(enable = "neon")]
263unsafe fn dot_q8_0_block_neon(block: &[u8], x: &[f32]) -> f32 {
264 let scale = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
265 let q_ptr = block.as_ptr().add(2) as *const i8;
266 let xp = x.as_ptr();
267 let mut acc = vdupq_n_f32(0.0);
268
269 macro_rules! fma_group {
270 ($qoff:expr, $xoff:expr) => {{
271 let q8 = vld1_s8(q_ptr.add($qoff));
272 let q16 = vmovl_s8(q8);
273 let qlo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
274 let qhi = vcvtq_f32_s32(vmovl_high_s16(q16));
275 acc = vfmaq_f32(acc, qlo, vld1q_f32(xp.add($xoff)));
276 acc = vfmaq_f32(acc, qhi, vld1q_f32(xp.add($xoff + 4)));
277 }};
278 }
279
280 fma_group!(0, 0);
281 fma_group!(8, 8);
282 fma_group!(16, 16);
283 fma_group!(24, 24);
284
285 vaddvq_f32(acc) * scale
286}
287
288pub fn quantize_row_to_i8(x: &[f32]) -> (Vec<i8>, f32) {
303 let max_abs = x.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
304 let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
305 let inv = if scale > 0.0 { 1.0 / scale } else { 0.0 };
306 let q = x
307 .iter()
308 .map(|v| (v * inv).round().clamp(-127.0, 127.0) as i8)
309 .collect();
310 (q, scale)
311}
312
313#[cfg(target_arch = "aarch64")]
323#[target_feature(enable = "neon")]
324unsafe fn dot_q8_0_block_sdot(block: &[u8], x_i8: &[i8]) -> i32 {
325 use std::arch::aarch64::*;
326 debug_assert_eq!(block.len(), Q8_0_BLOCK_BYTES);
327 debug_assert_eq!(x_i8.len(), QK);
328
329 let w_ptr = block.as_ptr().add(2) as *const i8;
330 let x_ptr = x_i8.as_ptr();
331
332 let w0 = vld1q_s8(w_ptr);
333 let x0 = vld1q_s8(x_ptr);
334 let w1 = vld1q_s8(w_ptr.add(16));
335 let x1 = vld1q_s8(x_ptr.add(16));
336
337 let mut acc = vdupq_n_s32(0i32);
338 core::arch::asm!(
341 "sdot {0:v}.4s, {1:v}.16b, {2:v}.16b",
342 inout(vreg) acc,
343 in(vreg) w0,
344 in(vreg) x0,
345 options(nomem, nostack),
346 );
347 core::arch::asm!(
348 "sdot {0:v}.4s, {1:v}.16b, {2:v}.16b",
349 inout(vreg) acc,
350 in(vreg) w1,
351 in(vreg) x1,
352 options(nomem, nostack),
353 );
354 vaddvq_s32(acc)
355}
356
357#[cfg(target_arch = "aarch64")]
365#[target_feature(enable = "neon")]
366pub unsafe fn dot_q8_0_row_sdot(row_blocks: &[u8], x_i8: &[i8], x_scale: f32) -> f32 {
367 let mut acc = 0.0f32;
368 let mut x_off = 0usize;
369 for block in row_blocks.chunks_exact(Q8_0_BLOCK_BYTES) {
370 let w_scale = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
371 let dot = dot_q8_0_block_sdot(block, &x_i8[x_off..x_off + QK]);
372 acc += w_scale * x_scale * dot as f32;
373 x_off += QK;
374 }
375 acc
376}
377
378pub fn dot_q8_0_row_i8_scalar(row_blocks: &[u8], x_i8: &[i8], x_scale: f32) -> f32 {
382 let mut acc = 0.0f32;
383 let mut x_off = 0usize;
384 for block in row_blocks.chunks_exact(Q8_0_BLOCK_BYTES) {
385 let w_scale = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
386 let w = &block[2..];
387 let dot: i32 = w[..QK]
388 .iter()
389 .zip(&x_i8[x_off..x_off + QK])
390 .map(|(&wi, &xi)| wi as i8 as i32 * xi as i32)
391 .sum();
392 acc += w_scale * x_scale * dot as f32;
393 x_off += QK;
394 }
395 acc
396}
397
398pub fn dot_q8_0_row_f32(row_blocks: &[u8], x: &[f32]) -> f32 {
403 #[cfg(target_arch = "x86_64")]
404 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
405 return unsafe { dot_q8_0_row_avx2(row_blocks, x) };
406 }
407 let k = x.len();
408 debug_assert_eq!(k % QK, 0);
409 let mut acc = 0.0f32;
410 for (b, chunk) in row_blocks.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
411 acc += dot_q8_0_block_f32(chunk, &x[b * QK..b * QK + QK]);
412 }
413 acc
414}
415
416#[cfg(target_arch = "x86_64")]
421#[target_feature(enable = "avx2,fma")]
422unsafe fn dot_q8_0_row_avx2(row_blocks: &[u8], x: &[f32]) -> f32 {
423 let k = x.len();
424 debug_assert_eq!(k % QK, 0);
425 let mut row_acc = _mm256_setzero_ps();
426
427 for (b, block) in row_blocks.chunks_exact(Q8_0_BLOCK_BYTES).enumerate() {
428 let scale = half::f16::from_le_bytes([block[0], block[1]]).to_f32();
429 let q_ptr = block.as_ptr().add(2) as *const i32; let xp = x.as_ptr().add(b * QK);
431 let mut block_acc = _mm256_setzero_ps();
432
433 for g in 0..4usize {
435 let q_i32_4 = _mm_loadu_si32(q_ptr.add(2 * g) as *const _); let q_i32_4b = _mm_loadu_si32(q_ptr.add(2 * g + 1) as *const _);
437 let q_a = _mm256_cvtepi8_epi32(q_i32_4); let q_b = _mm256_cvtepi8_epi32(q_i32_4b); let xv_a = _mm256_loadu_ps(xp.add(g * 8));
440 let xv_b = _mm256_loadu_ps(xp.add(g * 8 + 4));
441 let qf_a = _mm256_cvtepi32_ps(q_a);
442 let qf_b = _mm256_cvtepi32_ps(q_b);
443 block_acc = _mm256_fmadd_ps(qf_a, xv_a, block_acc);
444 block_acc = _mm256_fmadd_ps(qf_b, xv_b, block_acc);
445 }
446 let scale_v = _mm256_set1_ps(scale);
448 row_acc = _mm256_fmadd_ps(block_acc, scale_v, row_acc);
449 }
450
451 let lo = _mm256_castps256_ps128(row_acc);
453 let hi = _mm256_extractf128_ps(row_acc, 1);
454 let sum4 = _mm_add_ps(lo, hi);
455 let shuf = _mm_movehdup_ps(sum4);
456 let sum2 = _mm_add_ps(sum4, shuf);
457 let sum1 = _mm_add_ss(sum2, _mm_movehl_ps(shuf, sum2));
458 _mm_cvtss_f32(sum1)
459}
460
461pub const QK_K: usize = 256;
468pub const Q4_K_BLOCK_BYTES: usize = 144;
469pub const Q5_K_BLOCK_BYTES: usize = 176;
470pub const Q6_K_BLOCK_BYTES: usize = 210;
471
472#[inline(always)]
475fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
476 if j < 4 {
477 (scales[j] & 63, scales[j + 4] & 63)
478 } else {
479 (
480 (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4),
481 (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4),
482 )
483 }
484}
485
486pub fn dot_q4_k_row_f32(row_data: &[u8], x: &[f32]) -> f32 {
496 #[cfg(target_arch = "aarch64")]
497 return unsafe { dot_q4_k_row_f32_neon(row_data, x) };
498 #[cfg(not(target_arch = "aarch64"))]
499 dot_q4_k_row_f32_scalar(row_data, x)
500}
501
502#[allow(dead_code)]
504fn dot_q4_k_row_f32_scalar(row_data: &[u8], x: &[f32]) -> f32 {
505 let mut acc = 0.0f32;
506 let mut x_off = 0usize;
507 for block in row_data.chunks_exact(Q4_K_BLOCK_BYTES) {
508 let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
509 let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
510 let scales = &block[4..16];
511 let qs = &block[16..Q4_K_BLOCK_BYTES];
512 let mut q_off = 0usize;
513 let mut is = 0usize;
514 for _ in 0..(QK_K / 64) {
515 let (sc1, m1) = get_scale_min_k4(is, scales);
516 let d1 = d * sc1 as f32;
517 let m1v = dmin * m1 as f32;
518 let (sc2, m2) = get_scale_min_k4(is + 1, scales);
519 let d2 = d * sc2 as f32;
520 let m2v = dmin * m2 as f32;
521 for l in 0..32 {
522 acc += (d1 * (qs[q_off + l] & 0x0F) as f32 - m1v) * x[x_off + l];
523 acc += (d2 * (qs[q_off + l] >> 4) as f32 - m2v) * x[x_off + l + 32];
524 }
525 x_off += 64;
526 q_off += 32;
527 is += 2;
528 }
529 }
530 acc
531}
532
533#[cfg(target_arch = "aarch64")]
539#[target_feature(enable = "neon")]
540unsafe fn dot_q4_k_row_f32_neon(row_data: &[u8], x: &[f32]) -> f32 {
541 let mut acc = 0.0f32;
542 let mut x_off = 0usize;
543 let mask4 = vdup_n_u8(0x0F);
544
545 for block in row_data.chunks_exact(Q4_K_BLOCK_BYTES) {
546 let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
547 let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
548 let scales = &block[4..16];
549 let qs = &block[16..Q4_K_BLOCK_BYTES];
550
551 let mut q_off = 0usize;
552 let mut is = 0usize;
553
554 for _ in 0..(QK_K / 64) {
556 let (sc1, m1) = get_scale_min_k4(is, scales);
557 let (sc2, m2) = get_scale_min_k4(is + 1, scales);
558 let d1 = d * sc1 as f32;
559 let m1v = dmin * m1 as f32;
560 let d2 = d * sc2 as f32;
561 let m2v = dmin * m2 as f32;
562
563 let x_lo = &x[x_off..x_off + 32];
566 let x_hi = &x[x_off + 32..x_off + 64];
567
568 let mut vsum_lo = vdupq_n_f32(0.0); let mut vsum_hi = vdupq_n_f32(0.0); let mut vsum_xl = vdupq_n_f32(0.0); let mut vsum_xh = vdupq_n_f32(0.0); for chunk in 0..4usize {
576 let q8 = vld1_u8(qs.as_ptr().add(q_off + chunk * 8));
578 let lo8 = vand_u8(q8, mask4);
579 let hi8 = vshr_n_u8::<4>(q8);
580
581 let lo16 = vmovl_u8(lo8);
583 let lof0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16)));
584 let lof1 = vcvtq_f32_u32(vmovl_high_u16(lo16));
585
586 let hi16 = vmovl_u8(hi8);
587 let hif0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16)));
588 let hif1 = vcvtq_f32_u32(vmovl_high_u16(hi16));
589
590 let xl0 = vld1q_f32(x_lo.as_ptr().add(chunk * 8));
592 let xl1 = vld1q_f32(x_lo.as_ptr().add(chunk * 8 + 4));
593 let xh0 = vld1q_f32(x_hi.as_ptr().add(chunk * 8));
594 let xh1 = vld1q_f32(x_hi.as_ptr().add(chunk * 8 + 4));
595
596 vsum_lo = vfmaq_f32(vsum_lo, lof0, xl0);
597 vsum_lo = vfmaq_f32(vsum_lo, lof1, xl1);
598 vsum_hi = vfmaq_f32(vsum_hi, hif0, xh0);
599 vsum_hi = vfmaq_f32(vsum_hi, hif1, xh1);
600 vsum_xl = vaddq_f32(vsum_xl, vaddq_f32(xl0, xl1));
601 vsum_xh = vaddq_f32(vsum_xh, vaddq_f32(xh0, xh1));
602 }
603
604 acc += d1 * vaddvq_f32(vsum_lo) - m1v * vaddvq_f32(vsum_xl);
607 acc += d2 * vaddvq_f32(vsum_hi) - m2v * vaddvq_f32(vsum_xh);
608
609 x_off += 64;
610 q_off += 32;
611 is += 2;
612 }
613 }
614 acc
615}
616
617pub fn dot_q5_k_row_f32(row_data: &[u8], x: &[f32]) -> f32 {
624 let mut acc = 0.0f32;
625 let mut x_off = 0usize;
626 for block in row_data.chunks_exact(Q5_K_BLOCK_BYTES) {
627 let d = f16::from_le_bytes([block[0], block[1]]).to_f32();
628 let dmin = f16::from_le_bytes([block[2], block[3]]).to_f32();
629 let scales = &block[4..16];
630 let qh = &block[16..48];
631 let ql = &block[48..Q5_K_BLOCK_BYTES];
632 let mut ql_off = 0usize;
633 let mut is = 0usize;
634 let mut u1: u8 = 1;
635 let mut u2: u8 = 2;
636 for _ in 0..(QK_K / 64) {
637 let (sc1, m1) = get_scale_min_k4(is, scales);
638 let d1 = d * sc1 as f32;
639 let m1v = dmin * m1 as f32;
640 let (sc2, m2) = get_scale_min_k4(is + 1, scales);
641 let d2 = d * sc2 as f32;
642 let m2v = dmin * m2 as f32;
643 let qh_byte = qh[is / 8];
644 for l in 0..32 {
645 let hi1 = if qh_byte & u1 != 0 { 16.0f32 } else { 0.0 };
646 let hi2 = if qh_byte & u2 != 0 { 16.0f32 } else { 0.0 };
647 acc += (d1 * ((ql[ql_off + l] & 0x0F) as f32 + hi1) - m1v) * x[x_off + l];
648 acc += (d2 * ((ql[ql_off + l] >> 4) as f32 + hi2) - m2v) * x[x_off + l + 32];
649 }
650 x_off += 64;
651 ql_off += 32;
652 is += 2;
653 if is % 8 == 0 {
654 u1 = 1;
655 u2 = 2;
656 } else {
657 u1 <<= 2;
658 u2 <<= 2;
659 }
660 }
661 }
662 acc
663}
664
665pub fn dot_q6_k_row_f32(row_data: &[u8], x: &[f32]) -> f32 {
673 let mut acc = 0.0f32;
674 let mut x_off = 0usize;
675 for block in row_data.chunks_exact(Q6_K_BLOCK_BYTES) {
676 let ql = &block[0..128];
677 let qh = &block[128..192];
678 let sc = &block[192..208];
679 let d = f16::from_le_bytes([block[208], block[209]]).to_f32();
680 let mut ql_off = 0usize;
681 let mut qh_off = 0usize;
682 let mut ib = 0usize;
683 for _ in 0..(QK_K / 128) {
684 for l in 0..32 {
685 let q1 =
686 (((ql[ql_off + l] & 0x0F) | ((qh[qh_off + l] & 3) << 4)) as i32 - 32) as f32;
687 let q2 = (((ql[ql_off + l + 32] & 0x0F) | (((qh[qh_off + l] >> 2) & 3) << 4))
688 as i32
689 - 32) as f32;
690 let q3 = (((ql[ql_off + l] >> 4) | (((qh[qh_off + l] >> 4) & 3) << 4)) as i32 - 32)
691 as f32;
692 let q4 = (((ql[ql_off + l + 32] >> 4) | (((qh[qh_off + l] >> 6) & 3) << 4)) as i32
693 - 32) as f32;
694 acc += d * sc[ib] as i8 as f32 * q1 * x[x_off + l];
695 acc += d * sc[ib + 1] as i8 as f32 * q2 * x[x_off + l + 32];
696 acc += d * sc[ib + 2] as i8 as f32 * q3 * x[x_off + l + 64];
697 acc += d * sc[ib + 3] as i8 as f32 * q4 * x[x_off + l + 96];
698 }
699 x_off += 128;
700 ql_off += 64;
701 qh_off += 32;
702 ib += 4;
703 }
704 }
705 acc
706}
707
708#[cfg(test)]
709mod tests {
710 use super::*;
711
712 fn seq(n: usize) -> Vec<f32> {
714 let mut s: u64 = 0x9E3779B97F4A7C15;
715 (0..n)
716 .map(|_| {
717 s ^= s << 13;
718 s ^= s >> 7;
719 s ^= s << 17;
720 ((s >> 40) as f32 / (1u32 << 24) as f32) * 2.0 - 1.0
721 })
722 .collect()
723 }
724
725 #[test]
726 fn q4_0_on_the_fly_dot_matches_dequantized_reference() {
727 let k = 256;
728 let w = seq(k);
729 let x = seq(k).iter().map(|v| v * 0.5).collect::<Vec<_>>();
730
731 let blocks = quantize_q4_0_row(&w);
732 assert_eq!(blocks.len(), k / QK * Q4_0_BLOCK_BYTES);
734
735 let mut w_hat = vec![0.0f32; k];
737 for (b, chunk) in blocks.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
738 dequantize_q4_0_block(chunk, &mut w_hat[b * QK..b * QK + QK]);
739 }
740 let reference: f32 = w_hat.iter().zip(&x).map(|(a, b)| a * b).sum();
741
742 let on_the_fly = dot_q4_0_row_f32(&blocks, &x);
745 assert!(
746 (on_the_fly - reference).abs() < 1e-3,
747 "on-the-fly {on_the_fly} vs reference {reference}"
748 );
749 }
750
751 #[test]
752 fn q4_0_quantization_error_is_bounded() {
753 let w = seq(QK * 4);
755 let blocks = quantize_q4_0_row(&w);
756 let mut w_hat = vec![0.0f32; w.len()];
757 for (b, chunk) in blocks.chunks_exact(Q4_0_BLOCK_BYTES).enumerate() {
758 dequantize_q4_0_block(chunk, &mut w_hat[b * QK..b * QK + QK]);
759 }
760 let max_err = w
763 .iter()
764 .zip(&w_hat)
765 .map(|(a, b)| (a - b).abs())
766 .fold(0.0f32, f32::max);
767 assert!(max_err < 0.2, "max quant error {max_err} too large");
768 }
769}