Skip to main content

svod_tensor/nn/
quantize.rs

1//! Quantization operations (clamp-cast, quantized conv/matmul).
2
3use bon::bon;
4use svod_dtype::DType;
5
6use crate::Tensor;
7
8type Result<T> = crate::Result<T>;
9
10#[bon]
11impl Tensor {
12    /// Clamp to the representable range of `dtype`, then cast.
13    ///
14    /// Values outside the target type's range are saturated to its min/max
15    /// before casting, preventing overflow wrap-around.
16    ///
17    /// # Examples
18    ///
19    /// ```
20    /// # use svod_tensor::Tensor;
21    /// # use svod_dtype::DType;
22    /// let x = Tensor::from_slice([300.0f32, -10.0, 128.0]);
23    /// let mut y = x.clamp_cast(DType::UInt8).unwrap();
24    /// y.realize().unwrap();
25    /// let vals = y.as_vec::<u8>().unwrap();
26    /// assert_eq!(vals, vec![255, 0, 128]);
27    /// ```
28    pub fn clamp_cast(&self, dtype: DType) -> Result<Self> {
29        let min = Tensor::const_(dtype.min_value(), self.uop().dtype());
30        let max = Tensor::const_(dtype.max_value(), self.uop().dtype());
31        self.clamp().min(&min).max(&max).call()?.cast(dtype)
32    }
33
34    /// Quantized convolution: zero-point–adjust inputs, convolve in int32,
35    /// rescale and requantize to the output dtype.
36    ///
37    /// Implements the ONNX QLinearConv operator. The flow is:
38    /// 1. Subtract zero points from input and weights
39    /// 2. Perform integer convolution
40    /// 3. Rescale by `(x_scale * w_scale) / y_scale` and add `y_zero_point`
41    ///
42    /// # Examples
43    ///
44    /// ```
45    /// # use svod_tensor::Tensor;
46    /// # use svod_dtype::DType;
47    /// # use ndarray::Array4;
48    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 128u8));
49    /// let x_scale = Tensor::from_slice([0.1f32]);
50    /// let x_zp = Tensor::from_slice([128u8]);
51    /// let weight = Tensor::from_ndarray(&Array4::from_elem((1, 1, 1, 1), 128u8));
52    /// let w_scale = Tensor::from_slice([0.1f32]);
53    /// let w_zp = Tensor::from_slice([128u8]);
54    /// let y_scale = Tensor::from_slice([0.1f32]);
55    /// let y_zp = Tensor::from_slice([128u8]);
56    /// let y = x.qlinear_conv()
57    ///     .x_scale(&x_scale).x_zero_point(&x_zp)
58    ///     .weight(&weight).w_scale(&w_scale).w_zero_point(&w_zp)
59    ///     .y_scale(&y_scale).y_zero_point(&y_zp)
60    ///     .call()
61    ///     .unwrap();
62    /// let shape: Vec<usize> = y.shape().unwrap().iter()
63    ///     .map(|d| d.as_const().unwrap()).collect();
64    /// assert_eq!(shape, vec![1, 1, 3, 3]);
65    /// ```
66    #[builder]
67    pub fn qlinear_conv(
68        &self,
69        x_scale: &Tensor,
70        x_zero_point: &Tensor,
71        weight: &Tensor,
72        w_scale: &Tensor,
73        w_zero_point: &Tensor,
74        y_scale: &Tensor,
75        y_zero_point: &Tensor,
76        bias: Option<&Tensor>,
77        #[builder(default)] auto_pad: super::AutoPad,
78        #[builder(default = 1)] group: usize,
79        kernel_shape: Option<&[usize]>,
80        pads: Option<&[i64]>,
81        strides: Option<&[i64]>,
82        dilations: Option<&[i64]>,
83    ) -> Result<Tensor> {
84        let adj_x = self.cast(DType::Int32)?.try_sub(&x_zero_point.cast(DType::Int32)?)?;
85        let w_i32 = weight.cast(DType::Int32)?;
86        let w_zp = reshape_per_channel(&w_zero_point.cast(DType::Int32)?, w_i32.ndim()?)?;
87        let adj_w = w_i32.try_sub(&w_zp)?;
88        let conv_out = adj_x
89            .conv()
90            .weight(&adj_w)
91            .maybe_bias(bias)
92            .auto_pad(auto_pad)
93            .group(group)
94            .maybe_kernel_shape(kernel_shape)
95            .maybe_pads(pads)
96            .maybe_strides(strides)
97            .maybe_dilations(dilations)
98            .call()?;
99        requantize(&conv_out, &[x_scale, w_scale], y_scale, y_zero_point)
100    }
101
102    /// Integer convolution: zero-point–adjust inputs and convolve in int32.
103    /// No rescaling — returns raw int32 result.
104    ///
105    /// Implements the ONNX ConvInteger operator. Subtracts optional zero points
106    /// from input and weights, then convolves in int32. Unlike `qlinear_conv`,
107    /// no output rescaling is applied.
108    ///
109    /// # Examples
110    ///
111    /// ```
112    /// # use svod_tensor::Tensor;
113    /// # use svod_dtype::DType;
114    /// # use ndarray::Array4;
115    /// let x = Tensor::from_ndarray(&Array4::from_elem((1, 1, 3, 3), 10u8));
116    /// let weight = Tensor::from_ndarray(&Array4::from_elem((1, 1, 1, 1), 1u8));
117    /// let y = x.conv_integer().weight(&weight).call().unwrap();
118    /// let shape: Vec<usize> = y.shape().unwrap().iter()
119    ///     .map(|d| d.as_const().unwrap()).collect();
120    /// assert_eq!(shape, vec![1, 1, 3, 3]);
121    /// ```
122    #[builder]
123    pub fn conv_integer(
124        &self,
125        weight: &Tensor,
126        x_zero_point: Option<&Tensor>,
127        w_zero_point: Option<&Tensor>,
128        bias: Option<&Tensor>,
129        #[builder(default)] auto_pad: super::AutoPad,
130        #[builder(default = 1)] group: usize,
131        kernel_shape: Option<&[usize]>,
132        pads: Option<&[i64]>,
133        strides: Option<&[i64]>,
134        dilations: Option<&[i64]>,
135    ) -> Result<Tensor> {
136        let adj_x = if let Some(zp) = x_zero_point {
137            self.cast(DType::Int32)?.try_sub(&zp.cast(DType::Int32)?)?
138        } else {
139            self.cast(DType::Int32)?
140        };
141        let w_i32 = weight.cast(DType::Int32)?;
142        let adj_w = if let Some(zp) = w_zero_point {
143            let w_zp = reshape_per_channel(&zp.cast(DType::Int32)?, w_i32.ndim()?)?;
144            w_i32.try_sub(&w_zp)?
145        } else {
146            w_i32
147        };
148        adj_x
149            .conv()
150            .weight(&adj_w)
151            .maybe_bias(bias)
152            .auto_pad(auto_pad)
153            .group(group)
154            .maybe_kernel_shape(kernel_shape)
155            .maybe_pads(pads)
156            .maybe_strides(strides)
157            .maybe_dilations(dilations)
158            .call()
159    }
160
161    /// Quantized matrix multiplication: zero-point–adjust inputs, matmul in int32,
162    /// rescale and requantize to the output dtype.
163    ///
164    /// Implements the ONNX QLinearMatMul operator. The flow is:
165    /// 1. Subtract zero points from both inputs
166    /// 2. Perform integer matrix multiplication
167    /// 3. Rescale by `(a_scale * b_scale) / y_scale` and add `y_zero_point`
168    ///
169    /// # Examples
170    ///
171    /// ```
172    /// # use svod_tensor::Tensor;
173    /// # use svod_dtype::DType;
174    /// # use ndarray::Array2;
175    /// let a = Tensor::from_ndarray(&Array2::from_elem((2, 3), 128u8));
176    /// let a_scale = Tensor::from_slice([0.1f32]);
177    /// let a_zp = Tensor::from_slice([128u8]);
178    /// let b = Tensor::from_ndarray(&Array2::from_elem((3, 4), 128u8));
179    /// let b_scale = Tensor::from_slice([0.1f32]);
180    /// let b_zp = Tensor::from_slice([128u8]);
181    /// let y_scale = Tensor::from_slice([0.1f32]);
182    /// let y_zp = Tensor::from_slice([128u8]);
183    /// let y = a.qlinear_matmul()
184    ///     .a_scale(&a_scale).a_zero_point(&a_zp)
185    ///     .b(&b).b_scale(&b_scale).b_zero_point(&b_zp)
186    ///     .y_scale(&y_scale).y_zero_point(&y_zp)
187    ///     .call()
188    ///     .unwrap();
189    /// let shape: Vec<usize> = y.shape().unwrap().iter()
190    ///     .map(|d| d.as_const().unwrap()).collect();
191    /// assert_eq!(shape, vec![2, 4]);
192    /// ```
193    #[builder]
194    pub fn qlinear_matmul(
195        &self,
196        a_scale: &Tensor,
197        a_zero_point: &Tensor,
198        b: &Tensor,
199        b_scale: &Tensor,
200        b_zero_point: &Tensor,
201        y_scale: &Tensor,
202        y_zero_point: &Tensor,
203    ) -> Result<Tensor> {
204        let adj_a = self.cast(DType::Int32)?.try_sub(&a_zero_point.cast(DType::Int32)?)?;
205        let adj_b = b.cast(DType::Int32)?.try_sub(&b_zero_point.cast(DType::Int32)?)?;
206        let out = adj_a.matmul(&adj_b)?;
207        requantize(&out, &[a_scale, b_scale], y_scale, y_zero_point)
208    }
209}
210
211/// Reshape a per-channel zero point `(C,)` to broadcast against a weight
212/// tensor `(C, ...)` by appending singleton dimensions.
213fn reshape_per_channel(zp: &Tensor, target_ndim: usize) -> Result<Tensor> {
214    let zp_ndim = zp.ndim()?;
215    if zp_ndim == 0 || zp_ndim == target_ndim {
216        return Ok(zp.clone());
217    }
218    let mut shape: Vec<isize> = vec![-1];
219    shape.extend(std::iter::repeat_n(1, target_ndim - 1));
220    zp.try_reshape(&shape)
221}
222
223/// Rescale an integer result and requantize to the output zero-point's dtype.
224///
225/// No clamping: overflow means broken calibration — let it surface as garbage
226/// rather than silently saturating to boundary values.
227/// Round → Int32 → target dtype (int-to-int trunc wraps naturally).
228fn requantize(int_result: &Tensor, scales: &[&Tensor], out_scale: &Tensor, out_zero_point: &Tensor) -> Result<Tensor> {
229    let out_dtype = out_zero_point.uop().dtype();
230    let scale_dtype = out_scale.uop().dtype();
231    // Compute combined scale with explicit rounding to the scale's native
232    // dtype between operations. LLVM promotes _Float16 to float for
233    // arithmetic on x86 and may skip the intermediate fptrunc, keeping
234    // float32 precision. Roundtripping through float64→scale_dtype after
235    // each step forces correct intermediate rounding (matching numpy).
236    let mut combined = scales[0].cast(DType::Float64)?;
237    for s in &scales[1..] {
238        combined = combined.try_mul(&s.cast(DType::Float64)?)?.cast(scale_dtype.clone())?.cast(DType::Float64)?;
239    }
240    combined = combined.try_div(&out_scale.cast(DType::Float64)?)?.cast(scale_dtype.clone())?;
241    // Promote both operands to f64 for the final multiply (int32 * f16 → f64 in numpy)
242    let rescaled = int_result
243        .cast(DType::Float64)?
244        .try_mul(&combined.cast(DType::Float64)?)?
245        .try_add(&out_zero_point.cast(DType::Float64)?)?
246        .round()?;
247    // Float → Int32 (safe range) → target dtype (int trunc wraps)
248    rescaled.cast(DType::Int32)?.cast(out_dtype)
249}