Skip to main content

ruvector_temporal_tensor/
quantizer.rs

1//! Groupwise symmetric quantization with f16 scales.
2//!
3//! For each group of `group_len` values:
4//! - `scale = max(|v_i|) / qmax`
5//! - `q_i = round(v_i / scale)`, clamped to `[-qmax, +qmax]`
6//! - `u_i = q_i + qmax` (bias to unsigned for packing)
7
8use crate::bitpack::qmax_from_bits;
9use crate::f16;
10
11/// Compute f16 group scales for a frame.
12///
13/// Returns one f16-encoded scale per group of `group_len` elements.
14/// Each scale is `max(|v|) / qmax` for that group, stored as IEEE 754 half-precision.
15#[inline]
16pub fn compute_scales(frame: &[f32], group_len: usize, bits: u8) -> Vec<u16> {
17    let qmax = qmax_from_bits(bits);
18    if qmax == 0 {
19        return Vec::new();
20    }
21    let qmax_f = qmax as f32;
22    let num_groups = frame.len().div_ceil(group_len);
23    let mut scales = Vec::with_capacity(num_groups);
24
25    for chunk in frame.chunks(group_len) {
26        let mut max_abs = 0.0f32;
27        for &v in chunk {
28            if v.is_finite() {
29                let a = v.abs();
30                if a > max_abs {
31                    max_abs = a;
32                }
33            }
34        }
35
36        let scale = if max_abs == 0.0 { 0.0 } else { max_abs / qmax_f };
37        scales.push(f16::f32_to_f16_bits(scale));
38    }
39
40    scales
41}
42
43/// Pre-convert f16 scales to f32 for hot-path use.
44#[inline]
45pub fn scales_to_f32(scales_f16: &[u16]) -> Vec<f32> {
46    scales_f16.iter().map(|&s| f16::f16_bits_to_f32(s)).collect()
47}
48
49/// Check if a frame fits within existing scales (within drift tolerance).
50///
51/// Uses pre-converted f32 scales to avoid repeated f16 conversion.
52/// Returns `false` if any group's max absolute value exceeds
53/// `scale * qmax * drift_factor`.
54pub fn frame_fits_scales_f32(
55    frame: &[f32],
56    scales_f32: &[f32],
57    group_len: usize,
58    bits: u8,
59    drift_factor: f32,
60) -> bool {
61    let qmax = qmax_from_bits(bits);
62    if qmax == 0 || scales_f32.is_empty() {
63        return false;
64    }
65    let qmax_f = qmax as f32;
66
67    for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
68        if group_idx >= scales_f32.len() {
69            return false;
70        }
71        let allowed = scales_f32[group_idx] * qmax_f * drift_factor;
72
73        for &v in chunk {
74            if v.is_finite() && v.abs() > allowed {
75                return false;
76            }
77        }
78    }
79
80    true
81}
82
83/// Quantize a frame using pre-computed f32 scales and pack into bitstream.
84///
85/// Appends packed bytes to `out`. Pre-reserves the expected output size
86/// to avoid reallocations.
87///
88/// For 8-bit quantization, writes bytes directly without bit accumulation
89/// since each quantized value maps 1:1 to a u8.
90#[inline]
91pub fn quantize_and_pack_f32(
92    frame: &[f32],
93    scales_f32: &[f32],
94    group_len: usize,
95    bits: u8,
96    out: &mut Vec<u8>,
97) {
98    let qmax = qmax_from_bits(bits);
99    if qmax == 0 {
100        return;
101    }
102
103    // Fast path: 8-bit quantization writes bytes directly, no bit accumulator.
104    if bits == 8 {
105        out.reserve(frame.len());
106        for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
107            let scale = if group_idx < scales_f32.len() {
108                scales_f32[group_idx]
109            } else {
110                0.0
111            };
112            let inv_scale = if scale == 0.0 { 0.0 } else { 1.0 / scale };
113            for &v in chunk {
114                let mut q: i32 = 0;
115                if v.is_finite() {
116                    let scaled = v * inv_scale;
117                    q = if scaled >= 0.0 { (scaled + 0.5) as i32 } else { (scaled - 0.5) as i32 };
118                    q = q.clamp(-127, 127);
119                }
120                out.push((q + 127) as u8);
121            }
122        }
123        return;
124    }
125
126    // Fast path: 5-bit quantization packs 8 values into 5 bytes.
127    // 8 values * 5 bits = 40 bits = 5 bytes exactly, avoiding the bit accumulator.
128    // LSB-first packing layout for 8 values in 5 bytes:
129    //   byte0 = v0 | (v1 << 5)
130    //   byte1 = (v1 >> 3) | (v2 << 2) | (v3 << 7)
131    //   byte2 = (v3 >> 1) | (v4 << 4)
132    //   byte3 = (v4 >> 4) | (v5 << 1) | (v6 << 6)
133    //   byte4 = (v6 >> 2) | (v7 << 3)
134    #[inline]
135    fn pack_5bit_group(chunk: &[f32], inv_scale: f32, out: &mut Vec<u8>) {
136        let quantize = |v: f32| -> u32 {
137            let mut q: i32 = 0;
138            if v.is_finite() {
139                let scaled = v * inv_scale;
140                q = if scaled >= 0.0 {
141                    (scaled + 0.5) as i32
142                } else {
143                    (scaled - 0.5) as i32
144                };
145                q = q.clamp(-15, 15);
146            }
147            (q + 15) as u32
148        };
149        let v0 = quantize(chunk[0]);
150        let v1 = quantize(chunk[1]);
151        let v2 = quantize(chunk[2]);
152        let v3 = quantize(chunk[3]);
153        let v4 = quantize(chunk[4]);
154        let v5 = quantize(chunk[5]);
155        let v6 = quantize(chunk[6]);
156        let v7 = quantize(chunk[7]);
157
158        out.push((v0 | (v1 << 5)) as u8);
159        out.push(((v1 >> 3) | (v2 << 2) | (v3 << 7)) as u8);
160        out.push(((v3 >> 1) | (v4 << 4)) as u8);
161        out.push(((v4 >> 4) | (v5 << 1) | (v6 << 6)) as u8);
162        out.push(((v6 >> 2) | (v7 << 3)) as u8);
163    }
164    if bits == 5 {
165        let needed_bytes = (frame.len() * 5).div_ceil(8);
166        out.reserve(needed_bytes);
167
168        let mut acc: u64 = 0;
169        let mut acc_bits: u32 = 0;
170
171        for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
172            let scale = if group_idx < scales_f32.len() {
173                scales_f32[group_idx]
174            } else {
175                0.0
176            };
177            let inv_scale = if scale == 0.0 { 0.0 } else { 1.0 / scale };
178
179            let mut i = 0;
180            // Process 8 values at a time into 5 bytes when byte-aligned
181            while acc_bits == 0 && i + 8 <= chunk.len() {
182                pack_5bit_group(&chunk[i..i + 8], inv_scale, out);
183                i += 8;
184            }
185            // Remainder (or misaligned) with bit accumulator
186            while i < chunk.len() {
187                let mut q: i32 = 0;
188                if chunk[i].is_finite() {
189                    let scaled = chunk[i] * inv_scale;
190                    q = if scaled >= 0.0 {
191                        (scaled + 0.5) as i32
192                    } else {
193                        (scaled - 0.5) as i32
194                    };
195                    q = q.clamp(-15, 15);
196                }
197                let u = (q + 15) as u32;
198                acc |= (u as u64) << acc_bits;
199                acc_bits += 5;
200                while acc_bits >= 8 {
201                    out.push((acc & 0xFF) as u8);
202                    acc >>= 8;
203                    acc_bits -= 8;
204                }
205                i += 1;
206            }
207        }
208
209        if acc_bits > 0 {
210            out.push((acc & 0xFF) as u8);
211        }
212        return;
213    }
214
215    // Generic path for sub-byte bit widths.
216    let qmax_i = qmax;
217    let bias = qmax;
218    let bits_u32 = bits as u32;
219
220    let needed_bytes = (frame.len() * bits as usize).div_ceil(8);
221    out.reserve(needed_bytes);
222
223    let mut acc: u64 = 0;
224    let mut acc_bits: u32 = 0;
225
226    for (group_idx, chunk) in frame.chunks(group_len).enumerate() {
227        let scale = if group_idx < scales_f32.len() {
228            scales_f32[group_idx]
229        } else {
230            0.0
231        };
232        let inv_scale = if scale == 0.0 { 0.0 } else { 1.0 / scale };
233
234        for &v in chunk {
235            let mut q: i32 = 0;
236            if v.is_finite() {
237                let scaled = v * inv_scale;
238                q = if scaled >= 0.0 { (scaled + 0.5) as i32 } else { (scaled - 0.5) as i32 };
239                q = q.clamp(-qmax_i, qmax_i);
240            }
241
242            let u = (q + bias) as u32;
243            acc |= (u as u64) << acc_bits;
244            acc_bits += bits_u32;
245
246            while acc_bits >= 8 {
247                out.push((acc & 0xFF) as u8);
248                acc >>= 8;
249                acc_bits -= 8;
250            }
251        }
252    }
253
254    if acc_bits > 0 {
255        out.push((acc & 0xFF) as u8);
256    }
257}
258
259/// Dequantize packed codes using f32 scales, writing f32 values.
260///
261/// Iterates by frame then by group to avoid per-value modulo/division
262/// and caches the f32 scale per group.
263///
264/// For 8-bit data, reads bytes directly without bit accumulation.
265#[inline]
266pub fn dequantize_f32(
267    data: &[u8],
268    scales_f32: &[f32],
269    group_len: usize,
270    bits: u8,
271    tensor_len: usize,
272    frame_count: usize,
273    out: &mut Vec<f32>,
274) {
275    let qmax = qmax_from_bits(bits);
276    if qmax == 0 {
277        return;
278    }
279
280    let total = tensor_len * frame_count;
281    out.resize(total, 0.0);
282
283    // Fast path: 8-bit dequantization reads bytes directly, no bit accumulator.
284    if bits == 8 {
285        let mut out_idx = 0usize;
286        let mut byte_idx = 0usize;
287        for _frame in 0..frame_count {
288            let mut pos = 0usize;
289            let mut group_idx = 0usize;
290            while pos < tensor_len {
291                let group_end = (pos + group_len).min(tensor_len);
292                let scale = if group_idx < scales_f32.len() {
293                    scales_f32[group_idx]
294                } else {
295                    0.0
296                };
297                while pos < group_end && byte_idx < data.len() {
298                    let u = data[byte_idx] as i32;
299                    let q = u - 127;
300                    out[out_idx] = (q as f32) * scale;
301                    out_idx += 1;
302                    byte_idx += 1;
303                    pos += 1;
304                }
305                group_idx += 1;
306            }
307        }
308        return;
309    }
310
311    // Fast path: 3-bit dequantization processes 8 values from 3 bytes.
312    // 8 values * 3 bits = 24 bits = 3 bytes exactly, avoiding the bit accumulator.
313    // LSB-first packing layout for 8 values in 3 bytes:
314    //   byte0 = v0 | (v1 << 3) | ((v2 & 0x3) << 6)
315    //   byte1 = (v2 >> 2) | (v3 << 1) | (v4 << 4) | ((v5 & 0x1) << 7)
316    //   byte2 = (v5 >> 1) | (v6 << 2) | (v7 << 5)
317    if bits == 3 {
318        let bias = 3i32; // qmax for 3-bit
319        let mut out_idx = 0usize;
320        let mut byte_idx = 0usize;
321        for _frame in 0..frame_count {
322            let mut pos = 0usize;
323            let mut group_idx = 0usize;
324            while pos < tensor_len {
325                let group_end = (pos + group_len).min(tensor_len);
326                let scale = if group_idx < scales_f32.len() {
327                    scales_f32[group_idx]
328                } else {
329                    0.0
330                };
331                // Process 8 values at a time from 3 bytes
332                while pos + 8 <= group_end && byte_idx + 3 <= data.len() {
333                    let b0 = data[byte_idx] as u32;
334                    let b1 = data[byte_idx + 1] as u32;
335                    let b2 = data[byte_idx + 2] as u32;
336                    byte_idx += 3;
337
338                    out[out_idx]     = ((b0 & 0x7) as i32 - bias) as f32 * scale;
339                    out[out_idx + 1] = (((b0 >> 3) & 0x7) as i32 - bias) as f32 * scale;
340                    out[out_idx + 2] = ((((b0 >> 6) | (b1 << 2)) & 0x7) as i32 - bias) as f32 * scale;
341                    out[out_idx + 3] = (((b1 >> 1) & 0x7) as i32 - bias) as f32 * scale;
342                    out[out_idx + 4] = (((b1 >> 4) & 0x7) as i32 - bias) as f32 * scale;
343                    out[out_idx + 5] = ((((b1 >> 7) | (b2 << 1)) & 0x7) as i32 - bias) as f32 * scale;
344                    out[out_idx + 6] = (((b2 >> 2) & 0x7) as i32 - bias) as f32 * scale;
345                    out[out_idx + 7] = (((b2 >> 5) & 0x7) as i32 - bias) as f32 * scale;
346                    out_idx += 8;
347                    pos += 8;
348                }
349                // Handle remaining values (< 8) with a local bit accumulator
350                if pos < group_end {
351                    let remaining = group_end - pos;
352                    let mut acc: u64 = 0;
353                    let mut acc_bits: u32 = 0;
354                    while acc_bits < (remaining as u32) * 3 && byte_idx < data.len() {
355                        acc |= (data[byte_idx] as u64) << acc_bits;
356                        acc_bits += 8;
357                        byte_idx += 1;
358                    }
359                    for _ in 0..remaining {
360                        if acc_bits < 3 {
361                            break;
362                        }
363                        let u = (acc & 0x7) as i32;
364                        acc >>= 3;
365                        acc_bits -= 3;
366                        out[out_idx] = (u - bias) as f32 * scale;
367                        out_idx += 1;
368                        pos += 1;
369                    }
370                }
371                group_idx += 1;
372            }
373        }
374        return;
375    }
376
377    // Fast path: 7-bit dequantization processes 8 values from 7 bytes.
378    // 8 values * 7 bits = 56 bits = 7 bytes exactly, avoiding the bit accumulator.
379    // LSB-first packing layout for 8 values in 7 bytes:
380    //   v0 = b0 & 0x7F
381    //   v1 = ((b0 >> 7) | (b1 << 1)) & 0x7F
382    //   v2 = ((b1 >> 6) | (b2 << 2)) & 0x7F
383    //   v3 = ((b2 >> 5) | (b3 << 3)) & 0x7F
384    //   v4 = ((b3 >> 4) | (b4 << 4)) & 0x7F
385    //   v5 = ((b4 >> 3) | (b5 << 5)) & 0x7F
386    //   v6 = ((b5 >> 2) | (b6 << 6)) & 0x7F
387    //   v7 = (b6 >> 1) & 0x7F
388    if bits == 7 {
389        let bias = 63i32; // qmax for 7-bit
390        let mut out_idx = 0usize;
391        let mut byte_idx = 0usize;
392        for _frame in 0..frame_count {
393            let mut pos = 0usize;
394            let mut group_idx = 0usize;
395            while pos < tensor_len {
396                let group_end = (pos + group_len).min(tensor_len);
397                let scale = if group_idx < scales_f32.len() {
398                    scales_f32[group_idx]
399                } else {
400                    0.0
401                };
402                // Process 8 values at a time from 7 bytes
403                #[inline]
404                fn unpack_7bit(out: &mut [f32], out_idx: usize, data: &[u8], byte_idx: usize, bias: i32, scale: f32) {
405                    let b0 = data[byte_idx] as u32;
406                    let b1 = data[byte_idx + 1] as u32;
407                    let b2 = data[byte_idx + 2] as u32;
408                    let b3 = data[byte_idx + 3] as u32;
409                    let b4 = data[byte_idx + 4] as u32;
410                    let b5 = data[byte_idx + 5] as u32;
411                    let b6 = data[byte_idx + 6] as u32;
412
413                    out[out_idx]     = ((b0 & 0x7F) as i32 - bias) as f32 * scale;
414                    out[out_idx + 1] = ((((b0 >> 7) | (b1 << 1)) & 0x7F) as i32 - bias) as f32 * scale;
415                    out[out_idx + 2] = ((((b1 >> 6) | (b2 << 2)) & 0x7F) as i32 - bias) as f32 * scale;
416                    out[out_idx + 3] = ((((b2 >> 5) | (b3 << 3)) & 0x7F) as i32 - bias) as f32 * scale;
417                    out[out_idx + 4] = ((((b3 >> 4) | (b4 << 4)) & 0x7F) as i32 - bias) as f32 * scale;
418                    out[out_idx + 5] = ((((b4 >> 3) | (b5 << 5)) & 0x7F) as i32 - bias) as f32 * scale;
419                    out[out_idx + 6] = ((((b5 >> 2) | (b6 << 6)) & 0x7F) as i32 - bias) as f32 * scale;
420                    out[out_idx + 7] = (((b6 >> 1) & 0x7F) as i32 - bias) as f32 * scale;
421                }
422                while pos + 8 <= group_end && byte_idx + 7 <= data.len() {
423                    unpack_7bit(out, out_idx, data, byte_idx, bias, scale);
424                    byte_idx += 7;
425                    out_idx += 8;
426                    pos += 8;
427                }
428                // Handle remaining values (< 8) with a local bit accumulator
429                if pos < group_end {
430                    let remaining = group_end - pos;
431                    let mut acc: u64 = 0;
432                    let mut acc_bits: u32 = 0;
433                    while acc_bits < (remaining as u32) * 7 && byte_idx < data.len() {
434                        acc |= (data[byte_idx] as u64) << acc_bits;
435                        acc_bits += 8;
436                        byte_idx += 1;
437                    }
438                    for _ in 0..remaining {
439                        if acc_bits < 7 {
440                            break;
441                        }
442                        let u = (acc & 0x7F) as i32;
443                        acc >>= 7;
444                        acc_bits -= 7;
445                        out[out_idx] = (u - bias) as f32 * scale;
446                        out_idx += 1;
447                        pos += 1;
448                    }
449                }
450                group_idx += 1;
451            }
452        }
453        return;
454    }
455
456    // Fast path: 5-bit dequantization processes 8 values from 5 bytes.
457    // 8 values * 5 bits = 40 bits = 5 bytes exactly, avoiding the bit accumulator.
458    // LSB-first packing layout for 8 values in 5 bytes:
459    //   v0 = b0 & 0x1F
460    //   v1 = ((b0 >> 5) | (b1 << 3)) & 0x1F
461    //   v2 = (b1 >> 2) & 0x1F
462    //   v3 = ((b1 >> 7) | (b2 << 1)) & 0x1F
463    //   v4 = ((b2 >> 4) | (b3 << 4)) & 0x1F
464    //   v5 = (b3 >> 1) & 0x1F
465    //   v6 = ((b3 >> 6) | (b4 << 2)) & 0x1F
466    //   v7 = (b4 >> 3) & 0x1F
467    if bits == 5 {
468        let bias = 15i32; // qmax for 5-bit
469        let mut out_idx = 0usize;
470        let mut byte_idx = 0usize;
471        for _frame in 0..frame_count {
472            let mut pos = 0usize;
473            let mut group_idx = 0usize;
474            while pos < tensor_len {
475                let group_end = (pos + group_len).min(tensor_len);
476                let scale = if group_idx < scales_f32.len() {
477                    scales_f32[group_idx]
478                } else {
479                    0.0
480                };
481                // Process 8 values at a time from 5 bytes
482                #[inline]
483                fn unpack_5bit(out: &mut [f32], out_idx: usize, data: &[u8], byte_idx: usize, bias: i32, scale: f32) {
484                    let b0 = data[byte_idx] as u32;
485                    let b1 = data[byte_idx + 1] as u32;
486                    let b2 = data[byte_idx + 2] as u32;
487                    let b3 = data[byte_idx + 3] as u32;
488                    let b4 = data[byte_idx + 4] as u32;
489
490                    out[out_idx]     = ((b0 & 0x1F) as i32 - bias) as f32 * scale;
491                    out[out_idx + 1] = ((((b0 >> 5) | (b1 << 3)) & 0x1F) as i32 - bias) as f32 * scale;
492                    out[out_idx + 2] = (((b1 >> 2) & 0x1F) as i32 - bias) as f32 * scale;
493                    out[out_idx + 3] = ((((b1 >> 7) | (b2 << 1)) & 0x1F) as i32 - bias) as f32 * scale;
494                    out[out_idx + 4] = ((((b2 >> 4) | (b3 << 4)) & 0x1F) as i32 - bias) as f32 * scale;
495                    out[out_idx + 5] = (((b3 >> 1) & 0x1F) as i32 - bias) as f32 * scale;
496                    out[out_idx + 6] = ((((b3 >> 6) | (b4 << 2)) & 0x1F) as i32 - bias) as f32 * scale;
497                    out[out_idx + 7] = (((b4 >> 3) & 0x1F) as i32 - bias) as f32 * scale;
498                }
499                while pos + 8 <= group_end && byte_idx + 5 <= data.len() {
500                    unpack_5bit(out, out_idx, data, byte_idx, bias, scale);
501                    byte_idx += 5;
502                    out_idx += 8;
503                    pos += 8;
504                }
505                // Handle remaining values (< 8) with a local bit accumulator
506                if pos < group_end {
507                    let remaining = group_end - pos;
508                    let mut acc: u64 = 0;
509                    let mut acc_bits: u32 = 0;
510                    while acc_bits < (remaining as u32) * 5 && byte_idx < data.len() {
511                        acc |= (data[byte_idx] as u64) << acc_bits;
512                        acc_bits += 8;
513                        byte_idx += 1;
514                    }
515                    for _ in 0..remaining {
516                        if acc_bits < 5 {
517                            break;
518                        }
519                        let u = (acc & 0x1F) as i32;
520                        acc >>= 5;
521                        acc_bits -= 5;
522                        out[out_idx] = (u - bias) as f32 * scale;
523                        out_idx += 1;
524                        pos += 1;
525                    }
526                }
527                group_idx += 1;
528            }
529        }
530        return;
531    }
532
533    // Generic path for sub-byte bit widths.
534    let bias = qmax;
535    let bits_u32 = bits as u32;
536    let mask = (1u64 << bits_u32) - 1;
537
538    let mut acc: u64 = 0;
539    let mut acc_bits: u32 = 0;
540    let mut byte_idx = 0usize;
541    let mut out_idx = 0usize;
542
543    for _frame in 0..frame_count {
544        let mut pos = 0usize;
545        let mut group_idx = 0usize;
546
547        while pos < tensor_len {
548            let group_end = (pos + group_len).min(tensor_len);
549            let scale = if group_idx < scales_f32.len() {
550                scales_f32[group_idx]
551            } else {
552                0.0
553            };
554
555            while pos < group_end {
556                while acc_bits < bits_u32 && byte_idx < data.len() {
557                    acc |= (data[byte_idx] as u64) << acc_bits;
558                    acc_bits += 8;
559                    byte_idx += 1;
560                }
561                if acc_bits < bits_u32 {
562                    return;
563                }
564
565                let u = (acc & mask) as u32;
566                acc >>= bits_u32;
567                acc_bits -= bits_u32;
568
569                let q = (u as i32) - bias;
570                out[out_idx] = (q as f32) * scale;
571                out_idx += 1;
572                pos += 1;
573            }
574
575            group_idx += 1;
576        }
577    }
578}
579
580// --- Legacy API (delegates to f32 variants) ---
581
582/// Check if a frame fits within existing f16 scales (within drift tolerance).
583pub fn frame_fits_scales(
584    frame: &[f32],
585    scales: &[u16],
586    group_len: usize,
587    bits: u8,
588    drift_factor: f32,
589) -> bool {
590    let scales_f32 = scales_to_f32(scales);
591    frame_fits_scales_f32(frame, &scales_f32, group_len, bits, drift_factor)
592}
593
594/// Quantize a frame using pre-computed f16 scales and pack into bitstream.
595pub fn quantize_and_pack(
596    frame: &[f32],
597    scales: &[u16],
598    group_len: usize,
599    bits: u8,
600    out: &mut Vec<u8>,
601) {
602    let scales_f32 = scales_to_f32(scales);
603    quantize_and_pack_f32(frame, &scales_f32, group_len, bits, out)
604}
605
606/// Dequantize packed codes using f16 scales, writing f32 values.
607pub fn dequantize(
608    data: &[u8],
609    scales: &[u16],
610    group_len: usize,
611    bits: u8,
612    tensor_len: usize,
613    frame_count: usize,
614    out: &mut Vec<f32>,
615) {
616    let scales_f32 = scales_to_f32(scales);
617    dequantize_f32(data, &scales_f32, group_len, bits, tensor_len, frame_count, out)
618}
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623
624    #[test]
625    fn test_quantize_roundtrip_8bit() {
626        let frame: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.1).collect();
627        let scales = compute_scales(&frame, 64, 8);
628        let mut packed = Vec::new();
629        quantize_and_pack(&frame, &scales, 64, 8, &mut packed);
630
631        let mut decoded = Vec::new();
632        dequantize(&packed, &scales, 64, 8, frame.len(), 1, &mut decoded);
633
634        assert_eq!(decoded.len(), frame.len());
635        for (i, (&orig, &dec)) in frame.iter().zip(decoded.iter()).enumerate() {
636            let err = (orig - dec).abs();
637            let max_err = if orig.abs() > 0.01 { orig.abs() * 0.02 } else { 0.1 };
638            assert!(err < max_err, "i={i}, orig={orig}, dec={dec}, err={err}");
639        }
640    }
641
642    #[test]
643    fn test_quantize_roundtrip_3bit() {
644        let frame: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.5).collect();
645        let scales = compute_scales(&frame, 64, 3);
646        let mut packed = Vec::new();
647        quantize_and_pack(&frame, &scales, 64, 3, &mut packed);
648
649        let mut decoded = Vec::new();
650        dequantize(&packed, &scales, 64, 3, frame.len(), 1, &mut decoded);
651
652        let max_val = frame.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
653        for (&orig, &dec) in frame.iter().zip(decoded.iter()) {
654            let err = (orig - dec).abs();
655            assert!(err < max_val * 0.35, "orig={orig}, dec={dec}, err={err}");
656        }
657    }
658
659    #[test]
660    fn test_quantize_roundtrip_5bit() {
661        let frame: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.05).collect();
662        let scales = compute_scales(&frame, 64, 5);
663        let mut packed = Vec::new();
664        quantize_and_pack(&frame, &scales, 64, 5, &mut packed);
665
666        let mut decoded = Vec::new();
667        dequantize(&packed, &scales, 64, 5, frame.len(), 1, &mut decoded);
668
669        let max_val = frame.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
670        for (&orig, &dec) in frame.iter().zip(decoded.iter()) {
671            let err = (orig - dec).abs();
672            assert!(err < max_val * 0.08, "orig={orig}, dec={dec}, err={err}");
673        }
674    }
675
676    #[test]
677    fn test_quantize_roundtrip_7bit() {
678        let frame: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.05).collect();
679        let scales = compute_scales(&frame, 64, 7);
680        let mut packed = Vec::new();
681        quantize_and_pack(&frame, &scales, 64, 7, &mut packed);
682
683        let mut decoded = Vec::new();
684        dequantize(&packed, &scales, 64, 7, frame.len(), 1, &mut decoded);
685
686        for (i, (&orig, &dec)) in frame.iter().zip(decoded.iter()).enumerate() {
687            let err = (orig - dec).abs();
688            let max_err = if orig.abs() > 0.01 { orig.abs() * 0.02 } else { 0.1 };
689            assert!(err < max_err, "i={i}, orig={orig}, dec={dec}, err={err}");
690        }
691    }
692
693    #[test]
694    fn test_drift_detection() {
695        let frame1: Vec<f32> = vec![1.0; 64];
696        let frame2: Vec<f32> = vec![1.05; 64];
697        let frame3: Vec<f32> = vec![2.0; 64];
698
699        let scales = compute_scales(&frame1, 64, 8);
700        let drift_factor = 1.0 + 26.0 / 256.0;
701
702        assert!(frame_fits_scales(&frame2, &scales, 64, 8, drift_factor));
703        assert!(!frame_fits_scales(&frame3, &scales, 64, 8, drift_factor));
704    }
705
706    #[test]
707    fn test_zero_frame() {
708        let frame = vec![0.0f32; 128];
709        let scales = compute_scales(&frame, 64, 8);
710        let mut packed = Vec::new();
711        quantize_and_pack(&frame, &scales, 64, 8, &mut packed);
712
713        let mut decoded = Vec::new();
714        dequantize(&packed, &scales, 64, 8, 128, 1, &mut decoded);
715
716        for &v in &decoded {
717            assert_eq!(v, 0.0);
718        }
719    }
720
721    #[test]
722    fn test_non_finite_values() {
723        let mut frame = vec![1.0f32; 64];
724        frame[10] = f32::NAN;
725        frame[20] = f32::INFINITY;
726        frame[30] = f32::NEG_INFINITY;
727
728        let scales = compute_scales(&frame, 64, 8);
729        let mut packed = Vec::new();
730        quantize_and_pack(&frame, &scales, 64, 8, &mut packed);
731
732        let mut decoded = Vec::new();
733        dequantize(&packed, &scales, 64, 8, 64, 1, &mut decoded);
734
735        assert_eq!(decoded[10], 0.0);
736        assert_eq!(decoded[20], 0.0);
737        assert_eq!(decoded[30], 0.0);
738        assert!((decoded[0] - 1.0).abs() < 0.02);
739    }
740
741    #[test]
742    fn test_single_element_group() {
743        let frame = vec![3.14f32; 16];
744        let scales = compute_scales(&frame, 1, 8);
745        assert_eq!(scales.len(), 16);
746
747        let mut packed = Vec::new();
748        quantize_and_pack(&frame, &scales, 1, 8, &mut packed);
749
750        let mut decoded = Vec::new();
751        dequantize(&packed, &scales, 1, 8, 16, 1, &mut decoded);
752
753        for (i, &v) in decoded.iter().enumerate() {
754            let err = (v - 3.14).abs();
755            assert!(err < 0.03, "i={i} v={v} err={err}");
756        }
757    }
758
759    #[test]
760    fn test_compression_ratio() {
761        let frame = vec![1.0f32; 512];
762        for &(bits, min_ratio) in &[(8u8, 3.5f32), (7, 4.0), (5, 5.5), (3, 8.5)] {
763            let scales = compute_scales(&frame, 64, bits);
764            let mut packed = Vec::new();
765            quantize_and_pack(&frame, &scales, 64, bits, &mut packed);
766
767            let raw_bytes = frame.len() * 4;
768            let compressed = packed.len() + scales.len() * 2;
769            let ratio = raw_bytes as f32 / compressed as f32;
770
771            assert!(
772                ratio >= min_ratio,
773                "bits={bits}: ratio {ratio:.2}x < expected {min_ratio}x"
774            );
775        }
776    }
777}