Skip to main content

ruvector_temporal_tensor/
quantizer.rs

1//! Groupwise symmetric quantization with f16 scales.
2//!
3//! For each group of `group_len` values:
4//! - `scale = max(|v_i|) / qmax`
5//! - `q_i = round(v_i / scale)`, clamped to `[-qmax, +qmax]`
6//! - `u_i = q_i + qmax` (bias to unsigned for packing)
7
8use crate::bitpack::qmax_from_bits;
9use crate::f16;
10
11/// Compute f16 group scales for a frame.
12///
13/// Returns one f16-encoded scale per group of `group_len` elements.
14/// Each scale is `max(|v|) / qmax` for that group, stored as IEEE 754 half-precision.
15#[inline]
16pub fn compute_scales(frame: &[f32], group_len: usize, bits: u8) -> Vec<u16> {
17    let qmax = qmax_from_bits(bits);
18    if qmax == 0 {
19        return Vec::new();
20    }
21    let qmax_f = qmax as f32;
22    let num_groups = frame.len().div_ceil(group_len);
23    let mut scales = Vec::with_capacity(num_groups);
24
25    for chunk in frame.chunks(group_len) {
26        let mut max_abs = 0.0f32;
27        for &v in chunk {
28            if v.is_finite() {
29                let a = v.abs();
30                if a > max_abs {
31                    max_abs = a;
32                }
33            }
34        }
35
36        let scale = if max_abs == 0.0 {
37            0.0
38        } else {
39            max_abs / qmax_f
40        };
41        scales.push(f16::f32_to_f16_bits(scale));
42    }
43
44    scales
45}
46
47/// Pre-convert f16 scales to f32 for hot-path use.
48#[inline]
49pub fn scales_to_f32(scales_f16: &[u16]) -> Vec<f32> {
50    scales_f16
51        .iter()
52        .map(|&s| f16::f16_bits_to_f32(s))
53        .collect()
54}
55
56/// Check if a frame fits within existing scales (within drift tolerance).
57///
58/// Uses pre-converted f32 scales to avoid repeated f16 conversion.
59/// Returns `false` if any group's max absolute value exceeds
60/// `scale * qmax * drift_factor`.
61pub fn frame_fits_scales_f32(
62    frame: &[f32],
63    scales_f32: &[f32],
64    group_len: usize,
65    bits: u8,
66    drift_factor: f32,
67) -> bool {
68    let qmax = qmax_from_bits(bits);
69    if qmax == 0 || scales_f32.is_empty() {
70        return false;
71    }
72    let qmax_f = qmax as f32;
73
74    for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
75        if group_idx >= scales_f32.len() {
76            return false;
77        }
78        let allowed = scales_f32[group_idx] * qmax_f * drift_factor;
79
80        for &v in chunk {
81            if v.is_finite() && v.abs() > allowed {
82                return false;
83            }
84        }
85    }
86
87    true
88}
89
90/// Quantize a frame using pre-computed f32 scales and pack into bitstream.
91///
92/// Appends packed bytes to `out`. Pre-reserves the expected output size
93/// to avoid reallocations.
94///
95/// For 8-bit quantization, writes bytes directly without bit accumulation
96/// since each quantized value maps 1:1 to a u8.
97#[inline]
98pub fn quantize_and_pack_f32(
99    frame: &[f32],
100    scales_f32: &[f32],
101    group_len: usize,
102    bits: u8,
103    out: &mut Vec<u8>,
104) {
105    let qmax = qmax_from_bits(bits);
106    if qmax == 0 {
107        return;
108    }
109
110    // Fast path: 8-bit quantization writes bytes directly, no bit accumulator.
111    if bits == 8 {
112        out.reserve(frame.len());
113        for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
114            let scale = if group_idx < scales_f32.len() {
115                scales_f32[group_idx]
116            } else {
117                0.0
118            };
119            let inv_scale = if scale == 0.0 { 0.0 } else { 1.0 / scale };
120            for &v in chunk {
121                let mut q: i32 = 0;
122                if v.is_finite() {
123                    let scaled = v * inv_scale;
124                    q = if scaled >= 0.0 {
125                        (scaled + 0.5) as i32
126                    } else {
127                        (scaled - 0.5) as i32
128                    };
129                    q = q.clamp(-127, 127);
130                }
131                out.push((q + 127) as u8);
132            }
133        }
134        return;
135    }
136
137    // Fast path: 5-bit quantization packs 8 values into 5 bytes.
138    // 8 values * 5 bits = 40 bits = 5 bytes exactly, avoiding the bit accumulator.
139    // LSB-first packing layout for 8 values in 5 bytes:
140    //   byte0 = v0 | (v1 << 5)
141    //   byte1 = (v1 >> 3) | (v2 << 2) | (v3 << 7)
142    //   byte2 = (v3 >> 1) | (v4 << 4)
143    //   byte3 = (v4 >> 4) | (v5 << 1) | (v6 << 6)
144    //   byte4 = (v6 >> 2) | (v7 << 3)
145    #[inline]
146    fn pack_5bit_group(chunk: &[f32], inv_scale: f32, out: &mut Vec<u8>) {
147        let quantize = |v: f32| -> u32 {
148            let mut q: i32 = 0;
149            if v.is_finite() {
150                let scaled = v * inv_scale;
151                q = if scaled >= 0.0 {
152                    (scaled + 0.5) as i32
153                } else {
154                    (scaled - 0.5) as i32
155                };
156                q = q.clamp(-15, 15);
157            }
158            (q + 15) as u32
159        };
160        let v0 = quantize(chunk[0]);
161        let v1 = quantize(chunk[1]);
162        let v2 = quantize(chunk[2]);
163        let v3 = quantize(chunk[3]);
164        let v4 = quantize(chunk[4]);
165        let v5 = quantize(chunk[5]);
166        let v6 = quantize(chunk[6]);
167        let v7 = quantize(chunk[7]);
168
169        out.push((v0 | (v1 << 5)) as u8);
170        out.push(((v1 >> 3) | (v2 << 2) | (v3 << 7)) as u8);
171        out.push(((v3 >> 1) | (v4 << 4)) as u8);
172        out.push(((v4 >> 4) | (v5 << 1) | (v6 << 6)) as u8);
173        out.push(((v6 >> 2) | (v7 << 3)) as u8);
174    }
175    if bits == 5 {
176        let needed_bytes = (frame.len() * 5).div_ceil(8);
177        out.reserve(needed_bytes);
178
179        let mut acc: u64 = 0;
180        let mut acc_bits: u32 = 0;
181
182        for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
183            let scale = if group_idx < scales_f32.len() {
184                scales_f32[group_idx]
185            } else {
186                0.0
187            };
188            let inv_scale = if scale == 0.0 { 0.0 } else { 1.0 / scale };
189
190            let mut i = 0;
191            // Process 8 values at a time into 5 bytes when byte-aligned
192            while acc_bits == 0 && i + 8 <= chunk.len() {
193                pack_5bit_group(&chunk[i..i + 8], inv_scale, out);
194                i += 8;
195            }
196            // Remainder (or misaligned) with bit accumulator
197            while i < chunk.len() {
198                let mut q: i32 = 0;
199                if chunk[i].is_finite() {
200                    let scaled = chunk[i] * inv_scale;
201                    q = if scaled >= 0.0 {
202                        (scaled + 0.5) as i32
203                    } else {
204                        (scaled - 0.5) as i32
205                    };
206                    q = q.clamp(-15, 15);
207                }
208                let u = (q + 15) as u32;
209                acc |= (u as u64) << acc_bits;
210                acc_bits += 5;
211                while acc_bits >= 8 {
212                    out.push((acc & 0xFF) as u8);
213                    acc >>= 8;
214                    acc_bits -= 8;
215                }
216                i += 1;
217            }
218        }
219
220        if acc_bits > 0 {
221            out.push((acc & 0xFF) as u8);
222        }
223        return;
224    }
225
226    // Generic path for sub-byte bit widths.
227    let qmax_i = qmax;
228    let bias = qmax;
229    let bits_u32 = bits as u32;
230
231    let needed_bytes = (frame.len() * bits as usize).div_ceil(8);
232    out.reserve(needed_bytes);
233
234    let mut acc: u64 = 0;
235    let mut acc_bits: u32 = 0;
236
237    for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
238        let scale = if group_idx < scales_f32.len() {
239            scales_f32[group_idx]
240        } else {
241            0.0
242        };
243        let inv_scale = if scale == 0.0 { 0.0 } else { 1.0 / scale };
244
245        for &v in chunk {
246            let mut q: i32 = 0;
247            if v.is_finite() {
248                let scaled = v * inv_scale;
249                q = if scaled >= 0.0 {
250                    (scaled + 0.5) as i32
251                } else {
252                    (scaled - 0.5) as i32
253                };
254                q = q.clamp(-qmax_i, qmax_i);
255            }
256
257            let u = (q + bias) as u32;
258            acc |= (u as u64) << acc_bits;
259            acc_bits += bits_u32;
260
261            while acc_bits >= 8 {
262                out.push((acc & 0xFF) as u8);
263                acc >>= 8;
264                acc_bits -= 8;
265            }
266        }
267    }
268
269    if acc_bits > 0 {
270        out.push((acc & 0xFF) as u8);
271    }
272}
273
274/// Dequantize packed codes using f32 scales, writing f32 values.
275///
276/// Iterates by frame then by group to avoid per-value modulo/division
277/// and caches the f32 scale per group.
278///
279/// For 8-bit data, reads bytes directly without bit accumulation.
280#[inline]
281pub fn dequantize_f32(
282    data: &[u8],
283    scales_f32: &[f32],
284    group_len: usize,
285    bits: u8,
286    tensor_len: usize,
287    frame_count: usize,
288    out: &mut Vec<f32>,
289) {
290    let qmax = qmax_from_bits(bits);
291    if qmax == 0 {
292        return;
293    }
294
295    let total = tensor_len * frame_count;
296    out.resize(total, 0.0);
297
298    // Fast path: 8-bit dequantization reads bytes directly, no bit accumulator.
299    if bits == 8 {
300        let mut out_idx = 0usize;
301        let mut byte_idx = 0usize;
302        for _frame in 0..frame_count {
303            let mut pos = 0usize;
304            let mut group_idx = 0usize;
305            while pos < tensor_len {
306                let group_end = (pos + group_len).min(tensor_len);
307                let scale = if group_idx < scales_f32.len() {
308                    scales_f32[group_idx]
309                } else {
310                    0.0
311                };
312                while pos < group_end && byte_idx < data.len() {
313                    let u = data[byte_idx] as i32;
314                    let q = u - 127;
315                    out[out_idx] = (q as f32) * scale;
316                    out_idx += 1;
317                    byte_idx += 1;
318                    pos += 1;
319                }
320                group_idx += 1;
321            }
322        }
323        return;
324    }
325
326    // Fast path: 3-bit dequantization processes 8 values from 3 bytes.
327    // 8 values * 3 bits = 24 bits = 3 bytes exactly, avoiding the bit accumulator.
328    // LSB-first packing layout for 8 values in 3 bytes:
329    //   byte0 = v0 | (v1 << 3) | ((v2 & 0x3) << 6)
330    //   byte1 = (v2 >> 2) | (v3 << 1) | (v4 << 4) | ((v5 & 0x1) << 7)
331    //   byte2 = (v5 >> 1) | (v6 << 2) | (v7 << 5)
332    if bits == 3 {
333        let bias = 3i32; // qmax for 3-bit
334        let mut out_idx = 0usize;
335        let mut byte_idx = 0usize;
336        for _frame in 0..frame_count {
337            let mut pos = 0usize;
338            let mut group_idx = 0usize;
339            while pos < tensor_len {
340                let group_end = (pos + group_len).min(tensor_len);
341                let scale = if group_idx < scales_f32.len() {
342                    scales_f32[group_idx]
343                } else {
344                    0.0
345                };
346                // Process 8 values at a time from 3 bytes
347                while pos + 8 <= group_end && byte_idx + 3 <= data.len() {
348                    let b0 = data[byte_idx] as u32;
349                    let b1 = data[byte_idx + 1] as u32;
350                    let b2 = data[byte_idx + 2] as u32;
351                    byte_idx += 3;
352
353                    out[out_idx] = ((b0 & 0x7) as i32 - bias) as f32 * scale;
354                    out[out_idx + 1] = (((b0 >> 3) & 0x7) as i32 - bias) as f32 * scale;
355                    out[out_idx + 2] =
356                        ((((b0 >> 6) | (b1 << 2)) & 0x7) as i32 - bias) as f32 * scale;
357                    out[out_idx + 3] = (((b1 >> 1) & 0x7) as i32 - bias) as f32 * scale;
358                    out[out_idx + 4] = (((b1 >> 4) & 0x7) as i32 - bias) as f32 * scale;
359                    out[out_idx + 5] =
360                        ((((b1 >> 7) | (b2 << 1)) & 0x7) as i32 - bias) as f32 * scale;
361                    out[out_idx + 6] = (((b2 >> 2) & 0x7) as i32 - bias) as f32 * scale;
362                    out[out_idx + 7] = (((b2 >> 5) & 0x7) as i32 - bias) as f32 * scale;
363                    out_idx += 8;
364                    pos += 8;
365                }
366                // Handle remaining values (< 8) with a local bit accumulator
367                if pos < group_end {
368                    let remaining = group_end - pos;
369                    let mut acc: u64 = 0;
370                    let mut acc_bits: u32 = 0;
371                    while acc_bits < (remaining as u32) * 3 && byte_idx < data.len() {
372                        acc |= (data[byte_idx] as u64) << acc_bits;
373                        acc_bits += 8;
374                        byte_idx += 1;
375                    }
376                    for _ in 0..remaining {
377                        if acc_bits < 3 {
378                            break;
379                        }
380                        let u = (acc & 0x7) as i32;
381                        acc >>= 3;
382                        acc_bits -= 3;
383                        out[out_idx] = (u - bias) as f32 * scale;
384                        out_idx += 1;
385                        pos += 1;
386                    }
387                }
388                group_idx += 1;
389            }
390        }
391        return;
392    }
393
394    // Fast path: 7-bit dequantization processes 8 values from 7 bytes.
395    // 8 values * 7 bits = 56 bits = 7 bytes exactly, avoiding the bit accumulator.
396    // LSB-first packing layout for 8 values in 7 bytes:
397    //   v0 = b0 & 0x7F
398    //   v1 = ((b0 >> 7) | (b1 << 1)) & 0x7F
399    //   v2 = ((b1 >> 6) | (b2 << 2)) & 0x7F
400    //   v3 = ((b2 >> 5) | (b3 << 3)) & 0x7F
401    //   v4 = ((b3 >> 4) | (b4 << 4)) & 0x7F
402    //   v5 = ((b4 >> 3) | (b5 << 5)) & 0x7F
403    //   v6 = ((b5 >> 2) | (b6 << 6)) & 0x7F
404    //   v7 = (b6 >> 1) & 0x7F
405    if bits == 7 {
406        let bias = 63i32; // qmax for 7-bit
407        let mut out_idx = 0usize;
408        let mut byte_idx = 0usize;
409        for _frame in 0..frame_count {
410            let mut pos = 0usize;
411            let mut group_idx = 0usize;
412            while pos < tensor_len {
413                let group_end = (pos + group_len).min(tensor_len);
414                let scale = if group_idx < scales_f32.len() {
415                    scales_f32[group_idx]
416                } else {
417                    0.0
418                };
419                // Process 8 values at a time from 7 bytes
420                #[inline]
421                fn unpack_7bit(
422                    out: &mut [f32],
423                    out_idx: usize,
424                    data: &[u8],
425                    byte_idx: usize,
426                    bias: i32,
427                    scale: f32,
428                ) {
429                    let b0 = data[byte_idx] as u32;
430                    let b1 = data[byte_idx + 1] as u32;
431                    let b2 = data[byte_idx + 2] as u32;
432                    let b3 = data[byte_idx + 3] as u32;
433                    let b4 = data[byte_idx + 4] as u32;
434                    let b5 = data[byte_idx + 5] as u32;
435                    let b6 = data[byte_idx + 6] as u32;
436
437                    out[out_idx] = ((b0 & 0x7F) as i32 - bias) as f32 * scale;
438                    out[out_idx + 1] =
439                        ((((b0 >> 7) | (b1 << 1)) & 0x7F) as i32 - bias) as f32 * scale;
440                    out[out_idx + 2] =
441                        ((((b1 >> 6) | (b2 << 2)) & 0x7F) as i32 - bias) as f32 * scale;
442                    out[out_idx + 3] =
443                        ((((b2 >> 5) | (b3 << 3)) & 0x7F) as i32 - bias) as f32 * scale;
444                    out[out_idx + 4] =
445                        ((((b3 >> 4) | (b4 << 4)) & 0x7F) as i32 - bias) as f32 * scale;
446                    out[out_idx + 5] =
447                        ((((b4 >> 3) | (b5 << 5)) & 0x7F) as i32 - bias) as f32 * scale;
448                    out[out_idx + 6] =
449                        ((((b5 >> 2) | (b6 << 6)) & 0x7F) as i32 - bias) as f32 * scale;
450                    out[out_idx + 7] = (((b6 >> 1) & 0x7F) as i32 - bias) as f32 * scale;
451                }
452                while pos + 8 <= group_end && byte_idx + 7 <= data.len() {
453                    unpack_7bit(out, out_idx, data, byte_idx, bias, scale);
454                    byte_idx += 7;
455                    out_idx += 8;
456                    pos += 8;
457                }
458                // Handle remaining values (< 8) with a local bit accumulator
459                if pos < group_end {
460                    let remaining = group_end - pos;
461                    let mut acc: u64 = 0;
462                    let mut acc_bits: u32 = 0;
463                    while acc_bits < (remaining as u32) * 7 && byte_idx < data.len() {
464                        acc |= (data[byte_idx] as u64) << acc_bits;
465                        acc_bits += 8;
466                        byte_idx += 1;
467                    }
468                    for _ in 0..remaining {
469                        if acc_bits < 7 {
470                            break;
471                        }
472                        let u = (acc & 0x7F) as i32;
473                        acc >>= 7;
474                        acc_bits -= 7;
475                        out[out_idx] = (u - bias) as f32 * scale;
476                        out_idx += 1;
477                        pos += 1;
478                    }
479                }
480                group_idx += 1;
481            }
482        }
483        return;
484    }
485
486    // Fast path: 5-bit dequantization processes 8 values from 5 bytes.
487    // 8 values * 5 bits = 40 bits = 5 bytes exactly, avoiding the bit accumulator.
488    // LSB-first packing layout for 8 values in 5 bytes:
489    //   v0 = b0 & 0x1F
490    //   v1 = ((b0 >> 5) | (b1 << 3)) & 0x1F
491    //   v2 = (b1 >> 2) & 0x1F
492    //   v3 = ((b1 >> 7) | (b2 << 1)) & 0x1F
493    //   v4 = ((b2 >> 4) | (b3 << 4)) & 0x1F
494    //   v5 = (b3 >> 1) & 0x1F
495    //   v6 = ((b3 >> 6) | (b4 << 2)) & 0x1F
496    //   v7 = (b4 >> 3) & 0x1F
497    if bits == 5 {
498        let bias = 15i32; // qmax for 5-bit
499        let mut out_idx = 0usize;
500        let mut byte_idx = 0usize;
501        for _frame in 0..frame_count {
502            let mut pos = 0usize;
503            let mut group_idx = 0usize;
504            while pos < tensor_len {
505                let group_end = (pos + group_len).min(tensor_len);
506                let scale = if group_idx < scales_f32.len() {
507                    scales_f32[group_idx]
508                } else {
509                    0.0
510                };
511                // Process 8 values at a time from 5 bytes
512                #[inline]
513                fn unpack_5bit(
514                    out: &mut [f32],
515                    out_idx: usize,
516                    data: &[u8],
517                    byte_idx: usize,
518                    bias: i32,
519                    scale: f32,
520                ) {
521                    let b0 = data[byte_idx] as u32;
522                    let b1 = data[byte_idx + 1] as u32;
523                    let b2 = data[byte_idx + 2] as u32;
524                    let b3 = data[byte_idx + 3] as u32;
525                    let b4 = data[byte_idx + 4] as u32;
526
527                    out[out_idx] = ((b0 & 0x1F) as i32 - bias) as f32 * scale;
528                    out[out_idx + 1] =
529                        ((((b0 >> 5) | (b1 << 3)) & 0x1F) as i32 - bias) as f32 * scale;
530                    out[out_idx + 2] = (((b1 >> 2) & 0x1F) as i32 - bias) as f32 * scale;
531                    out[out_idx + 3] =
532                        ((((b1 >> 7) | (b2 << 1)) & 0x1F) as i32 - bias) as f32 * scale;
533                    out[out_idx + 4] =
534                        ((((b2 >> 4) | (b3 << 4)) & 0x1F) as i32 - bias) as f32 * scale;
535                    out[out_idx + 5] = (((b3 >> 1) & 0x1F) as i32 - bias) as f32 * scale;
536                    out[out_idx + 6] =
537                        ((((b3 >> 6) | (b4 << 2)) & 0x1F) as i32 - bias) as f32 * scale;
538                    out[out_idx + 7] = (((b4 >> 3) & 0x1F) as i32 - bias) as f32 * scale;
539                }
540                while pos + 8 <= group_end && byte_idx + 5 <= data.len() {
541                    unpack_5bit(out, out_idx, data, byte_idx, bias, scale);
542                    byte_idx += 5;
543                    out_idx += 8;
544                    pos += 8;
545                }
546                // Handle remaining values (< 8) with a local bit accumulator
547                if pos < group_end {
548                    let remaining = group_end - pos;
549                    let mut acc: u64 = 0;
550                    let mut acc_bits: u32 = 0;
551                    while acc_bits < (remaining as u32) * 5 && byte_idx < data.len() {
552                        acc |= (data[byte_idx] as u64) << acc_bits;
553                        acc_bits += 8;
554                        byte_idx += 1;
555                    }
556                    for _ in 0..remaining {
557                        if acc_bits < 5 {
558                            break;
559                        }
560                        let u = (acc & 0x1F) as i32;
561                        acc >>= 5;
562                        acc_bits -= 5;
563                        out[out_idx] = (u - bias) as f32 * scale;
564                        out_idx += 1;
565                        pos += 1;
566                    }
567                }
568                group_idx += 1;
569            }
570        }
571        return;
572    }
573
574    // Generic path for sub-byte bit widths.
575    let bias = qmax;
576    let bits_u32 = bits as u32;
577    let mask = (1u64 << bits_u32) - 1;
578
579    let mut acc: u64 = 0;
580    let mut acc_bits: u32 = 0;
581    let mut byte_idx = 0usize;
582    let mut out_idx = 0usize;
583
584    for _frame in 0..frame_count {
585        let mut pos = 0usize;
586        let mut group_idx = 0usize;
587
588        while pos < tensor_len {
589            let group_end = (pos + group_len).min(tensor_len);
590            let scale = if group_idx < scales_f32.len() {
591                scales_f32[group_idx]
592            } else {
593                0.0
594            };
595
596            while pos < group_end {
597                while acc_bits < bits_u32 && byte_idx < data.len() {
598                    acc |= (data[byte_idx] as u64) << acc_bits;
599                    acc_bits += 8;
600                    byte_idx += 1;
601                }
602                if acc_bits < bits_u32 {
603                    return;
604                }
605
606                let u = (acc & mask) as u32;
607                acc >>= bits_u32;
608                acc_bits -= bits_u32;
609
610                let q = (u as i32) - bias;
611                out[out_idx] = (q as f32) * scale;
612                out_idx += 1;
613                pos += 1;
614            }
615
616            group_idx += 1;
617        }
618    }
619}
620
621// --- Legacy API (delegates to f32 variants) ---
622
623/// Check if a frame fits within existing f16 scales (within drift tolerance).
624pub fn frame_fits_scales(
625    frame: &[f32],
626    scales: &[u16],
627    group_len: usize,
628    bits: u8,
629    drift_factor: f32,
630) -> bool {
631    let scales_f32 = scales_to_f32(scales);
632    frame_fits_scales_f32(frame, &scales_f32, group_len, bits, drift_factor)
633}
634
635/// Quantize a frame using pre-computed f16 scales and pack into bitstream.
636pub fn quantize_and_pack(
637    frame: &[f32],
638    scales: &[u16],
639    group_len: usize,
640    bits: u8,
641    out: &mut Vec<u8>,
642) {
643    let scales_f32 = scales_to_f32(scales);
644    quantize_and_pack_f32(frame, &scales_f32, group_len, bits, out)
645}
646
647/// Dequantize packed codes using f16 scales, writing f32 values.
648pub fn dequantize(
649    data: &[u8],
650    scales: &[u16],
651    group_len: usize,
652    bits: u8,
653    tensor_len: usize,
654    frame_count: usize,
655    out: &mut Vec<f32>,
656) {
657    let scales_f32 = scales_to_f32(scales);
658    dequantize_f32(
659        data,
660        &scales_f32,
661        group_len,
662        bits,
663        tensor_len,
664        frame_count,
665        out,
666    )
667}
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672
673    #[test]
674    fn test_quantize_roundtrip_8bit() {
675        let frame: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.1).collect();
676        let scales = compute_scales(&frame, 64, 8);
677        let mut packed = Vec::new();
678        quantize_and_pack(&frame, &scales, 64, 8, &mut packed);
679
680        let mut decoded = Vec::new();
681        dequantize(&packed, &scales, 64, 8, frame.len(), 1, &mut decoded);
682
683        assert_eq!(decoded.len(), frame.len());
684        for (i, (&orig, &dec)) in frame.iter().zip(decoded.iter()).enumerate() {
685            let err = (orig - dec).abs();
686            let max_err = if orig.abs() > 0.01 {
687                orig.abs() * 0.02
688            } else {
689                0.1
690            };
691            assert!(err < max_err, "i={i}, orig={orig}, dec={dec}, err={err}");
692        }
693    }
694
695    #[test]
696    fn test_quantize_roundtrip_3bit() {
697        let frame: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.5).collect();
698        let scales = compute_scales(&frame, 64, 3);
699        let mut packed = Vec::new();
700        quantize_and_pack(&frame, &scales, 64, 3, &mut packed);
701
702        let mut decoded = Vec::new();
703        dequantize(&packed, &scales, 64, 3, frame.len(), 1, &mut decoded);
704
705        let max_val = frame.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
706        for (&orig, &dec) in frame.iter().zip(decoded.iter()) {
707            let err = (orig - dec).abs();
708            assert!(err < max_val * 0.35, "orig={orig}, dec={dec}, err={err}");
709        }
710    }
711
712    #[test]
713    fn test_quantize_roundtrip_5bit() {
714        let frame: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.05).collect();
715        let scales = compute_scales(&frame, 64, 5);
716        let mut packed = Vec::new();
717        quantize_and_pack(&frame, &scales, 64, 5, &mut packed);
718
719        let mut decoded = Vec::new();
720        dequantize(&packed, &scales, 64, 5, frame.len(), 1, &mut decoded);
721
722        let max_val = frame.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
723        for (&orig, &dec) in frame.iter().zip(decoded.iter()) {
724            let err = (orig - dec).abs();
725            assert!(err < max_val * 0.08, "orig={orig}, dec={dec}, err={err}");
726        }
727    }
728
729    #[test]
730    fn test_quantize_roundtrip_7bit() {
731        let frame: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.05).collect();
732        let scales = compute_scales(&frame, 64, 7);
733        let mut packed = Vec::new();
734        quantize_and_pack(&frame, &scales, 64, 7, &mut packed);
735
736        let mut decoded = Vec::new();
737        dequantize(&packed, &scales, 64, 7, frame.len(), 1, &mut decoded);
738
739        for (i, (&orig, &dec)) in frame.iter().zip(decoded.iter()).enumerate() {
740            let err = (orig - dec).abs();
741            let max_err = if orig.abs() > 0.01 {
742                orig.abs() * 0.02
743            } else {
744                0.1
745            };
746            assert!(err < max_err, "i={i}, orig={orig}, dec={dec}, err={err}");
747        }
748    }
749
750    #[test]
751    fn test_drift_detection() {
752        let frame1: Vec<f32> = vec![1.0; 64];
753        let frame2: Vec<f32> = vec![1.05; 64];
754        let frame3: Vec<f32> = vec![2.0; 64];
755
756        let scales = compute_scales(&frame1, 64, 8);
757        let drift_factor = 1.0 + 26.0 / 256.0;
758
759        assert!(frame_fits_scales(&frame2, &scales, 64, 8, drift_factor));
760        assert!(!frame_fits_scales(&frame3, &scales, 64, 8, drift_factor));
761    }
762
763    #[test]
764    fn test_zero_frame() {
765        let frame = vec![0.0f32; 128];
766        let scales = compute_scales(&frame, 64, 8);
767        let mut packed = Vec::new();
768        quantize_and_pack(&frame, &scales, 64, 8, &mut packed);
769
770        let mut decoded = Vec::new();
771        dequantize(&packed, &scales, 64, 8, 128, 1, &mut decoded);
772
773        for &v in &decoded {
774            assert_eq!(v, 0.0);
775        }
776    }
777
778    #[test]
779    fn test_non_finite_values() {
780        let mut frame = vec![1.0f32; 64];
781        frame[10] = f32::NAN;
782        frame[20] = f32::INFINITY;
783        frame[30] = f32::NEG_INFINITY;
784
785        let scales = compute_scales(&frame, 64, 8);
786        let mut packed = Vec::new();
787        quantize_and_pack(&frame, &scales, 64, 8, &mut packed);
788
789        let mut decoded = Vec::new();
790        dequantize(&packed, &scales, 64, 8, 64, 1, &mut decoded);
791
792        assert_eq!(decoded[10], 0.0);
793        assert_eq!(decoded[20], 0.0);
794        assert_eq!(decoded[30], 0.0);
795        assert!((decoded[0] - 1.0).abs() < 0.02);
796    }
797
798    #[test]
799    fn test_single_element_group() {
800        let frame = vec![3.14f32; 16];
801        let scales = compute_scales(&frame, 1, 8);
802        assert_eq!(scales.len(), 16);
803
804        let mut packed = Vec::new();
805        quantize_and_pack(&frame, &scales, 1, 8, &mut packed);
806
807        let mut decoded = Vec::new();
808        dequantize(&packed, &scales, 1, 8, 16, 1, &mut decoded);
809
810        for (i, &v) in decoded.iter().enumerate() {
811            let err = (v - 3.14).abs();
812            assert!(err < 0.03, "i={i} v={v} err={err}");
813        }
814    }
815
816    #[test]
817    fn test_compression_ratio() {
818        let frame = vec![1.0f32; 512];
819        for &(bits, min_ratio) in &[(8u8, 3.5f32), (7, 4.0), (5, 5.5), (3, 8.5)] {
820            let scales = compute_scales(&frame, 64, bits);
821            let mut packed = Vec::new();
822            quantize_and_pack(&frame, &scales, 64, bits, &mut packed);
823
824            let raw_bytes = frame.len() * 4;
825            let compressed = packed.len() + scales.len() * 2;
826            let ratio = raw_bytes as f32 / compressed as f32;
827
828            assert!(
829                ratio >= min_ratio,
830                "bits={bits}: ratio {ratio:.2}x < expected {min_ratio}x"
831            );
832        }
833    }
834}