ruvector_sparse_inference/precision/
quantizers.rs1use super::lanes::PrecisionLane;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct QuantizedBlock {
12 pub data: Vec<i8>,
14 pub scale: f32,
16 pub zero_point: i8,
18 pub block_size: usize,
20 pub lane: PrecisionLane,
22}
23
24impl QuantizedBlock {
25 pub fn new(lane: PrecisionLane, block_size: usize) -> Self {
27 Self {
28 data: Vec::with_capacity(block_size),
29 scale: lane.default_scale(),
30 zero_point: 0,
31 block_size,
32 lane,
33 }
34 }
35
36 pub fn dequantize(&self) -> Vec<f32> {
38 self.data.iter()
39 .map(|&q| ((q as i32 - self.zero_point as i32) as f32) * self.scale)
40 .collect()
41 }
42
43 pub fn size_bytes(&self) -> usize {
45 self.data.len() + 4 + 1 }
47}
48
49#[derive(Debug, Clone)]
54pub struct Quantizer3Bit {
55 pub scales: Vec<f32>,
57 pub block_size: usize,
59 pub activation_lut: Option<[f32; 8]>,
61}
62
63impl Quantizer3Bit {
64 pub fn new(block_size: usize) -> Self {
66 Self {
67 scales: Vec::new(),
68 block_size,
69 activation_lut: None,
70 }
71 }
72
73 pub fn with_activation_lut(mut self, lut: [f32; 8]) -> Self {
75 self.activation_lut = Some(lut);
76 self
77 }
78
79 pub fn quantize(&mut self, values: &[f32]) -> Vec<u8> {
81 let num_blocks = (values.len() + self.block_size - 1) / self.block_size;
82 self.scales = Vec::with_capacity(num_blocks);
83
84 let mut result = Vec::with_capacity((values.len() + 1) / 2); for block in values.chunks(self.block_size) {
87 let max_abs = block.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
89 let scale = if max_abs > 0.0 { max_abs / 3.0 } else { 1.0 }; self.scales.push(scale);
91
92 for pair in block.chunks(2) {
94 let q0 = Self::quantize_value(pair[0], scale);
95 let q1 = if pair.len() > 1 {
96 Self::quantize_value(pair[1], scale)
97 } else {
98 0
99 };
100 result.push(((q1 as u8) << 4) | (q0 as u8 & 0x0F));
102 }
103 }
104
105 result
106 }
107
108 fn quantize_value(value: f32, scale: f32) -> i8 {
110 let scaled = (value / scale).round() as i8;
111 scaled.clamp(-4, 3)
112 }
113
114 pub fn dequantize(&self, data: &[u8], num_values: usize) -> Vec<f32> {
116 let mut result = Vec::with_capacity(num_values);
117 let mut value_idx = 0;
118 let mut block_idx = 0;
119
120 for &byte in data {
121 if value_idx >= num_values {
122 break;
123 }
124
125 let scale = self.scales.get(block_idx).copied().unwrap_or(1.0);
126
127 let q0 = (byte & 0x0F) as i8;
129 let q0 = if q0 > 7 { q0 - 16 } else { q0 }; let v0 = (q0 as f32) * scale;
131
132 let v0 = if let Some(ref lut) = self.activation_lut {
134 lut[(q0 + 4) as usize]
135 } else {
136 v0
137 };
138
139 result.push(v0);
140 value_idx += 1;
141
142 if value_idx >= num_values {
143 break;
144 }
145
146 let q1 = ((byte >> 4) & 0x0F) as i8;
148 let q1 = if q1 > 7 { q1 - 16 } else { q1 };
149 let v1 = (q1 as f32) * scale;
150
151 let v1 = if let Some(ref lut) = self.activation_lut {
152 lut[(q1 + 4) as usize]
153 } else {
154 v1
155 };
156
157 result.push(v1);
158 value_idx += 1;
159
160 if value_idx % self.block_size == 0 {
162 block_idx += 1;
163 }
164 }
165
166 result
167 }
168}
169
170#[derive(Debug, Clone)]
175pub struct Quantizer5Bit {
176 pub scales: Vec<f32>,
178 pub block_size: usize,
180 pub per_channel: bool,
182}
183
184impl Quantizer5Bit {
185 pub fn new(block_size: usize) -> Self {
187 Self {
188 scales: Vec::new(),
189 block_size,
190 per_channel: false,
191 }
192 }
193
194 pub fn with_per_channel(mut self) -> Self {
196 self.per_channel = true;
197 self
198 }
199
200 pub fn quantize(&mut self, values: &[f32]) -> Vec<i8> {
202 if self.per_channel {
203 self.quantize_per_channel(values)
204 } else {
205 self.quantize_per_block(values)
206 }
207 }
208
209 fn quantize_per_block(&mut self, values: &[f32]) -> Vec<i8> {
210 let num_blocks = (values.len() + self.block_size - 1) / self.block_size;
211 self.scales = Vec::with_capacity(num_blocks);
212
213 let mut result = Vec::with_capacity(values.len());
214
215 for block in values.chunks(self.block_size) {
216 let max_abs = block.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
217 let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 }; self.scales.push(scale);
219
220 for &value in block {
221 let q = (value / scale).round() as i8;
222 result.push(q.clamp(-16, 15));
223 }
224 }
225
226 result
227 }
228
229 fn quantize_per_channel(&mut self, values: &[f32]) -> Vec<i8> {
230 self.scales = Vec::with_capacity(values.len());
231
232 values.iter().map(|&value| {
233 let max_abs = value.abs();
234 let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 };
235 self.scales.push(scale);
236 let q = (value / scale).round() as i8;
237 q.clamp(-16, 15)
238 }).collect()
239 }
240
241 pub fn dequantize(&self, data: &[i8]) -> Vec<f32> {
243 if self.per_channel {
244 data.iter().zip(self.scales.iter())
245 .map(|(&q, &scale)| (q as f32) * scale)
246 .collect()
247 } else {
248 let mut result = Vec::with_capacity(data.len());
249 let mut block_idx = 0;
250
251 for (i, &q) in data.iter().enumerate() {
252 let scale = self.scales.get(block_idx).copied().unwrap_or(1.0);
253 result.push((q as f32) * scale);
254
255 if (i + 1) % self.block_size == 0 {
256 block_idx += 1;
257 }
258 }
259
260 result
261 }
262 }
263}
264
265#[derive(Debug, Clone)]
270pub struct Quantizer7Bit {
271 pub scales: Vec<f32>,
273 pub block_size: usize,
275}
276
277impl Quantizer7Bit {
278 pub fn new(block_size: usize) -> Self {
280 Self {
281 scales: Vec::new(),
282 block_size,
283 }
284 }
285
286 pub fn quantize(&mut self, values: &[f32]) -> Vec<i8> {
288 let num_blocks = (values.len() + self.block_size - 1) / self.block_size;
289 self.scales = Vec::with_capacity(num_blocks);
290
291 let mut result = Vec::with_capacity(values.len());
292
293 for block in values.chunks(self.block_size) {
294 let max_abs = block.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
295 let scale = if max_abs > 0.0 { max_abs / 63.0 } else { 1.0 }; self.scales.push(scale);
297
298 for &value in block {
299 let q = (value / scale).round() as i8;
300 result.push(q.clamp(-64, 63));
301 }
302 }
303
304 result
305 }
306
307 pub fn dequantize(&self, data: &[i8]) -> Vec<f32> {
309 let mut result = Vec::with_capacity(data.len());
310 let mut block_idx = 0;
311
312 for (i, &q) in data.iter().enumerate() {
313 let scale = self.scales.get(block_idx).copied().unwrap_or(1.0);
314 result.push((q as f32) * scale);
315
316 if (i + 1) % self.block_size == 0 {
317 block_idx += 1;
318 }
319 }
320
321 result
322 }
323
324 pub fn apply_lora_delta(&mut self, base: &[i8], delta: &[i8], alpha: f32) -> Vec<i8> {
326 base.iter().zip(delta.iter()).map(|(&b, &d)| {
327 let result = (b as f32) + (d as f32) * alpha;
328 (result.round() as i8).clamp(-64, 63)
329 }).collect()
330 }
331}
332
333#[derive(Debug, Clone)]
335pub enum LaneQuantizer {
336 Bit3(Quantizer3Bit),
337 Bit5(Quantizer5Bit),
338 Bit7(Quantizer7Bit),
339}
340
341impl LaneQuantizer {
342 pub fn for_lane(lane: PrecisionLane, block_size: usize) -> Self {
344 match lane {
345 PrecisionLane::Bit3 => Self::Bit3(Quantizer3Bit::new(block_size)),
346 PrecisionLane::Bit5 => Self::Bit5(Quantizer5Bit::new(block_size)),
347 PrecisionLane::Bit7 => Self::Bit7(Quantizer7Bit::new(block_size)),
348 PrecisionLane::Float32 => Self::Bit7(Quantizer7Bit::new(block_size)), }
350 }
351
352 pub fn lane(&self) -> PrecisionLane {
354 match self {
355 Self::Bit3(_) => PrecisionLane::Bit3,
356 Self::Bit5(_) => PrecisionLane::Bit5,
357 Self::Bit7(_) => PrecisionLane::Bit7,
358 }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_3bit_roundtrip() {
368 let mut quantizer = Quantizer3Bit::new(32);
369 let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
370
371 let quantized = quantizer.quantize(&values);
372 let dequantized = quantizer.dequantize(&quantized, values.len());
373
374 assert_eq!(dequantized.len(), values.len());
375
376 for (orig, deq) in values.iter().zip(dequantized.iter()) {
380 let error = (orig - deq).abs();
381 assert!(error < 1.0, "Error too large: {} vs {}", orig, deq);
382 }
383 }
384
385 #[test]
386 fn test_5bit_roundtrip() {
387 let mut quantizer = Quantizer5Bit::new(32);
388 let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
389
390 let quantized = quantizer.quantize(&values);
391 let dequantized = quantizer.dequantize(&quantized);
392
393 assert_eq!(dequantized.len(), values.len());
394
395 for (orig, deq) in values.iter().zip(dequantized.iter()) {
396 let error = (orig - deq).abs();
397 assert!(error < 0.2, "Error too large: {} vs {}", orig, deq);
398 }
399 }
400
401 #[test]
402 fn test_7bit_roundtrip() {
403 let mut quantizer = Quantizer7Bit::new(32);
404 let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
405
406 let quantized = quantizer.quantize(&values);
407 let dequantized = quantizer.dequantize(&quantized);
408
409 assert_eq!(dequantized.len(), values.len());
410
411 for (orig, deq) in values.iter().zip(dequantized.iter()) {
412 let error = (orig - deq).abs();
413 assert!(error < 0.1, "Error too large: {} vs {}", orig, deq);
414 }
415 }
416
417 #[test]
418 fn test_7bit_lora_delta() {
419 let mut quantizer = Quantizer7Bit::new(32);
420 let base: Vec<i8> = vec![10, 20, 30, 40];
421 let delta: Vec<i8> = vec![1, 2, 3, 4];
422
423 let result = quantizer.apply_lora_delta(&base, &delta, 0.5);
424
425 assert_eq!(result[0], 11); assert_eq!(result[1], 21); assert_eq!(result[2], 32); assert_eq!(result[3], 42); }
430}