svod_tensor/nn/quantize.rs
1//! Quantization operations (clamp-cast, quantized conv/matmul).
2
3use bon::bon;
4use svod_dtype::DType;
5
6use crate::Tensor;
7
8type Result<T> = crate::Result<T>;
9
10#[bon]
11impl Tensor {
12 /// Clamp to the representable range of `dtype`, then cast.
13 ///
14 /// Values outside the target type's range are saturated to its min/max
15 /// before casting, preventing overflow wrap-around.
16 ///
17 /// # Examples
18 ///
19 /// ```
20 /// # use svod_tensor::Tensor;
21 /// # use svod_dtype::DType;
22 /// let x = Tensor::from_slice([300.0f32, -10.0, 128.0]);
23 /// let mut y = x.clamp_cast(DType::UInt8).unwrap();
24 /// y.realize().unwrap();
25 /// let vals = y.as_vec::<u8>().unwrap();
26 /// assert_eq!(vals, vec![255, 0, 128]);
27 /// ```
28 pub fn clamp_cast(&self, dtype: DType) -> Result<Self> {
29 let min = Tensor::const_(dtype.min_value(), self.uop().dtype());
30 let max = Tensor::const_(dtype.max_value(), self.uop().dtype());
31 self.clamp().min(&min).max(&max).call()?.cast(dtype)
32 }
33
34 /// Quantized convolution: zero-point–adjust inputs, convolve in int32,
35 /// rescale and requantize to the output dtype.
36 ///
37 /// Implements the ONNX QLinearConv operator. The flow is:
38 /// 1. Subtract zero points from input and weights
39 /// 2. Perform integer convolution
40 /// 3. Rescale by `(x_scale * w_scale) / y_scale` and add `y_zero_point`
41 ///
42 /// # Examples
43 ///
44 /// ```
45 /// # use svod_tensor::Tensor;
46 /// # use svod_dtype::DType;
47 /// # use ndarray::Array4;
48 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 128u8));
49 /// let x_scale = Tensor::from_slice([0.1f32]);
50 /// let x_zp = Tensor::from_slice([128u8]);
51 /// let weight = Tensor::from_ndarray(&Array4::from_elem((1, 1, 1, 1), 128u8));
52 /// let w_scale = Tensor::from_slice([0.1f32]);
53 /// let w_zp = Tensor::from_slice([128u8]);
54 /// let y_scale = Tensor::from_slice([0.1f32]);
55 /// let y_zp = Tensor::from_slice([128u8]);
56 /// let y = x.qlinear_conv()
57 /// .x_scale(&x_scale).x_zero_point(&x_zp)
58 /// .weight(&weight).w_scale(&w_scale).w_zero_point(&w_zp)
59 /// .y_scale(&y_scale).y_zero_point(&y_zp)
60 /// .call()
61 /// .unwrap();
62 /// let shape: Vec<usize> = y.shape().unwrap().iter()
63 /// .map(|d| d.as_const().unwrap()).collect();
64 /// assert_eq!(shape, vec![1, 1, 3, 3]);
65 /// ```
66 #[builder]
67 pub fn qlinear_conv(
68 &self,
69 x_scale: &Tensor,
70 x_zero_point: &Tensor,
71 weight: &Tensor,
72 w_scale: &Tensor,
73 w_zero_point: &Tensor,
74 y_scale: &Tensor,
75 y_zero_point: &Tensor,
76 bias: Option<&Tensor>,
77 #[builder(default)] auto_pad: super::AutoPad,
78 #[builder(default = 1)] group: usize,
79 kernel_shape: Option<&[usize]>,
80 pads: Option<&[i64]>,
81 strides: Option<&[i64]>,
82 dilations: Option<&[i64]>,
83 ) -> Result<Tensor> {
84 let adj_x = self.cast(DType::Int32)?.try_sub(&x_zero_point.cast(DType::Int32)?)?;
85 let w_i32 = weight.cast(DType::Int32)?;
86 let w_zp = reshape_per_channel(&w_zero_point.cast(DType::Int32)?, w_i32.ndim()?)?;
87 let adj_w = w_i32.try_sub(&w_zp)?;
88 let conv_out = adj_x
89 .conv()
90 .weight(&adj_w)
91 .maybe_bias(bias)
92 .auto_pad(auto_pad)
93 .group(group)
94 .maybe_kernel_shape(kernel_shape)
95 .maybe_pads(pads)
96 .maybe_strides(strides)
97 .maybe_dilations(dilations)
98 .call()?;
99 requantize(&conv_out, &[x_scale, w_scale], y_scale, y_zero_point)
100 }
101
102 /// Integer convolution: zero-point–adjust inputs and convolve in int32.
103 /// No rescaling — returns raw int32 result.
104 ///
105 /// Implements the ONNX ConvInteger operator. Subtracts optional zero points
106 /// from input and weights, then convolves in int32. Unlike `qlinear_conv`,
107 /// no output rescaling is applied.
108 ///
109 /// # Examples
110 ///
111 /// ```
112 /// # use svod_tensor::Tensor;
113 /// # use svod_dtype::DType;
114 /// # use ndarray::Array4;
115 /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 10u8));
116 /// let weight = Tensor::from_ndarray(&Array4::from_elem((1, 1, 1, 1), 1u8));
117 /// let y = x.conv_integer().weight(&weight).call().unwrap();
118 /// let shape: Vec<usize> = y.shape().unwrap().iter()
119 /// .map(|d| d.as_const().unwrap()).collect();
120 /// assert_eq!(shape, vec![1, 1, 3, 3]);
121 /// ```
122 #[builder]
123 pub fn conv_integer(
124 &self,
125 weight: &Tensor,
126 x_zero_point: Option<&Tensor>,
127 w_zero_point: Option<&Tensor>,
128 bias: Option<&Tensor>,
129 #[builder(default)] auto_pad: super::AutoPad,
130 #[builder(default = 1)] group: usize,
131 kernel_shape: Option<&[usize]>,
132 pads: Option<&[i64]>,
133 strides: Option<&[i64]>,
134 dilations: Option<&[i64]>,
135 ) -> Result<Tensor> {
136 let adj_x = if let Some(zp) = x_zero_point {
137 self.cast(DType::Int32)?.try_sub(&zp.cast(DType::Int32)?)?
138 } else {
139 self.cast(DType::Int32)?
140 };
141 let w_i32 = weight.cast(DType::Int32)?;
142 let adj_w = if let Some(zp) = w_zero_point {
143 let w_zp = reshape_per_channel(&zp.cast(DType::Int32)?, w_i32.ndim()?)?;
144 w_i32.try_sub(&w_zp)?
145 } else {
146 w_i32
147 };
148 adj_x
149 .conv()
150 .weight(&adj_w)
151 .maybe_bias(bias)
152 .auto_pad(auto_pad)
153 .group(group)
154 .maybe_kernel_shape(kernel_shape)
155 .maybe_pads(pads)
156 .maybe_strides(strides)
157 .maybe_dilations(dilations)
158 .call()
159 }
160
161 /// Quantized matrix multiplication: zero-point–adjust inputs, matmul in int32,
162 /// rescale and requantize to the output dtype.
163 ///
164 /// Implements the ONNX QLinearMatMul operator. The flow is:
165 /// 1. Subtract zero points from both inputs
166 /// 2. Perform integer matrix multiplication
167 /// 3. Rescale by `(a_scale * b_scale) / y_scale` and add `y_zero_point`
168 ///
169 /// # Examples
170 ///
171 /// ```
172 /// # use svod_tensor::Tensor;
173 /// # use svod_dtype::DType;
174 /// # use ndarray::Array2;
175 /// let a = Tensor::from_ndarray(&Array2::from_elem((2, 3), 128u8));
176 /// let a_scale = Tensor::from_slice([0.1f32]);
177 /// let a_zp = Tensor::from_slice([128u8]);
178 /// let b = Tensor::from_ndarray(&Array2::from_elem((3, 4), 128u8));
179 /// let b_scale = Tensor::from_slice([0.1f32]);
180 /// let b_zp = Tensor::from_slice([128u8]);
181 /// let y_scale = Tensor::from_slice([0.1f32]);
182 /// let y_zp = Tensor::from_slice([128u8]);
183 /// let y = a.qlinear_matmul()
184 /// .a_scale(&a_scale).a_zero_point(&a_zp)
185 /// .b(&b).b_scale(&b_scale).b_zero_point(&b_zp)
186 /// .y_scale(&y_scale).y_zero_point(&y_zp)
187 /// .call()
188 /// .unwrap();
189 /// let shape: Vec<usize> = y.shape().unwrap().iter()
190 /// .map(|d| d.as_const().unwrap()).collect();
191 /// assert_eq!(shape, vec![2, 4]);
192 /// ```
193 #[builder]
194 pub fn qlinear_matmul(
195 &self,
196 a_scale: &Tensor,
197 a_zero_point: &Tensor,
198 b: &Tensor,
199 b_scale: &Tensor,
200 b_zero_point: &Tensor,
201 y_scale: &Tensor,
202 y_zero_point: &Tensor,
203 ) -> Result<Tensor> {
204 let adj_a = self.cast(DType::Int32)?.try_sub(&a_zero_point.cast(DType::Int32)?)?;
205 let adj_b = b.cast(DType::Int32)?.try_sub(&b_zero_point.cast(DType::Int32)?)?;
206 let out = adj_a.matmul(&adj_b)?;
207 requantize(&out, &[a_scale, b_scale], y_scale, y_zero_point)
208 }
209}
210
211/// Reshape a per-channel zero point `(C,)` to broadcast against a weight
212/// tensor `(C, ...)` by appending singleton dimensions.
213fn reshape_per_channel(zp: &Tensor, target_ndim: usize) -> Result<Tensor> {
214 let zp_ndim = zp.ndim()?;
215 if zp_ndim == 0 || zp_ndim == target_ndim {
216 return Ok(zp.clone());
217 }
218 let mut shape: Vec<isize> = vec![-1];
219 shape.extend(std::iter::repeat_n(1, target_ndim - 1));
220 zp.try_reshape(&shape)
221}
222
223/// Rescale an integer result and requantize to the output zero-point's dtype.
224///
225/// No clamping: overflow means broken calibration — let it surface as garbage
226/// rather than silently saturating to boundary values.
227/// Round → Int32 → target dtype (int-to-int trunc wraps naturally).
228fn requantize(int_result: &Tensor, scales: &[&Tensor], out_scale: &Tensor, out_zero_point: &Tensor) -> Result<Tensor> {
229 let out_dtype = out_zero_point.uop().dtype();
230 let scale_dtype = out_scale.uop().dtype();
231 // Compute combined scale with explicit rounding to the scale's native
232 // dtype between operations. LLVM promotes _Float16 to float for
233 // arithmetic on x86 and may skip the intermediate fptrunc, keeping
234 // float32 precision. Roundtripping through float64→scale_dtype after
235 // each step forces correct intermediate rounding (matching numpy).
236 let mut combined = scales[0].cast(DType::Float64)?;
237 for s in &scales[1..] {
238 combined = combined.try_mul(&s.cast(DType::Float64)?)?.cast(scale_dtype.clone())?.cast(DType::Float64)?;
239 }
240 combined = combined.try_div(&out_scale.cast(DType::Float64)?)?.cast(scale_dtype.clone())?;
241 // Promote both operands to f64 for the final multiply (int32 * f16 → f64 in numpy)
242 let rescaled = int_result
243 .cast(DType::Float64)?
244 .try_mul(&combined.cast(DType::Float64)?)?
245 .try_add(&out_zero_point.cast(DType::Float64)?)?
246 .round()?;
247 // Float → Int32 (safe range) → target dtype (int trunc wraps)
248 rescaled.cast(DType::Int32)?.cast(out_dtype)
249}