Skip to main content

sensorlm/quantization/
int8.rs

1//! INT8 post-training quantisation (PTQ).
2//!
3//! # Overview
4//!
5//! Post-training quantisation converts the FP32 weights of a trained model to
6//! INT8 with minimal accuracy degradation.  The procedure is:
7//!
8//! 1. **Calibration** – run the model on a small representative dataset and
9//!    collect the per-layer minimum and maximum activation / weight values.
10//! 2. **Quantise** – for each `Linear` weight matrix compute a symmetric
11//!    INT8 scale and zero-point, then represent weights as `i8` values.
12//! 3. **Dequantise at runtime** – before the matrix multiplication, convert
13//!    `i8` back to `f32` using the stored scale.
14//!
15//! # Quantisation formula (symmetric per-tensor)
16//!
17//! ```text
18//! scale = max(|W|) / 127
19//! W_q   = round(W / scale)  ∈ [-127, 127]
20//! W_dq  = W_q * scale        (approx. original W)
21//! ```
22//!
23//! # INT8 linear layer
24//!
25//! The [`QuantizedLinear`] module stores weights as `i8` with a per-tensor
26//! scale and reconstructs `f32` on the fly.  This halves the model's memory
27//! footprint and can accelerate inference when hardware INT8 GEMM is
28//! available.
29//!
30//! # Limitations
31//!
32//! * This is a **weight-only** quantisation scheme (activations remain FP32).
33//! * Full activation quantisation would require inserting `QuantizeAct` nodes
34//!   throughout the graph – left as a future extension.
35//! * The Burn framework does not yet expose native INT8 GEMM kernels; the
36//!   dequantise-then-multiply approach used here is correctness-demonstrating
37//!   but does not provide a runtime speedup until WGPU INT8 kernels land.
38
39use std::{fs, path::Path};
40
41
42use serde::{Deserialize, Serialize};
43
44// ---------------------------------------------------------------------------
45// Per-layer calibration statistics
46// ---------------------------------------------------------------------------
47
48/// Calibration data collected from one `Linear` layer.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct LayerCalibration {
51    /// Layer name (e.g. `"sensor_encoder.blocks.0.attn.q_proj"`).
52    pub name: String,
53    /// Minimum observed weight value.
54    pub w_min: f32,
55    /// Maximum observed weight value.
56    pub w_max: f32,
57    /// Minimum observed input activation value.
58    pub act_min: f32,
59    /// Maximum observed input activation value.
60    pub act_max: f32,
61}
62
63impl LayerCalibration {
64    /// Compute the symmetric INT8 weight scale.
65    ///
66    /// `scale = max(|w_min|, |w_max|) / 127`
67    pub fn weight_scale(&self) -> f32 {
68        self.w_min.abs().max(self.w_max.abs()) / 127.0
69    }
70}
71
72// ---------------------------------------------------------------------------
73// Quantised weight representation
74// ---------------------------------------------------------------------------
75
76/// INT8 quantised weights for a single `Linear` layer.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct QuantizedWeights {
79    /// Layer name.
80    pub name: String,
81    /// Quantised weight values in row-major order (i8 serialised as i32 for
82    /// serde compatibility).
83    pub weights_i8: Vec<i32>,
84    /// Weight tensor shape `[out_features, in_features]`.
85    pub shape: Vec<usize>,
86    /// Per-tensor scale factor.
87    pub scale: f32,
88    /// Bias (kept in FP32).
89    pub bias: Option<Vec<f32>>,
90}
91
92impl QuantizedWeights {
93    /// Quantise an FP32 weight matrix.
94    ///
95    /// # Arguments
96    ///
97    /// * `name`   – Layer name.
98    /// * `w`      – FP32 weights in row-major layout.
99    /// * `shape`  – `[out_features, in_features]`.
100    /// * `bias`   – Optional FP32 bias vector.
101    /// * `scale`  – If `None`, computed from the max absolute weight.
102    pub fn from_f32(
103        name: String,
104        w: &[f32],
105        shape: Vec<usize>,
106        bias: Option<Vec<f32>>,
107        scale: Option<f32>,
108    ) -> Self {
109        let max_abs = w.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
110        let s = scale.unwrap_or(max_abs / 127.0).max(1e-8);
111
112        let weights_i8: Vec<i32> = w
113            .iter()
114            .map(|&x| (x / s).round().clamp(-127.0, 127.0) as i32)
115            .collect();
116
117        Self { name, weights_i8, shape, scale: s, bias }
118    }
119
120    /// Dequantise to FP32.
121    pub fn dequantize(&self) -> Vec<f32> {
122        self.weights_i8
123            .iter()
124            .map(|&q| q as f32 * self.scale)
125            .collect()
126    }
127
128    /// Memory used by the quantised weights in bytes.
129    pub fn size_bytes(&self) -> usize {
130        self.weights_i8.len() // i8: 1 byte each
131    }
132
133    /// Memory the original FP32 weights would have used.
134    pub fn original_size_bytes(&self) -> usize {
135        self.weights_i8.len() * 4
136    }
137
138    /// Compression ratio vs. FP32.
139    pub fn compression_ratio(&self) -> f32 {
140        self.original_size_bytes() as f32 / self.size_bytes() as f32
141    }
142}
143
144// ---------------------------------------------------------------------------
145// Quantised model manifest
146// ---------------------------------------------------------------------------
147
148/// Collection of quantised weights for an entire model.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct QuantizedModel {
151    /// Model configuration JSON (serialised [`crate::config::SensorLMConfig`]).
152    pub config_json: String,
153    /// Quantised weights for every Linear layer in the model.
154    pub layers: Vec<QuantizedWeights>,
155    /// Quantisation scheme used.
156    pub scheme: String,
157    /// Total size of quantised weights in bytes.
158    pub total_quantized_bytes: usize,
159    /// Total size of original FP32 weights in bytes.
160    pub total_fp32_bytes: usize,
161}
162
163impl QuantizedModel {
164    /// Overall compression ratio (FP32 → INT8).
165    pub fn compression_ratio(&self) -> f32 {
166        self.total_fp32_bytes as f32 / self.total_quantized_bytes.max(1) as f32
167    }
168
169    /// Save the quantised model to a JSON file.
170    pub fn save(&self, path: &Path) -> crate::error::Result<()> {
171        let json = serde_json::to_string_pretty(self)?;
172        fs::write(path, json)?;
173        Ok(())
174    }
175
176    /// Load a quantised model from a JSON file.
177    pub fn load(path: &Path) -> crate::error::Result<Self> {
178        let json = fs::read_to_string(path)?;
179        let model = serde_json::from_str(&json)?;
180        Ok(model)
181    }
182}
183
184// ---------------------------------------------------------------------------
185// PTQ pipeline
186// ---------------------------------------------------------------------------
187
188/// Calibration pass over a small dataset to collect weight statistics.
189///
190/// In a production implementation this would hook into the model's forward
191/// pass via observers.  Here we collect statistics directly from the stored
192/// weight tensors – sufficient for symmetric per-tensor weight quantisation
193/// since we do not require input activation statistics.
194pub struct Calibrator {
195    calibrations: Vec<LayerCalibration>,
196}
197
198impl Calibrator {
199    /// Create a new calibrator.
200    pub fn new() -> Self {
201        Self { calibrations: Vec::new() }
202    }
203
204    /// Record a weight tensor from one linear layer.
205    pub fn record_layer(&mut self, name: String, weights: &[f32]) {
206        let w_min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
207        let w_max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
208        self.calibrations.push(LayerCalibration {
209            name,
210            w_min,
211            w_max,
212            act_min: -1.0, // placeholder (weight-only PTQ)
213            act_max: 1.0,
214        });
215    }
216
217    /// Consume the calibrator and produce per-layer scale factors.
218    pub fn finish(self) -> Vec<LayerCalibration> {
219        self.calibrations
220    }
221}
222
223impl Default for Calibrator {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229// ---------------------------------------------------------------------------
230// Quantise a flat representation of model weights
231// ---------------------------------------------------------------------------
232
233/// Quantise a named list of FP32 weight tensors.
234///
235/// This function takes the flat `(name, weights, shape, bias)` representation
236/// that would be extracted from a `SensorLMModel` and produces a
237/// [`QuantizedModel`].
238///
239/// # Arguments
240///
241/// * `config_json`  – JSON string of the model config.
242/// * `layers`       – Iterator of `(name, fp32_weights, shape, optional_bias)`.
243pub fn quantize_model_weights(
244    config_json: String,
245    layers: impl IntoIterator<Item = (String, Vec<f32>, Vec<usize>, Option<Vec<f32>>)>,
246) -> QuantizedModel {
247    let mut quantized_layers = Vec::new();
248    let mut total_fp32_bytes = 0usize;
249    let mut total_quantized_bytes = 0usize;
250
251    for (name, weights, shape, bias) in layers {
252        let qw = QuantizedWeights::from_f32(name, &weights, shape, bias, None);
253        total_fp32_bytes += qw.original_size_bytes();
254        total_quantized_bytes += qw.size_bytes();
255        quantized_layers.push(qw);
256    }
257
258    QuantizedModel {
259        config_json,
260        layers: quantized_layers,
261        scheme: "symmetric_per_tensor_int8".to_string(),
262        total_quantized_bytes,
263        total_fp32_bytes,
264    }
265}
266
267// ---------------------------------------------------------------------------
268// FP16 export
269// ---------------------------------------------------------------------------
270
271/// Convert an FP32 weight vector to FP16 (f16 represented as u16 bits).
272///
273/// FP16 halves the model's memory footprint with minimal accuracy loss and
274/// is natively supported by most modern GPUs.
275pub fn fp32_to_fp16_bits(weights: &[f32]) -> Vec<u16> {
276    weights
277        .iter()
278        .map(|&x| half::f16::from_f32(x).to_bits())
279        .collect()
280}
281
282/// Convert an FP16 weight vector back to FP32.
283pub fn fp16_bits_to_fp32(bits: &[u16]) -> Vec<f32> {
284    bits.iter()
285        .map(|&b| half::f16::from_bits(b).to_f32())
286        .collect()
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_quantize_roundtrip() {
295        let weights: Vec<f32> = (0..100).map(|i| i as f32 * 0.01 - 0.5).collect();
296        let qw = QuantizedWeights::from_f32(
297            "test".to_string(),
298            &weights,
299            vec![10, 10],
300            None,
301            None,
302        );
303        let dq = qw.dequantize();
304        assert_eq!(dq.len(), weights.len());
305        // Max quantisation error should be ≤ scale/2.
306        let max_err = weights.iter().zip(dq.iter())
307            .map(|(a, b)| (a - b).abs())
308            .fold(0.0f32, f32::max);
309        assert!(max_err <= qw.scale / 2.0 + 1e-6,
310            "Max quant error {max_err} > scale/2 = {}", qw.scale / 2.0);
311    }
312
313    #[test]
314    fn test_compression_ratio() {
315        let weights: Vec<f32> = vec![1.0f32; 1024];
316        let qw = QuantizedWeights::from_f32("test".into(), &weights, vec![32, 32], None, None);
317        assert!((qw.compression_ratio() - 4.0).abs() < 1e-5,
318            "INT8 should compress ~4x vs FP32");
319    }
320
321    #[test]
322    fn test_save_load_roundtrip() {
323        let qm = QuantizedModel {
324            config_json: "{}".into(),
325            layers: vec![QuantizedWeights {
326                name: "l1".into(),
327                weights_i8: vec![1, -2, 3],
328                shape: vec![1, 3],
329                scale: 0.01,
330                bias: None,
331            }],
332            scheme: "symmetric_per_tensor_int8".into(),
333            total_quantized_bytes: 3,
334            total_fp32_bytes: 12,
335        };
336        let tmp = tempfile::NamedTempFile::new().unwrap();
337        qm.save(tmp.path()).unwrap();
338        let loaded = QuantizedModel::load(tmp.path()).unwrap();
339        assert_eq!(loaded.layers[0].name, "l1");
340        assert_eq!(loaded.layers[0].weights_i8, vec![1, -2, 3]);
341    }
342}