1use crate::{CnnError, CnnResult, Tensor};
9
10#[derive(Debug, Clone)]
17pub struct QuantizedResidualAdd {
18 output_scale: f32,
20
21 output_zero_point: u8,
23}
24
25impl QuantizedResidualAdd {
26 pub fn new(scale1: f32, scale2: f32) -> Self {
32 let output_scale = (scale1 * scale2).sqrt();
34
35 let output_zero_point = 128u8;
37
38 Self {
39 output_scale,
40 output_zero_point,
41 }
42 }
43
44 pub fn forward_int8(
58 &self,
59 input1: &[u8],
60 scale1: f32,
61 zero_point1: u8,
62 input2: &[u8],
63 scale2: f32,
64 zero_point2: u8,
65 shape: &[usize],
66 ) -> CnnResult<(Vec<u8>, f32, u8)> {
67 if input1.len() != input2.len() {
68 return Err(CnnError::invalid_shape(
69 format!("input size {}", input1.len()),
70 format!("size {}", input2.len())
71 ));
72 }
73
74 let mut output = vec![self.output_zero_point; input1.len()];
75
76 let scale_factor1 = scale1 / self.output_scale;
81 let scale_factor2 = scale2 / self.output_scale;
82
83 for i in 0..input1.len() {
84 let val1 = (input1[i] as f32 - zero_point1 as f32) * scale_factor1;
86 let val2 = (input2[i] as f32 - zero_point2 as f32) * scale_factor2;
87
88 let sum = val1 + val2;
90
91 let output_q = (sum + self.output_zero_point as f32).round().clamp(0.0, 255.0);
93 output[i] = output_q as u8;
94 }
95
96 Ok((output, self.output_scale, self.output_zero_point))
97 }
98
99 pub fn forward_int8_i16(
103 &self,
104 input1: &[u8],
105 scale1: f32,
106 zero_point1: u8,
107 input2: &[u8],
108 scale2: f32,
109 zero_point2: u8,
110 shape: &[usize],
111 ) -> CnnResult<(Vec<u8>, f32, u8)> {
112 if input1.len() != input2.len() {
113 return Err(CnnError::invalid_shape(
114 format!("input size {}", input1.len()),
115 format!("size {}", input2.len())
116 ));
117 }
118
119 let mut output = vec![self.output_zero_point; input1.len()];
120
121 let (mult1, shift1) = Self::quantize_scale(scale1 / self.output_scale);
123 let (mult2, shift2) = Self::quantize_scale(scale2 / self.output_scale);
124
125 for i in 0..input1.len() {
126 let val1 = input1[i] as i16 - zero_point1 as i16;
128 let val2 = input2[i] as i16 - zero_point2 as i16;
129
130 let scaled1 = Self::multiply_by_quantized_multiplier(val1 as i32, mult1, shift1);
132 let scaled2 = Self::multiply_by_quantized_multiplier(val2 as i32, mult2, shift2);
133
134 let sum = scaled1 + scaled2 + self.output_zero_point as i32;
136 output[i] = sum.clamp(0, 255) as u8;
137 }
138
139 Ok((output, self.output_scale, self.output_zero_point))
140 }
141
142 fn quantize_scale(scale: f32) -> (i32, i32) {
147 if scale <= 0.0 {
148 return (0, 0);
149 }
150
151 let mut shift = 0i32;
153 let mut scaled = scale;
154
155 while scaled < 0.5 {
156 scaled *= 2.0;
157 shift += 1;
158 }
159 while scaled >= 1.0 {
160 scaled *= 0.5;
161 shift -= 1;
162 }
163
164 let multiplier = (scaled * 2147483648.0) as i32; (multiplier, shift)
168 }
169
170 fn multiply_by_quantized_multiplier(value: i32, multiplier: i32, shift: i32) -> i32 {
172 let total = (value as i64) * (multiplier as i64);
174
175 let result = if shift >= 0 {
177 (total + (1i64 << (shift - 1))) >> shift
178 } else {
179 total << (-shift)
180 };
181
182 result as i32
183 }
184
185 pub fn output_scale(&self) -> f32 {
187 self.output_scale
188 }
189
190 pub fn output_zero_point(&self) -> u8 {
192 self.output_zero_point
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_quantized_residual_add_same_scale() {
202 let scale = 0.01f32;
203 let residual = QuantizedResidualAdd::new(scale, scale);
204
205 let input1 = vec![128u8; 16];
206 let input2 = vec![138u8; 16]; let shape = &[4, 4];
208
209 let (output, _out_scale, _out_zp) = residual
210 .forward_int8(&input1, scale, 128, &input2, scale, 128, shape)
211 .unwrap();
212
213 assert_eq!(output.len(), 16);
214 assert!(output[0] >= 135 && output[0] <= 141);
216 }
217
218 #[test]
219 fn test_quantized_residual_add_different_scales() {
220 let scale1 = 0.01f32;
221 let scale2 = 0.02f32;
222 let residual = QuantizedResidualAdd::new(scale1, scale2);
223
224 let input1 = vec![128u8; 16];
225 let input2 = vec![133u8; 16]; let shape = &[4, 4];
227
228 let (output, _out_scale, _out_zp) = residual
229 .forward_int8(&input1, scale1, 128, &input2, scale2, 128, shape)
230 .unwrap();
231
232 assert_eq!(output.len(), 16);
233 assert!(output[0] >= 120 && output[0] <= 140);
235 }
236
237 #[test]
238 fn test_quantized_residual_add_i16_precision() {
239 let scale = 0.01f32;
240 let residual = QuantizedResidualAdd::new(scale, scale);
241
242 let input1 = vec![100u8; 8];
243 let input2 = vec![150u8; 8];
244 let shape = &[2, 4];
245
246 let (output, _, _) = residual
247 .forward_int8_i16(&input1, scale, 128, &input2, scale, 128, shape)
248 .unwrap();
249
250 assert_eq!(output.len(), 8);
251 }
252
253 #[test]
254 fn test_quantize_scale() {
255 let (mult, shift) = QuantizedResidualAdd::quantize_scale(0.5);
256 assert!(mult > 0);
257 assert_eq!(shift, 0);
258
259 let (mult2, shift2) = QuantizedResidualAdd::quantize_scale(0.25);
260 assert!(mult2 > 0);
261 assert_eq!(shift2, 1);
262 }
263}