ruvector_cnn/layers/
quantized_linear.rs1use crate::{CnnError, CnnResult, Tensor};
9
10use super::Linear;
11
12#[derive(Debug, Clone)]
17pub struct QuantizedLinear {
18 weights_q: Vec<i8>,
20
21 weight_scales: Vec<f32>,
23
24 bias_q: Vec<i32>,
26
27 bias_f32: Vec<f32>,
29
30 in_features: usize,
32 out_features: usize,
33}
34
35impl QuantizedLinear {
36 pub fn from_fp32(linear: &Linear, input_scale: f32) -> Self {
42 let in_features = linear.in_features();
43 let out_features = linear.out_features();
44 let weights = linear.weight();
45
46 let mut weight_scales = vec![0.0f32; out_features];
48
49 for of in 0..out_features {
50 let mut max_abs = 0.0f32;
51 for if_ in 0..in_features {
52 let idx = of * in_features + if_;
53 max_abs = max_abs.max(weights[idx].abs());
54 }
55 weight_scales[of] = if max_abs > 0.0 {
56 max_abs / 127.0
57 } else {
58 1.0
59 };
60 }
61
62 let mut weights_q = vec![0i8; weights.len()];
64 for of in 0..out_features {
65 let scale = weight_scales[of];
66 for if_ in 0..in_features {
67 let idx = of * in_features + if_;
68 let w_q = (weights[idx] / scale).round().clamp(-127.0, 127.0) as i8;
69 weights_q[idx] = w_q;
70 }
71 }
72
73 let bias_f32 = linear.bias()
75 .map(|b| b.to_vec())
76 .unwrap_or_else(|| vec![0.0; out_features]);
77 let mut bias_q = vec![0i32; out_features];
78
79 for of in 0..out_features {
80 let combined_scale = input_scale * weight_scales[of];
81 bias_q[of] = if combined_scale > 0.0 {
82 (bias_f32[of] / combined_scale).round() as i32
83 } else {
84 0
85 };
86 }
87
88 Self {
89 weights_q,
90 weight_scales,
91 bias_q,
92 bias_f32,
93 in_features,
94 out_features,
95 }
96 }
97
98 pub fn forward_int8(
106 &self,
107 input: &[u8],
108 batch_size: usize,
109 input_scale: f32,
110 input_zero_point: u8,
111 ) -> CnnResult<Tensor> {
112 if input.len() != batch_size * self.in_features {
113 return Err(CnnError::invalid_shape(
114 format!("input size {}", batch_size * self.in_features),
115 format!("size {}", input.len())
116 ));
117 }
118
119 let mut output_i32 = vec![0i32; batch_size * self.out_features];
120
121 let mut weight_sums = vec![0i32; self.out_features];
123 for of in 0..self.out_features {
124 let mut sum = 0i32;
125 for if_ in 0..self.in_features {
126 sum += self.weights_q[of * self.in_features + if_] as i32;
127 }
128 weight_sums[of] = sum;
129 }
130
131 for b in 0..batch_size {
133 for of in 0..self.out_features {
134 let mut acc = self.bias_q[of] - (input_zero_point as i32) * weight_sums[of];
136
137 for if_ in 0..self.in_features {
139 let input_val = input[b * self.in_features + if_] as i32;
140 let weight_val = self.weights_q[of * self.in_features + if_] as i32;
141 acc += input_val * weight_val;
142 }
143
144 output_i32[b * self.out_features + of] = acc;
145 }
146 }
147
148 let output_f32 = self.dequantize_output(&output_i32, input_scale);
150
151 Tensor::from_data(output_f32, &[batch_size, self.out_features])
152 }
153
154 fn dequantize_output(&self, acc: &[i32], input_scale: f32) -> Vec<f32> {
156 let mut output = vec![0.0f32; acc.len()];
157
158 for (i, &val) in acc.iter().enumerate() {
159 let of = i % self.out_features;
160 let scale = input_scale * self.weight_scales[of];
161 output[i] = val as f32 * scale;
162 }
163
164 output
165 }
166
167 pub fn in_features(&self) -> usize {
169 self.in_features
170 }
171
172 pub fn out_features(&self) -> usize {
174 self.out_features
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_quantized_linear_creation() {
184 let linear = Linear::new(128, 64, true).unwrap();
185 let qlinear = QuantizedLinear::from_fp32(&linear, 0.01);
186
187 assert_eq!(qlinear.in_features(), 128);
188 assert_eq!(qlinear.out_features(), 64);
189 }
190
191 #[test]
192 fn test_quantized_linear_forward() {
193 let linear = Linear::new(32, 16, true).unwrap();
194 let qlinear = QuantizedLinear::from_fp32(&linear, 0.01);
195
196 let batch_size = 4;
197 let input = vec![128u8; batch_size * 32];
198
199 let output = qlinear.forward_int8(&input, batch_size, 0.01, 128).unwrap();
200
201 assert_eq!(output.shape(), &[batch_size, 16]);
202 }
203
204 #[test]
205 fn test_quantized_linear_zero_point_correction() {
206 let linear = Linear::new(8, 4, true).unwrap();
207 let qlinear = QuantizedLinear::from_fp32(&linear, 0.01);
208
209 let input = vec![200u8; 1 * 8];
211 let output = qlinear.forward_int8(&input, 1, 0.01, 128).unwrap();
212
213 assert_eq!(output.shape(), &[1, 4]);
214 }
215}