Skip to main content

tensorlogic_trustformers/quantization/
linear.rs

1//! [`QuantizedLinear`]: an INT8 weight matrix with per-channel or per-tensor
2//! scale/zero_point.
3//!
4//! ## Design
5//!
6//! Weights are stored as `Array2<i8>` with shape `(out_features, in_features)`.
7//! The forward pass dequantizes the weight matrix on every call and then
8//! performs a standard f64 `matmul`.  This is the *CPU-first honest cut*;
9//! integer-matmul (packed int8 GEMM) is a future follow-up.
10//!
11//! For `PerChannel` granularity each row (`output channel`) has its own
12//! `scale[c]` and `zero_point[c]`.  For `PerTensor` a single pair applies to
13//! all elements.
14
15use ndarray::{Array1, Array2, Axis};
16use tensorlogic_scirs_backend::quantization::{
17    QuantizationGranularity, QuantizationParams, QuantizationType, QuantizedTensor,
18};
19
20/// Error type for quantization operations on linear layers.
21#[derive(Debug)]
22pub enum QuantizationError {
23    /// Weight matrix shape does not match expectations.
24    ShapeMismatch(String),
25    /// Quantization parameters are inconsistent.
26    InvalidParams(String),
27}
28
29impl std::fmt::Display for QuantizationError {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            QuantizationError::ShapeMismatch(msg) => write!(f, "shape mismatch: {msg}"),
33            QuantizationError::InvalidParams(msg) => write!(f, "invalid params: {msg}"),
34        }
35    }
36}
37
38impl std::error::Error for QuantizationError {}
39
40/// A weight matrix quantized to i8, with per-channel or per-tensor
41/// `scale`/`zero_point`.
42///
43/// Forward pass: dequantize weights on the fly, then f64 matmul.
44///
45/// ## Layout
46///
47/// `weight_q` has shape `(out_features, in_features)` — the same convention
48/// as `LinearExpert::weights` (see `moe/expert.rs`).
49pub struct QuantizedLinear {
50    /// Quantized weights of shape `(out_features, in_features)`.
51    weight_q: Array2<i8>,
52    /// Scale factor per channel (length 1 for PerTensor, `out_features` for PerChannel).
53    scale: Vec<f64>,
54    /// Zero-point per channel.
55    zero_point: Vec<i32>,
56    /// Granularity used during quantization.
57    granularity: QuantizationGranularity,
58    /// Optional bias of length `out_features`.
59    bias: Option<Array1<f64>>,
60}
61
62impl QuantizedLinear {
63    /// Quantize an existing f64 weight matrix using the provided params.
64    ///
65    /// Only `Int8` quantization type is supported.  Use
66    /// [`crate::quantization::calibrate_linear`] to produce `params`.
67    ///
68    /// # Errors
69    ///
70    /// - [`QuantizationError::InvalidParams`] if `qtype != Int8`.
71    /// - [`QuantizationError::ShapeMismatch`] if the weight is not 2-D or the
72    ///   scale/zero_point vectors are the wrong length for `PerChannel`.
73    pub fn from_fp(
74        weight: &Array2<f64>,
75        params: &QuantizationParams,
76    ) -> Result<Self, QuantizationError> {
77        if params.qtype != QuantizationType::Int8 {
78            return Err(QuantizationError::InvalidParams(format!(
79                "only Int8 is supported, got {:?}",
80                params.qtype
81            )));
82        }
83
84        let (out_features, _in_features) = weight.dim();
85
86        // Validate per-channel scale length.
87        if params.granularity == QuantizationGranularity::PerChannel
88            && params.scale.len() != out_features
89        {
90            return Err(QuantizationError::ShapeMismatch(format!(
91                "PerChannel: scale.len()={} but out_features={}",
92                params.scale.len(),
93                out_features,
94            )));
95        }
96
97        // Call scirs-backend to get the quantized f64 array.
98        let weight_dyn = weight.clone().into_dyn();
99        let qt = QuantizedTensor::quantize(&weight_dyn, params.clone());
100
101        // Cast f64 quantized values to i8.
102        let weight_i8 = qt
103            .data
104            .mapv(|x| x as i8)
105            .into_dimensionality::<ndarray::Ix2>()
106            .map_err(|e| {
107                QuantizationError::ShapeMismatch(format!("dimensionality cast failed: {e}"))
108            })?;
109
110        Ok(Self {
111            weight_q: weight_i8,
112            scale: params.scale.clone(),
113            zero_point: params.zero_point.clone(),
114            granularity: params.granularity,
115            bias: None,
116        })
117    }
118
119    /// Attach a bias vector of length `out_features`.
120    ///
121    /// # Errors
122    ///
123    /// [`QuantizationError::ShapeMismatch`] if `bias.len() != out_features`.
124    pub fn with_bias(mut self, bias: Array1<f64>) -> Result<Self, QuantizationError> {
125        let out_features = self.weight_q.nrows();
126        if bias.len() != out_features {
127            return Err(QuantizationError::ShapeMismatch(format!(
128                "bias.len()={} but out_features={}",
129                bias.len(),
130                out_features
131            )));
132        }
133        self.bias = Some(bias);
134        Ok(self)
135    }
136
137    /// Dequantize and run matmul.
138    ///
139    /// Input `x` must have shape `[batch, in_features]`.
140    /// Output has shape `[batch, out_features]`.
141    ///
142    /// `fp = (q - zero_point[c]) * scale[c]` per element, where `c` is the
143    /// output channel (row) index when `granularity == PerChannel`, or `0`
144    /// for `PerTensor`.
145    pub fn forward(&self, x: &Array2<f64>) -> Array2<f64> {
146        let weight_fp = self.dequantize();
147        // matmul: x @ weight_fp.t() gives [batch, out_features]
148        let out = x.dot(&weight_fp.t());
149        match &self.bias {
150            Some(b) => out + b,
151            None => out,
152        }
153    }
154
155    /// Dequantize the stored i8 weights back to f64.
156    ///
157    /// For `PerTensor`: all elements use `scale[0]` / `zero_point[0]`.
158    /// For `PerChannel`: each row `c` uses `scale[c]` / `zero_point[c]`.
159    pub fn dequantize(&self) -> Array2<f64> {
160        let (out_features, in_features) = self.weight_q.dim();
161        let mut fp = Array2::<f64>::zeros((out_features, in_features));
162
163        match self.granularity {
164            QuantizationGranularity::PerTensor => {
165                let s = self.scale[0];
166                let zp = self.zero_point[0] as f64;
167                for (q_row, mut fp_row) in self
168                    .weight_q
169                    .axis_iter(Axis(0))
170                    .zip(fp.axis_iter_mut(Axis(0)))
171                {
172                    for (q_val, fp_val) in q_row.iter().zip(fp_row.iter_mut()) {
173                        *fp_val = (*q_val as f64 - zp) * s;
174                    }
175                }
176            }
177            QuantizationGranularity::PerChannel => {
178                for (c, (q_row, mut fp_row)) in self
179                    .weight_q
180                    .axis_iter(Axis(0))
181                    .zip(fp.axis_iter_mut(Axis(0)))
182                    .enumerate()
183                {
184                    let s = self.scale.get(c).copied().unwrap_or(self.scale[0]);
185                    let zp = self.zero_point.get(c).copied().unwrap_or(0) as f64;
186                    for (q_val, fp_val) in q_row.iter().zip(fp_row.iter_mut()) {
187                        *fp_val = (*q_val as f64 - zp) * s;
188                    }
189                }
190            }
191        }
192
193        fp
194    }
195
196    /// Return the output feature dimension.
197    pub fn out_features(&self) -> usize {
198        self.weight_q.nrows()
199    }
200
201    /// Return the input feature dimension.
202    pub fn in_features(&self) -> usize {
203        self.weight_q.ncols()
204    }
205
206    /// Return the quantization granularity.
207    pub fn granularity(&self) -> QuantizationGranularity {
208        self.granularity
209    }
210
211    /// Return the scale factors (length 1 for PerTensor, `out_features` for PerChannel).
212    pub fn scales(&self) -> &[f64] {
213        &self.scale
214    }
215}