Skip to main content

ruvector_cnn/layers/
quantized_linear.rs

1//! Quantized Linear (Fully Connected) Layer
2//!
3//! INT8 quantized linear layer with:
4//! - GEMM-based forward pass
5//! - Fused bias and requantization
6//! - Per-channel or per-tensor quantization
7
8use crate::{CnnError, CnnResult, Tensor};
9
10use super::Linear;
11
12/// Quantized Linear Layer
13///
14/// Performs matrix multiplication in INT8:
15/// output = (input @ weights^T + bias) * scale
16#[derive(Debug, Clone)]
17pub struct QuantizedLinear {
18    /// Quantized weights: [out_features, in_features] in i8
19    weights_q: Vec<i8>,
20
21    /// Per-output-feature weight scales (per-channel quantization)
22    weight_scales: Vec<f32>,
23
24    /// Bias in i32 accumulator space
25    bias_q: Vec<i32>,
26
27    /// Original FP32 bias
28    bias_f32: Vec<f32>,
29
30    /// Layer dimensions
31    in_features: usize,
32    out_features: usize,
33}
34
35impl QuantizedLinear {
36    /// Create from FP32 Linear layer
37    ///
38    /// # Arguments
39    /// * `linear` - FP32 linear layer to quantize
40    /// * `input_scale` - Expected input activation scale
41    pub fn from_fp32(linear: &Linear, input_scale: f32) -> Self {
42        let in_features = linear.in_features();
43        let out_features = linear.out_features();
44        let weights = linear.weight();
45
46        // Compute per-output-feature weight scales
47        let mut weight_scales = vec![0.0f32; out_features];
48
49        for of in 0..out_features {
50            let mut max_abs = 0.0f32;
51            for if_ in 0..in_features {
52                let idx = of * in_features + if_;
53                max_abs = max_abs.max(weights[idx].abs());
54            }
55            weight_scales[of] = if max_abs > 0.0 {
56                max_abs / 127.0
57            } else {
58                1.0
59            };
60        }
61
62        // Quantize weights
63        let mut weights_q = vec![0i8; weights.len()];
64        for of in 0..out_features {
65            let scale = weight_scales[of];
66            for if_ in 0..in_features {
67                let idx = of * in_features + if_;
68                let w_q = (weights[idx] / scale).round().clamp(-127.0, 127.0) as i8;
69                weights_q[idx] = w_q;
70            }
71        }
72
73        // Pre-compute bias in i32 accumulator space
74        let bias_f32 = linear.bias()
75            .map(|b| b.to_vec())
76            .unwrap_or_else(|| vec![0.0; out_features]);
77        let mut bias_q = vec![0i32; out_features];
78
79        for of in 0..out_features {
80            let combined_scale = input_scale * weight_scales[of];
81            bias_q[of] = if combined_scale > 0.0 {
82                (bias_f32[of] / combined_scale).round() as i32
83            } else {
84                0
85            };
86        }
87
88        Self {
89            weights_q,
90            weight_scales,
91            bias_q,
92            bias_f32,
93            in_features,
94            out_features,
95        }
96    }
97
98    /// Forward pass with INT8 computation
99    ///
100    /// # Arguments
101    /// * `input` - Quantized u8 input tensor [batch, in_features]
102    /// * `batch_size` - Batch size
103    /// * `input_scale` - Input quantization scale
104    /// * `input_zero_point` - Input quantization zero point
105    pub fn forward_int8(
106        &self,
107        input: &[u8],
108        batch_size: usize,
109        input_scale: f32,
110        input_zero_point: u8,
111    ) -> CnnResult<Tensor> {
112        if input.len() != batch_size * self.in_features {
113            return Err(CnnError::invalid_shape(
114                format!("input size {}", batch_size * self.in_features),
115                format!("size {}", input.len())
116            ));
117        }
118
119        let mut output_i32 = vec![0i32; batch_size * self.out_features];
120
121        // Pre-compute weight sums for zero-point correction
122        let mut weight_sums = vec![0i32; self.out_features];
123        for of in 0..self.out_features {
124            let mut sum = 0i32;
125            for if_ in 0..self.in_features {
126                sum += self.weights_q[of * self.in_features + if_] as i32;
127            }
128            weight_sums[of] = sum;
129        }
130
131        // GEMM: output = input @ weights^T + bias
132        for b in 0..batch_size {
133            for of in 0..self.out_features {
134                // Initialize with bias and zero-point correction
135                let mut acc = self.bias_q[of] - (input_zero_point as i32) * weight_sums[of];
136
137                // Dot product
138                for if_ in 0..self.in_features {
139                    let input_val = input[b * self.in_features + if_] as i32;
140                    let weight_val = self.weights_q[of * self.in_features + if_] as i32;
141                    acc += input_val * weight_val;
142                }
143
144                output_i32[b * self.out_features + of] = acc;
145            }
146        }
147
148        // Dequantize to f32
149        let output_f32 = self.dequantize_output(&output_i32, input_scale);
150
151        Tensor::from_data(output_f32, &[batch_size, self.out_features])
152    }
153
154    /// Dequantize i32 accumulator to f32
155    fn dequantize_output(&self, acc: &[i32], input_scale: f32) -> Vec<f32> {
156        let mut output = vec![0.0f32; acc.len()];
157
158        for (i, &val) in acc.iter().enumerate() {
159            let of = i % self.out_features;
160            let scale = input_scale * self.weight_scales[of];
161            output[i] = val as f32 * scale;
162        }
163
164        output
165    }
166
167    /// Get input features
168    pub fn in_features(&self) -> usize {
169        self.in_features
170    }
171
172    /// Get output features
173    pub fn out_features(&self) -> usize {
174        self.out_features
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_quantized_linear_creation() {
184        let linear = Linear::new(128, 64, true).unwrap();
185        let qlinear = QuantizedLinear::from_fp32(&linear, 0.01);
186
187        assert_eq!(qlinear.in_features(), 128);
188        assert_eq!(qlinear.out_features(), 64);
189    }
190
191    #[test]
192    fn test_quantized_linear_forward() {
193        let linear = Linear::new(32, 16, true).unwrap();
194        let qlinear = QuantizedLinear::from_fp32(&linear, 0.01);
195
196        let batch_size = 4;
197        let input = vec![128u8; batch_size * 32];
198
199        let output = qlinear.forward_int8(&input, batch_size, 0.01, 128).unwrap();
200
201        assert_eq!(output.shape(), &[batch_size, 16]);
202    }
203
204    #[test]
205    fn test_quantized_linear_zero_point_correction() {
206        let linear = Linear::new(8, 4, true).unwrap();
207        let qlinear = QuantizedLinear::from_fp32(&linear, 0.01);
208
209        // Test with non-zero zero-point
210        let input = vec![200u8; 1 * 8];
211        let output = qlinear.forward_int8(&input, 1, 0.01, 128).unwrap();
212
213        assert_eq!(output.shape(), &[1, 4]);
214    }
215}