sensorlm/quantization/int8.rs
1//! INT8 post-training quantisation (PTQ).
2//!
3//! # Overview
4//!
5//! Post-training quantisation converts the FP32 weights of a trained model to
6//! INT8 with minimal accuracy degradation. The procedure is:
7//!
8//! 1. **Calibration** – run the model on a small representative dataset and
9//! collect the per-layer minimum and maximum activation / weight values.
10//! 2. **Quantise** – for each `Linear` weight matrix compute a symmetric
11//! INT8 scale and zero-point, then represent weights as `i8` values.
12//! 3. **Dequantise at runtime** – before the matrix multiplication, convert
13//! `i8` back to `f32` using the stored scale.
14//!
15//! # Quantisation formula (symmetric per-tensor)
16//!
17//! ```text
18//! scale = max(|W|) / 127
19//! W_q = round(W / scale) ∈ [-127, 127]
20//! W_dq = W_q * scale (approx. original W)
21//! ```
22//!
23//! # INT8 linear layer
24//!
25//! The [`QuantizedLinear`] module stores weights as `i8` with a per-tensor
26//! scale and reconstructs `f32` on the fly. This halves the model's memory
27//! footprint and can accelerate inference when hardware INT8 GEMM is
28//! available.
29//!
30//! # Limitations
31//!
32//! * This is a **weight-only** quantisation scheme (activations remain FP32).
33//! * Full activation quantisation would require inserting `QuantizeAct` nodes
34//! throughout the graph – left as a future extension.
35//! * The Burn framework does not yet expose native INT8 GEMM kernels; the
36//! dequantise-then-multiply approach used here is correctness-demonstrating
37//! but does not provide a runtime speedup until WGPU INT8 kernels land.
38
39use std::{fs, path::Path};
40
41
42use serde::{Deserialize, Serialize};
43
44// ---------------------------------------------------------------------------
45// Per-layer calibration statistics
46// ---------------------------------------------------------------------------
47
48/// Calibration data collected from one `Linear` layer.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct LayerCalibration {
51 /// Layer name (e.g. `"sensor_encoder.blocks.0.attn.q_proj"`).
52 pub name: String,
53 /// Minimum observed weight value.
54 pub w_min: f32,
55 /// Maximum observed weight value.
56 pub w_max: f32,
57 /// Minimum observed input activation value.
58 pub act_min: f32,
59 /// Maximum observed input activation value.
60 pub act_max: f32,
61}
62
63impl LayerCalibration {
64 /// Compute the symmetric INT8 weight scale.
65 ///
66 /// `scale = max(|w_min|, |w_max|) / 127`
67 pub fn weight_scale(&self) -> f32 {
68 self.w_min.abs().max(self.w_max.abs()) / 127.0
69 }
70}
71
72// ---------------------------------------------------------------------------
73// Quantised weight representation
74// ---------------------------------------------------------------------------
75
76/// INT8 quantised weights for a single `Linear` layer.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct QuantizedWeights {
79 /// Layer name.
80 pub name: String,
81 /// Quantised weight values in row-major order (i8 serialised as i32 for
82 /// serde compatibility).
83 pub weights_i8: Vec<i32>,
84 /// Weight tensor shape `[out_features, in_features]`.
85 pub shape: Vec<usize>,
86 /// Per-tensor scale factor.
87 pub scale: f32,
88 /// Bias (kept in FP32).
89 pub bias: Option<Vec<f32>>,
90}
91
92impl QuantizedWeights {
93 /// Quantise an FP32 weight matrix.
94 ///
95 /// # Arguments
96 ///
97 /// * `name` – Layer name.
98 /// * `w` – FP32 weights in row-major layout.
99 /// * `shape` – `[out_features, in_features]`.
100 /// * `bias` – Optional FP32 bias vector.
101 /// * `scale` – If `None`, computed from the max absolute weight.
102 pub fn from_f32(
103 name: String,
104 w: &[f32],
105 shape: Vec<usize>,
106 bias: Option<Vec<f32>>,
107 scale: Option<f32>,
108 ) -> Self {
109 let max_abs = w.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
110 let s = scale.unwrap_or(max_abs / 127.0).max(1e-8);
111
112 let weights_i8: Vec<i32> = w
113 .iter()
114 .map(|&x| (x / s).round().clamp(-127.0, 127.0) as i32)
115 .collect();
116
117 Self { name, weights_i8, shape, scale: s, bias }
118 }
119
120 /// Dequantise to FP32.
121 pub fn dequantize(&self) -> Vec<f32> {
122 self.weights_i8
123 .iter()
124 .map(|&q| q as f32 * self.scale)
125 .collect()
126 }
127
128 /// Memory used by the quantised weights in bytes.
129 pub fn size_bytes(&self) -> usize {
130 self.weights_i8.len() // i8: 1 byte each
131 }
132
133 /// Memory the original FP32 weights would have used.
134 pub fn original_size_bytes(&self) -> usize {
135 self.weights_i8.len() * 4
136 }
137
138 /// Compression ratio vs. FP32.
139 pub fn compression_ratio(&self) -> f32 {
140 self.original_size_bytes() as f32 / self.size_bytes() as f32
141 }
142}
143
144// ---------------------------------------------------------------------------
145// Quantised model manifest
146// ---------------------------------------------------------------------------
147
148/// Collection of quantised weights for an entire model.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct QuantizedModel {
151 /// Model configuration JSON (serialised [`crate::config::SensorLMConfig`]).
152 pub config_json: String,
153 /// Quantised weights for every Linear layer in the model.
154 pub layers: Vec<QuantizedWeights>,
155 /// Quantisation scheme used.
156 pub scheme: String,
157 /// Total size of quantised weights in bytes.
158 pub total_quantized_bytes: usize,
159 /// Total size of original FP32 weights in bytes.
160 pub total_fp32_bytes: usize,
161}
162
163impl QuantizedModel {
164 /// Overall compression ratio (FP32 → INT8).
165 pub fn compression_ratio(&self) -> f32 {
166 self.total_fp32_bytes as f32 / self.total_quantized_bytes.max(1) as f32
167 }
168
169 /// Save the quantised model to a JSON file.
170 pub fn save(&self, path: &Path) -> crate::error::Result<()> {
171 let json = serde_json::to_string_pretty(self)?;
172 fs::write(path, json)?;
173 Ok(())
174 }
175
176 /// Load a quantised model from a JSON file.
177 pub fn load(path: &Path) -> crate::error::Result<Self> {
178 let json = fs::read_to_string(path)?;
179 let model = serde_json::from_str(&json)?;
180 Ok(model)
181 }
182}
183
184// ---------------------------------------------------------------------------
185// PTQ pipeline
186// ---------------------------------------------------------------------------
187
188/// Calibration pass over a small dataset to collect weight statistics.
189///
190/// In a production implementation this would hook into the model's forward
191/// pass via observers. Here we collect statistics directly from the stored
192/// weight tensors – sufficient for symmetric per-tensor weight quantisation
193/// since we do not require input activation statistics.
194pub struct Calibrator {
195 calibrations: Vec<LayerCalibration>,
196}
197
198impl Calibrator {
199 /// Create a new calibrator.
200 pub fn new() -> Self {
201 Self { calibrations: Vec::new() }
202 }
203
204 /// Record a weight tensor from one linear layer.
205 pub fn record_layer(&mut self, name: String, weights: &[f32]) {
206 let w_min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
207 let w_max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
208 self.calibrations.push(LayerCalibration {
209 name,
210 w_min,
211 w_max,
212 act_min: -1.0, // placeholder (weight-only PTQ)
213 act_max: 1.0,
214 });
215 }
216
217 /// Consume the calibrator and produce per-layer scale factors.
218 pub fn finish(self) -> Vec<LayerCalibration> {
219 self.calibrations
220 }
221}
222
223impl Default for Calibrator {
224 fn default() -> Self {
225 Self::new()
226 }
227}
228
229// ---------------------------------------------------------------------------
230// Quantise a flat representation of model weights
231// ---------------------------------------------------------------------------
232
233/// Quantise a named list of FP32 weight tensors.
234///
235/// This function takes the flat `(name, weights, shape, bias)` representation
236/// that would be extracted from a `SensorLMModel` and produces a
237/// [`QuantizedModel`].
238///
239/// # Arguments
240///
241/// * `config_json` – JSON string of the model config.
242/// * `layers` – Iterator of `(name, fp32_weights, shape, optional_bias)`.
243pub fn quantize_model_weights(
244 config_json: String,
245 layers: impl IntoIterator<Item = (String, Vec<f32>, Vec<usize>, Option<Vec<f32>>)>,
246) -> QuantizedModel {
247 let mut quantized_layers = Vec::new();
248 let mut total_fp32_bytes = 0usize;
249 let mut total_quantized_bytes = 0usize;
250
251 for (name, weights, shape, bias) in layers {
252 let qw = QuantizedWeights::from_f32(name, &weights, shape, bias, None);
253 total_fp32_bytes += qw.original_size_bytes();
254 total_quantized_bytes += qw.size_bytes();
255 quantized_layers.push(qw);
256 }
257
258 QuantizedModel {
259 config_json,
260 layers: quantized_layers,
261 scheme: "symmetric_per_tensor_int8".to_string(),
262 total_quantized_bytes,
263 total_fp32_bytes,
264 }
265}
266
267// ---------------------------------------------------------------------------
268// FP16 export
269// ---------------------------------------------------------------------------
270
271/// Convert an FP32 weight vector to FP16 (f16 represented as u16 bits).
272///
273/// FP16 halves the model's memory footprint with minimal accuracy loss and
274/// is natively supported by most modern GPUs.
275pub fn fp32_to_fp16_bits(weights: &[f32]) -> Vec<u16> {
276 weights
277 .iter()
278 .map(|&x| half::f16::from_f32(x).to_bits())
279 .collect()
280}
281
282/// Convert an FP16 weight vector back to FP32.
283pub fn fp16_bits_to_fp32(bits: &[u16]) -> Vec<f32> {
284 bits.iter()
285 .map(|&b| half::f16::from_bits(b).to_f32())
286 .collect()
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_quantize_roundtrip() {
295 let weights: Vec<f32> = (0..100).map(|i| i as f32 * 0.01 - 0.5).collect();
296 let qw = QuantizedWeights::from_f32(
297 "test".to_string(),
298 &weights,
299 vec![10, 10],
300 None,
301 None,
302 );
303 let dq = qw.dequantize();
304 assert_eq!(dq.len(), weights.len());
305 // Max quantisation error should be ≤ scale/2.
306 let max_err = weights.iter().zip(dq.iter())
307 .map(|(a, b)| (a - b).abs())
308 .fold(0.0f32, f32::max);
309 assert!(max_err <= qw.scale / 2.0 + 1e-6,
310 "Max quant error {max_err} > scale/2 = {}", qw.scale / 2.0);
311 }
312
313 #[test]
314 fn test_compression_ratio() {
315 let weights: Vec<f32> = vec![1.0f32; 1024];
316 let qw = QuantizedWeights::from_f32("test".into(), &weights, vec![32, 32], None, None);
317 assert!((qw.compression_ratio() - 4.0).abs() < 1e-5,
318 "INT8 should compress ~4x vs FP32");
319 }
320
321 #[test]
322 fn test_save_load_roundtrip() {
323 let qm = QuantizedModel {
324 config_json: "{}".into(),
325 layers: vec![QuantizedWeights {
326 name: "l1".into(),
327 weights_i8: vec![1, -2, 3],
328 shape: vec![1, 3],
329 scale: 0.01,
330 bias: None,
331 }],
332 scheme: "symmetric_per_tensor_int8".into(),
333 total_quantized_bytes: 3,
334 total_fp32_bytes: 12,
335 };
336 let tmp = tempfile::NamedTempFile::new().unwrap();
337 qm.save(tmp.path()).unwrap();
338 let loaded = QuantizedModel::load(tmp.path()).unwrap();
339 assert_eq!(loaded.layers[0].name, "l1");
340 assert_eq!(loaded.layers[0].weights_i8, vec![1, -2, 3]);
341 }
342}