tensorlogic_trustformers/quantization/linear.rs
1//! [`QuantizedLinear`]: an INT8 weight matrix with per-channel or per-tensor
2//! scale/zero_point.
3//!
4//! ## Design
5//!
6//! Weights are stored as `Array2<i8>` with shape `(out_features, in_features)`.
7//! The forward pass dequantizes the weight matrix on every call and then
8//! performs a standard f64 `matmul`. This is the *CPU-first honest cut*;
9//! integer-matmul (packed int8 GEMM) is a future follow-up.
10//!
11//! For `PerChannel` granularity each row (`output channel`) has its own
12//! `scale[c]` and `zero_point[c]`. For `PerTensor` a single pair applies to
13//! all elements.
14
15use ndarray::{Array1, Array2, Axis};
16use tensorlogic_scirs_backend::quantization::{
17 QuantizationGranularity, QuantizationParams, QuantizationType, QuantizedTensor,
18};
19
20/// Error type for quantization operations on linear layers.
21#[derive(Debug)]
22pub enum QuantizationError {
23 /// Weight matrix shape does not match expectations.
24 ShapeMismatch(String),
25 /// Quantization parameters are inconsistent.
26 InvalidParams(String),
27}
28
29impl std::fmt::Display for QuantizationError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 QuantizationError::ShapeMismatch(msg) => write!(f, "shape mismatch: {msg}"),
33 QuantizationError::InvalidParams(msg) => write!(f, "invalid params: {msg}"),
34 }
35 }
36}
37
38impl std::error::Error for QuantizationError {}
39
40/// A weight matrix quantized to i8, with per-channel or per-tensor
41/// `scale`/`zero_point`.
42///
43/// Forward pass: dequantize weights on the fly, then f64 matmul.
44///
45/// ## Layout
46///
47/// `weight_q` has shape `(out_features, in_features)` — the same convention
48/// as `LinearExpert::weights` (see `moe/expert.rs`).
49pub struct QuantizedLinear {
50 /// Quantized weights of shape `(out_features, in_features)`.
51 weight_q: Array2<i8>,
52 /// Scale factor per channel (length 1 for PerTensor, `out_features` for PerChannel).
53 scale: Vec<f64>,
54 /// Zero-point per channel.
55 zero_point: Vec<i32>,
56 /// Granularity used during quantization.
57 granularity: QuantizationGranularity,
58 /// Optional bias of length `out_features`.
59 bias: Option<Array1<f64>>,
60}
61
62impl QuantizedLinear {
63 /// Quantize an existing f64 weight matrix using the provided params.
64 ///
65 /// Only `Int8` quantization type is supported. Use
66 /// [`crate::quantization::calibrate_linear`] to produce `params`.
67 ///
68 /// # Errors
69 ///
70 /// - [`QuantizationError::InvalidParams`] if `qtype != Int8`.
71 /// - [`QuantizationError::ShapeMismatch`] if the weight is not 2-D or the
72 /// scale/zero_point vectors are the wrong length for `PerChannel`.
73 pub fn from_fp(
74 weight: &Array2<f64>,
75 params: &QuantizationParams,
76 ) -> Result<Self, QuantizationError> {
77 if params.qtype != QuantizationType::Int8 {
78 return Err(QuantizationError::InvalidParams(format!(
79 "only Int8 is supported, got {:?}",
80 params.qtype
81 )));
82 }
83
84 let (out_features, _in_features) = weight.dim();
85
86 // Validate per-channel scale length.
87 if params.granularity == QuantizationGranularity::PerChannel
88 && params.scale.len() != out_features
89 {
90 return Err(QuantizationError::ShapeMismatch(format!(
91 "PerChannel: scale.len()={} but out_features={}",
92 params.scale.len(),
93 out_features,
94 )));
95 }
96
97 // Call scirs-backend to get the quantized f64 array.
98 let weight_dyn = weight.clone().into_dyn();
99 let qt = QuantizedTensor::quantize(&weight_dyn, params.clone());
100
101 // Cast f64 quantized values to i8.
102 let weight_i8 = qt
103 .data
104 .mapv(|x| x as i8)
105 .into_dimensionality::<ndarray::Ix2>()
106 .map_err(|e| {
107 QuantizationError::ShapeMismatch(format!("dimensionality cast failed: {e}"))
108 })?;
109
110 Ok(Self {
111 weight_q: weight_i8,
112 scale: params.scale.clone(),
113 zero_point: params.zero_point.clone(),
114 granularity: params.granularity,
115 bias: None,
116 })
117 }
118
119 /// Attach a bias vector of length `out_features`.
120 ///
121 /// # Errors
122 ///
123 /// [`QuantizationError::ShapeMismatch`] if `bias.len() != out_features`.
124 pub fn with_bias(mut self, bias: Array1<f64>) -> Result<Self, QuantizationError> {
125 let out_features = self.weight_q.nrows();
126 if bias.len() != out_features {
127 return Err(QuantizationError::ShapeMismatch(format!(
128 "bias.len()={} but out_features={}",
129 bias.len(),
130 out_features
131 )));
132 }
133 self.bias = Some(bias);
134 Ok(self)
135 }
136
137 /// Dequantize and run matmul.
138 ///
139 /// Input `x` must have shape `[batch, in_features]`.
140 /// Output has shape `[batch, out_features]`.
141 ///
142 /// `fp = (q - zero_point[c]) * scale[c]` per element, where `c` is the
143 /// output channel (row) index when `granularity == PerChannel`, or `0`
144 /// for `PerTensor`.
145 pub fn forward(&self, x: &Array2<f64>) -> Array2<f64> {
146 let weight_fp = self.dequantize();
147 // matmul: x @ weight_fp.t() gives [batch, out_features]
148 let out = x.dot(&weight_fp.t());
149 match &self.bias {
150 Some(b) => out + b,
151 None => out,
152 }
153 }
154
155 /// Dequantize the stored i8 weights back to f64.
156 ///
157 /// For `PerTensor`: all elements use `scale[0]` / `zero_point[0]`.
158 /// For `PerChannel`: each row `c` uses `scale[c]` / `zero_point[c]`.
159 pub fn dequantize(&self) -> Array2<f64> {
160 let (out_features, in_features) = self.weight_q.dim();
161 let mut fp = Array2::<f64>::zeros((out_features, in_features));
162
163 match self.granularity {
164 QuantizationGranularity::PerTensor => {
165 let s = self.scale[0];
166 let zp = self.zero_point[0] as f64;
167 for (q_row, mut fp_row) in self
168 .weight_q
169 .axis_iter(Axis(0))
170 .zip(fp.axis_iter_mut(Axis(0)))
171 {
172 for (q_val, fp_val) in q_row.iter().zip(fp_row.iter_mut()) {
173 *fp_val = (*q_val as f64 - zp) * s;
174 }
175 }
176 }
177 QuantizationGranularity::PerChannel => {
178 for (c, (q_row, mut fp_row)) in self
179 .weight_q
180 .axis_iter(Axis(0))
181 .zip(fp.axis_iter_mut(Axis(0)))
182 .enumerate()
183 {
184 let s = self.scale.get(c).copied().unwrap_or(self.scale[0]);
185 let zp = self.zero_point.get(c).copied().unwrap_or(0) as f64;
186 for (q_val, fp_val) in q_row.iter().zip(fp_row.iter_mut()) {
187 *fp_val = (*q_val as f64 - zp) * s;
188 }
189 }
190 }
191 }
192
193 fp
194 }
195
196 /// Return the output feature dimension.
197 pub fn out_features(&self) -> usize {
198 self.weight_q.nrows()
199 }
200
201 /// Return the input feature dimension.
202 pub fn in_features(&self) -> usize {
203 self.weight_q.ncols()
204 }
205
206 /// Return the quantization granularity.
207 pub fn granularity(&self) -> QuantizationGranularity {
208 self.granularity
209 }
210
211 /// Return the scale factors (length 1 for PerTensor, `out_features` for PerChannel).
212 pub fn scales(&self) -> &[f64] {
213 &self.scale
214 }
215}