Skip to main content

ruvector_cnn/layers/
quantized_residual.rs

1//! Quantized Residual Addition
2//!
3//! INT8 residual connections with:
4//! - Requantization to align scales between branches
5//! - Per-tensor scale alignment
6//! - Handles mismatched scales
7
8use crate::{CnnError, CnnResult, Tensor};
9
10/// Quantized Residual Addition
11///
12/// Adds two quantized tensors with potentially different scales:
13/// output = input1 + input2
14///
15/// Handles scale alignment and requantization.
16#[derive(Debug, Clone)]
17pub struct QuantizedResidualAdd {
18    /// Output scale (chosen as the average of input scales)
19    output_scale: f32,
20
21    /// Output zero point (typically 128 for symmetric distributions)
22    output_zero_point: u8,
23}
24
25impl QuantizedResidualAdd {
26    /// Create a new quantized residual add layer
27    ///
28    /// # Arguments
29    /// * `scale1` - Scale of first input
30    /// * `scale2` - Scale of second input
31    pub fn new(scale1: f32, scale2: f32) -> Self {
32        // Use geometric mean of scales as output scale
33        let output_scale = (scale1 * scale2).sqrt();
34
35        // Assume symmetric distribution around 128
36        let output_zero_point = 128u8;
37
38        Self {
39            output_scale,
40            output_zero_point,
41        }
42    }
43
44    /// Forward pass with INT8 inputs
45    ///
46    /// # Arguments
47    /// * `input1` - First quantized u8 input
48    /// * `scale1` - Scale of first input
49    /// * `zero_point1` - Zero point of first input
50    /// * `input2` - Second quantized u8 input
51    /// * `scale2` - Scale of second input
52    /// * `zero_point2` - Zero point of second input
53    /// * `shape` - Tensor shape (must be identical for both inputs)
54    ///
55    /// # Returns
56    /// (output, output_scale, output_zero_point)
57    pub fn forward_int8(
58        &self,
59        input1: &[u8],
60        scale1: f32,
61        zero_point1: u8,
62        input2: &[u8],
63        scale2: f32,
64        zero_point2: u8,
65        shape: &[usize],
66    ) -> CnnResult<(Vec<u8>, f32, u8)> {
67        if input1.len() != input2.len() {
68            return Err(CnnError::invalid_shape(
69                format!("input size {}", input1.len()),
70                format!("size {}", input2.len())
71            ));
72        }
73
74        let mut output = vec![self.output_zero_point; input1.len()];
75
76        // Compute scale factors for requantization
77        // output = (input1_dequant + input2_dequant) / output_scale + output_zero_point
78        //        = ((q1 - zp1) * s1 + (q2 - zp2) * s2) / s_out + zp_out
79
80        let scale_factor1 = scale1 / self.output_scale;
81        let scale_factor2 = scale2 / self.output_scale;
82
83        for i in 0..input1.len() {
84            // Dequantize to floating point domain
85            let val1 = (input1[i] as f32 - zero_point1 as f32) * scale_factor1;
86            let val2 = (input2[i] as f32 - zero_point2 as f32) * scale_factor2;
87
88            // Add in floating point
89            let sum = val1 + val2;
90
91            // Requantize to output
92            let output_q = (sum + self.output_zero_point as f32).round().clamp(0.0, 255.0);
93            output[i] = output_q as u8;
94        }
95
96        Ok((output, self.output_scale, self.output_zero_point))
97    }
98
99    /// Forward pass with scale alignment (i16 intermediate precision)
100    ///
101    /// More accurate version using i16 intermediate precision.
102    pub fn forward_int8_i16(
103        &self,
104        input1: &[u8],
105        scale1: f32,
106        zero_point1: u8,
107        input2: &[u8],
108        scale2: f32,
109        zero_point2: u8,
110        shape: &[usize],
111    ) -> CnnResult<(Vec<u8>, f32, u8)> {
112        if input1.len() != input2.len() {
113            return Err(CnnError::invalid_shape(
114                format!("input size {}", input1.len()),
115                format!("size {}", input2.len())
116            ));
117        }
118
119        let mut output = vec![self.output_zero_point; input1.len()];
120
121        // Compute integer scale factors (multiplier and shift)
122        let (mult1, shift1) = Self::quantize_scale(scale1 / self.output_scale);
123        let (mult2, shift2) = Self::quantize_scale(scale2 / self.output_scale);
124
125        for i in 0..input1.len() {
126            // Subtract zero points
127            let val1 = input1[i] as i16 - zero_point1 as i16;
128            let val2 = input2[i] as i16 - zero_point2 as i16;
129
130            // Scale using fixed-point arithmetic
131            let scaled1 = Self::multiply_by_quantized_multiplier(val1 as i32, mult1, shift1);
132            let scaled2 = Self::multiply_by_quantized_multiplier(val2 as i32, mult2, shift2);
133
134            // Add and requantize
135            let sum = scaled1 + scaled2 + self.output_zero_point as i32;
136            output[i] = sum.clamp(0, 255) as u8;
137        }
138
139        Ok((output, self.output_scale, self.output_zero_point))
140    }
141
142    /// Quantize a floating-point scale to (multiplier, shift) format
143    ///
144    /// Represents scale as: multiplier * 2^(-shift)
145    /// where multiplier is in [0.5, 1.0) as i32 in Q31 format
146    fn quantize_scale(scale: f32) -> (i32, i32) {
147        if scale <= 0.0 {
148            return (0, 0);
149        }
150
151        // Find the shift such that scale * 2^shift is in [0.5, 1.0)
152        let mut shift = 0i32;
153        let mut scaled = scale;
154
155        while scaled < 0.5 {
156            scaled *= 2.0;
157            shift += 1;
158        }
159        while scaled >= 1.0 {
160            scaled *= 0.5;
161            shift -= 1;
162        }
163
164        // Quantize to Q31 format (31 fractional bits)
165        let multiplier = (scaled * 2147483648.0) as i32; // 2^31
166
167        (multiplier, shift)
168    }
169
170    /// Multiply by quantized multiplier with rounding
171    fn multiply_by_quantized_multiplier(value: i32, multiplier: i32, shift: i32) -> i32 {
172        // Perform multiplication in i64 to avoid overflow
173        let total = (value as i64) * (multiplier as i64);
174
175        // Apply shift with rounding
176        let result = if shift >= 0 {
177            (total + (1i64 << (shift - 1))) >> shift
178        } else {
179            total << (-shift)
180        };
181
182        result as i32
183    }
184
185    /// Get output scale
186    pub fn output_scale(&self) -> f32 {
187        self.output_scale
188    }
189
190    /// Get output zero point
191    pub fn output_zero_point(&self) -> u8 {
192        self.output_zero_point
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_quantized_residual_add_same_scale() {
202        let scale = 0.01f32;
203        let residual = QuantizedResidualAdd::new(scale, scale);
204
205        let input1 = vec![128u8; 16];
206        let input2 = vec![138u8; 16]; // +10 in quantized domain
207        let shape = &[4, 4];
208
209        let (output, _out_scale, _out_zp) = residual
210            .forward_int8(&input1, scale, 128, &input2, scale, 128, shape)
211            .unwrap();
212
213        assert_eq!(output.len(), 16);
214        // Output should be approximately 138 (128 + 10)
215        assert!(output[0] >= 135 && output[0] <= 141);
216    }
217
218    #[test]
219    fn test_quantized_residual_add_different_scales() {
220        let scale1 = 0.01f32;
221        let scale2 = 0.02f32;
222        let residual = QuantizedResidualAdd::new(scale1, scale2);
223
224        let input1 = vec![128u8; 16];
225        let input2 = vec![133u8; 16]; // +5 in quantized domain, but scale2 is 2x
226        let shape = &[4, 4];
227
228        let (output, _out_scale, _out_zp) = residual
229            .forward_int8(&input1, scale1, 128, &input2, scale2, 128, shape)
230            .unwrap();
231
232        assert_eq!(output.len(), 16);
233        // Check that output is within reasonable range
234        assert!(output[0] >= 120 && output[0] <= 140);
235    }
236
237    #[test]
238    fn test_quantized_residual_add_i16_precision() {
239        let scale = 0.01f32;
240        let residual = QuantizedResidualAdd::new(scale, scale);
241
242        let input1 = vec![100u8; 8];
243        let input2 = vec![150u8; 8];
244        let shape = &[2, 4];
245
246        let (output, _, _) = residual
247            .forward_int8_i16(&input1, scale, 128, &input2, scale, 128, shape)
248            .unwrap();
249
250        assert_eq!(output.len(), 8);
251    }
252
253    #[test]
254    fn test_quantize_scale() {
255        let (mult, shift) = QuantizedResidualAdd::quantize_scale(0.5);
256        assert!(mult > 0);
257        assert_eq!(shift, 0);
258
259        let (mult2, shift2) = QuantizedResidualAdd::quantize_scale(0.25);
260        assert!(mult2 > 0);
261        assert_eq!(shift2, 1);
262    }
263}