ruvector_sparse_inference/precision/
quantizers.rs

1//! Quantizers for 3/5/7-bit precision lanes
2//!
3//! Implements pack/unpack operations for each precision lane with
4//! per-block or per-channel scaling.
5
6use super::lanes::PrecisionLane;
7use serde::{Deserialize, Serialize};
8
9/// Quantized block with scale factor
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct QuantizedBlock {
12    /// Quantized data
13    pub data: Vec<i8>,
14    /// Scale factor for dequantization
15    pub scale: f32,
16    /// Zero point offset
17    pub zero_point: i8,
18    /// Block size
19    pub block_size: usize,
20    /// Precision lane
21    pub lane: PrecisionLane,
22}
23
24impl QuantizedBlock {
25    /// Create a new quantized block
26    pub fn new(lane: PrecisionLane, block_size: usize) -> Self {
27        Self {
28            data: Vec::with_capacity(block_size),
29            scale: lane.default_scale(),
30            zero_point: 0,
31            block_size,
32            lane,
33        }
34    }
35
36    /// Dequantize to f32 values
37    pub fn dequantize(&self) -> Vec<f32> {
38        self.data.iter()
39            .map(|&q| ((q as i32 - self.zero_point as i32) as f32) * self.scale)
40            .collect()
41    }
42
43    /// Get memory size in bytes
44    pub fn size_bytes(&self) -> usize {
45        self.data.len() + 4 + 1 // data + scale + zero_point
46    }
47}
48
49/// 3-bit quantizer for reflex signals
50///
51/// Uses signed int4 container with values restricted to -4..3.
52/// Optimized for LUT-based activation.
53#[derive(Debug, Clone)]
54pub struct Quantizer3Bit {
55    /// Per-block scale factors
56    pub scales: Vec<f32>,
57    /// Block size (typically 32)
58    pub block_size: usize,
59    /// LUT for activation (optional)
60    pub activation_lut: Option<[f32; 8]>,
61}
62
63impl Quantizer3Bit {
64    /// Create a new 3-bit quantizer
65    pub fn new(block_size: usize) -> Self {
66        Self {
67            scales: Vec::new(),
68            block_size,
69            activation_lut: None,
70        }
71    }
72
73    /// Set activation LUT (e.g., for ReLU)
74    pub fn with_activation_lut(mut self, lut: [f32; 8]) -> Self {
75        self.activation_lut = Some(lut);
76        self
77    }
78
79    /// Quantize f32 values to 3-bit
80    pub fn quantize(&mut self, values: &[f32]) -> Vec<u8> {
81        let num_blocks = (values.len() + self.block_size - 1) / self.block_size;
82        self.scales = Vec::with_capacity(num_blocks);
83
84        let mut result = Vec::with_capacity((values.len() + 1) / 2); // Pack 2 values per byte
85
86        for block in values.chunks(self.block_size) {
87            // Find scale for this block
88            let max_abs = block.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
89            let scale = if max_abs > 0.0 { max_abs / 3.0 } else { 1.0 }; // 3-bit max is 3
90            self.scales.push(scale);
91
92            // Quantize values
93            for pair in block.chunks(2) {
94                let q0 = Self::quantize_value(pair[0], scale);
95                let q1 = if pair.len() > 1 {
96                    Self::quantize_value(pair[1], scale)
97                } else {
98                    0
99                };
100                // Pack two 4-bit values into one byte
101                result.push(((q1 as u8) << 4) | (q0 as u8 & 0x0F));
102            }
103        }
104
105        result
106    }
107
108    /// Quantize single value to 3-bit
109    fn quantize_value(value: f32, scale: f32) -> i8 {
110        let scaled = (value / scale).round() as i8;
111        scaled.clamp(-4, 3)
112    }
113
114    /// Dequantize 3-bit values to f32
115    pub fn dequantize(&self, data: &[u8], num_values: usize) -> Vec<f32> {
116        let mut result = Vec::with_capacity(num_values);
117        let mut value_idx = 0;
118        let mut block_idx = 0;
119
120        for &byte in data {
121            if value_idx >= num_values {
122                break;
123            }
124
125            let scale = self.scales.get(block_idx).copied().unwrap_or(1.0);
126
127            // Unpack first value (lower 4 bits)
128            let q0 = (byte & 0x0F) as i8;
129            let q0 = if q0 > 7 { q0 - 16 } else { q0 }; // Sign extend
130            let v0 = (q0 as f32) * scale;
131
132            // Apply activation LUT if present
133            let v0 = if let Some(ref lut) = self.activation_lut {
134                lut[(q0 + 4) as usize]
135            } else {
136                v0
137            };
138
139            result.push(v0);
140            value_idx += 1;
141
142            if value_idx >= num_values {
143                break;
144            }
145
146            // Unpack second value (upper 4 bits)
147            let q1 = ((byte >> 4) & 0x0F) as i8;
148            let q1 = if q1 > 7 { q1 - 16 } else { q1 };
149            let v1 = (q1 as f32) * scale;
150
151            let v1 = if let Some(ref lut) = self.activation_lut {
152                lut[(q1 + 4) as usize]
153            } else {
154                v1
155            };
156
157            result.push(v1);
158            value_idx += 1;
159
160            // Update block index
161            if value_idx % self.block_size == 0 {
162                block_idx += 1;
163            }
164        }
165
166        result
167    }
168}
169
170/// 5-bit quantizer for streaming embeddings
171///
172/// Uses signed int8 container with values in -16..15.
173/// Per-channel or per-block scale for stable streaming updates.
174#[derive(Debug, Clone)]
175pub struct Quantizer5Bit {
176    /// Per-block scale factors
177    pub scales: Vec<f32>,
178    /// Block size
179    pub block_size: usize,
180    /// Use per-channel scaling (instead of per-block)
181    pub per_channel: bool,
182}
183
184impl Quantizer5Bit {
185    /// Create a new 5-bit quantizer
186    pub fn new(block_size: usize) -> Self {
187        Self {
188            scales: Vec::new(),
189            block_size,
190            per_channel: false,
191        }
192    }
193
194    /// Enable per-channel scaling
195    pub fn with_per_channel(mut self) -> Self {
196        self.per_channel = true;
197        self
198    }
199
200    /// Quantize f32 values to 5-bit (stored in int8)
201    pub fn quantize(&mut self, values: &[f32]) -> Vec<i8> {
202        if self.per_channel {
203            self.quantize_per_channel(values)
204        } else {
205            self.quantize_per_block(values)
206        }
207    }
208
209    fn quantize_per_block(&mut self, values: &[f32]) -> Vec<i8> {
210        let num_blocks = (values.len() + self.block_size - 1) / self.block_size;
211        self.scales = Vec::with_capacity(num_blocks);
212
213        let mut result = Vec::with_capacity(values.len());
214
215        for block in values.chunks(self.block_size) {
216            let max_abs = block.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
217            let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 }; // 5-bit max is 15
218            self.scales.push(scale);
219
220            for &value in block {
221                let q = (value / scale).round() as i8;
222                result.push(q.clamp(-16, 15));
223            }
224        }
225
226        result
227    }
228
229    fn quantize_per_channel(&mut self, values: &[f32]) -> Vec<i8> {
230        self.scales = Vec::with_capacity(values.len());
231
232        values.iter().map(|&value| {
233            let max_abs = value.abs();
234            let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 };
235            self.scales.push(scale);
236            let q = (value / scale).round() as i8;
237            q.clamp(-16, 15)
238        }).collect()
239    }
240
241    /// Dequantize 5-bit values to f32
242    pub fn dequantize(&self, data: &[i8]) -> Vec<f32> {
243        if self.per_channel {
244            data.iter().zip(self.scales.iter())
245                .map(|(&q, &scale)| (q as f32) * scale)
246                .collect()
247        } else {
248            let mut result = Vec::with_capacity(data.len());
249            let mut block_idx = 0;
250
251            for (i, &q) in data.iter().enumerate() {
252                let scale = self.scales.get(block_idx).copied().unwrap_or(1.0);
253                result.push((q as f32) * scale);
254
255                if (i + 1) % self.block_size == 0 {
256                    block_idx += 1;
257                }
258            }
259
260            result
261        }
262    }
263}
264
265/// 7-bit quantizer for reasoning
266///
267/// Uses signed int8 container with values in -64..63.
268/// Stable accumulators, close to int8 quality.
269#[derive(Debug, Clone)]
270pub struct Quantizer7Bit {
271    /// Per-block scale factors
272    pub scales: Vec<f32>,
273    /// Block size
274    pub block_size: usize,
275}
276
277impl Quantizer7Bit {
278    /// Create a new 7-bit quantizer
279    pub fn new(block_size: usize) -> Self {
280        Self {
281            scales: Vec::new(),
282            block_size,
283        }
284    }
285
286    /// Quantize f32 values to 7-bit (stored in int8)
287    pub fn quantize(&mut self, values: &[f32]) -> Vec<i8> {
288        let num_blocks = (values.len() + self.block_size - 1) / self.block_size;
289        self.scales = Vec::with_capacity(num_blocks);
290
291        let mut result = Vec::with_capacity(values.len());
292
293        for block in values.chunks(self.block_size) {
294            let max_abs = block.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
295            let scale = if max_abs > 0.0 { max_abs / 63.0 } else { 1.0 }; // 7-bit max is 63
296            self.scales.push(scale);
297
298            for &value in block {
299                let q = (value / scale).round() as i8;
300                result.push(q.clamp(-64, 63));
301            }
302        }
303
304        result
305    }
306
307    /// Dequantize 7-bit values to f32
308    pub fn dequantize(&self, data: &[i8]) -> Vec<f32> {
309        let mut result = Vec::with_capacity(data.len());
310        let mut block_idx = 0;
311
312        for (i, &q) in data.iter().enumerate() {
313            let scale = self.scales.get(block_idx).copied().unwrap_or(1.0);
314            result.push((q as f32) * scale);
315
316            if (i + 1) % self.block_size == 0 {
317                block_idx += 1;
318            }
319        }
320
321        result
322    }
323
324    /// Apply micro-LoRA delta (in 7-bit precision)
325    pub fn apply_lora_delta(&mut self, base: &[i8], delta: &[i8], alpha: f32) -> Vec<i8> {
326        base.iter().zip(delta.iter()).map(|(&b, &d)| {
327            let result = (b as f32) + (d as f32) * alpha;
328            (result.round() as i8).clamp(-64, 63)
329        }).collect()
330    }
331}
332
333/// Unified quantizer that selects appropriate implementation
334#[derive(Debug, Clone)]
335pub enum LaneQuantizer {
336    Bit3(Quantizer3Bit),
337    Bit5(Quantizer5Bit),
338    Bit7(Quantizer7Bit),
339}
340
341impl LaneQuantizer {
342    /// Create quantizer for a specific lane
343    pub fn for_lane(lane: PrecisionLane, block_size: usize) -> Self {
344        match lane {
345            PrecisionLane::Bit3 => Self::Bit3(Quantizer3Bit::new(block_size)),
346            PrecisionLane::Bit5 => Self::Bit5(Quantizer5Bit::new(block_size)),
347            PrecisionLane::Bit7 => Self::Bit7(Quantizer7Bit::new(block_size)),
348            PrecisionLane::Float32 => Self::Bit7(Quantizer7Bit::new(block_size)), // Fallback
349        }
350    }
351
352    /// Get the precision lane
353    pub fn lane(&self) -> PrecisionLane {
354        match self {
355            Self::Bit3(_) => PrecisionLane::Bit3,
356            Self::Bit5(_) => PrecisionLane::Bit5,
357            Self::Bit7(_) => PrecisionLane::Bit7,
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_3bit_roundtrip() {
368        let mut quantizer = Quantizer3Bit::new(32);
369        let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
370
371        let quantized = quantizer.quantize(&values);
372        let dequantized = quantizer.dequantize(&quantized, values.len());
373
374        assert_eq!(dequantized.len(), values.len());
375
376        // Check error is bounded (3-bit is very lossy - only 8 levels)
377        // With range ~6.4 (-3.2 to 3.2), each level is ~0.8, so max error is ~0.4
378        // But with grouping, it can be higher
379        for (orig, deq) in values.iter().zip(dequantized.iter()) {
380            let error = (orig - deq).abs();
381            assert!(error < 1.0, "Error too large: {} vs {}", orig, deq);
382        }
383    }
384
385    #[test]
386    fn test_5bit_roundtrip() {
387        let mut quantizer = Quantizer5Bit::new(32);
388        let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
389
390        let quantized = quantizer.quantize(&values);
391        let dequantized = quantizer.dequantize(&quantized);
392
393        assert_eq!(dequantized.len(), values.len());
394
395        for (orig, deq) in values.iter().zip(dequantized.iter()) {
396            let error = (orig - deq).abs();
397            assert!(error < 0.2, "Error too large: {} vs {}", orig, deq);
398        }
399    }
400
401    #[test]
402    fn test_7bit_roundtrip() {
403        let mut quantizer = Quantizer7Bit::new(32);
404        let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
405
406        let quantized = quantizer.quantize(&values);
407        let dequantized = quantizer.dequantize(&quantized);
408
409        assert_eq!(dequantized.len(), values.len());
410
411        for (orig, deq) in values.iter().zip(dequantized.iter()) {
412            let error = (orig - deq).abs();
413            assert!(error < 0.1, "Error too large: {} vs {}", orig, deq);
414        }
415    }
416
417    #[test]
418    fn test_7bit_lora_delta() {
419        let mut quantizer = Quantizer7Bit::new(32);
420        let base: Vec<i8> = vec![10, 20, 30, 40];
421        let delta: Vec<i8> = vec![1, 2, 3, 4];
422
423        let result = quantizer.apply_lora_delta(&base, &delta, 0.5);
424
425        assert_eq!(result[0], 11); // 10 + 1*0.5 = 10.5 -> 11
426        assert_eq!(result[1], 21); // 20 + 2*0.5 = 21
427        assert_eq!(result[2], 32); // 30 + 3*0.5 = 31.5 -> 32
428        assert_eq!(result[3], 42); // 40 + 4*0.5 = 42
429    }
430}