ruvector_tiny_dancer_core/
optimization.rs

1//! Model optimization techniques (quantization, pruning, knowledge distillation)
2
3use crate::error::{Result, TinyDancerError};
4use ndarray::Array2;
5
6/// Quantization configuration
7#[derive(Debug, Clone, Copy)]
8pub enum QuantizationMode {
9    /// No quantization (FP32)
10    None,
11    /// INT8 quantization
12    Int8,
13    /// INT16 quantization
14    Int16,
15}
16
17/// Quantization parameters
18#[derive(Debug, Clone)]
19pub struct QuantizationParams {
20    /// Scale factor
21    pub scale: f32,
22    /// Zero point
23    pub zero_point: i32,
24    /// Min value
25    pub min_val: f32,
26    /// Max value
27    pub max_val: f32,
28}
29
30/// Quantize a weight matrix to INT8
31pub fn quantize_to_int8(weights: &Array2<f32>) -> Result<(Vec<i8>, QuantizationParams)> {
32    let min_val = weights.iter().fold(f32::INFINITY, |a, &b| a.min(b));
33    let max_val = weights.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
34
35    if (max_val - min_val).abs() < f32::EPSILON {
36        return Err(TinyDancerError::InvalidInput(
37            "Cannot quantize constant weights".to_string(),
38        ));
39    }
40
41    // Calculate scale and zero point for symmetric quantization
42    let scale = (max_val - min_val) / 255.0;
43    let zero_point = -128;
44
45    let quantized: Vec<i8> = weights
46        .iter()
47        .map(|&w| {
48            let q = ((w - min_val) / scale) as i32 + zero_point;
49            q.clamp(-128, 127) as i8
50        })
51        .collect();
52
53    let params = QuantizationParams {
54        scale,
55        zero_point,
56        min_val,
57        max_val,
58    };
59
60    Ok((quantized, params))
61}
62
63/// Dequantize INT8 weights back to FP32
64pub fn dequantize_from_int8(
65    quantized: &[i8],
66    params: &QuantizationParams,
67    shape: (usize, usize),
68) -> Result<Array2<f32>> {
69    let weights: Vec<f32> = quantized
70        .iter()
71        .map(|&q| {
72            let dequantized = (q as i32 - params.zero_point) as f32 * params.scale + params.min_val;
73            dequantized
74        })
75        .collect();
76
77    Array2::from_shape_vec(shape, weights)
78        .map_err(|e| TinyDancerError::InvalidInput(format!("Shape error: {}", e)))
79}
80
81/// Apply magnitude-based pruning to weights
82pub fn prune_weights(weights: &mut Array2<f32>, sparsity: f32) -> Result<usize> {
83    if !(0.0..=1.0).contains(&sparsity) {
84        return Err(TinyDancerError::InvalidInput(
85            "Sparsity must be between 0.0 and 1.0".to_string(),
86        ));
87    }
88
89    let total_weights = weights.len();
90    let num_to_prune = (total_weights as f32 * sparsity) as usize;
91
92    // Get absolute values
93    let mut abs_weights: Vec<(usize, f32)> = weights
94        .iter()
95        .enumerate()
96        .map(|(i, &w)| (i, w.abs()))
97        .collect();
98
99    // Sort by magnitude
100    abs_weights.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
101
102    // Zero out smallest weights
103    let mut pruned_count = 0;
104    for i in 0..num_to_prune {
105        let idx = abs_weights[i].0;
106        let (row, col) = (idx / weights.ncols(), idx % weights.ncols());
107        weights[[row, col]] = 0.0;
108        pruned_count += 1;
109    }
110
111    Ok(pruned_count)
112}
113
114/// Calculate model compression ratio
115pub fn compression_ratio(original_size: usize, compressed_size: usize) -> f32 {
116    original_size as f32 / compressed_size as f32
117}
118
119/// Calculate speedup from optimization
120pub fn calculate_speedup(original_time_us: u64, optimized_time_us: u64) -> f32 {
121    original_time_us as f32 / optimized_time_us as f32
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use ndarray::Array2;
128
129    #[test]
130    fn test_int8_quantization() {
131        let weights = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
132        let (quantized, params) = quantize_to_int8(&weights).unwrap();
133
134        assert_eq!(quantized.len(), 4);
135        assert!(params.scale > 0.0);
136    }
137
138    #[test]
139    fn test_quantization_dequantization() {
140        let weights =
141            Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
142                .unwrap();
143        let (quantized, params) = quantize_to_int8(&weights).unwrap();
144        let dequantized = dequantize_from_int8(&quantized, &params, (3, 3)).unwrap();
145
146        // Check that values are approximately preserved
147        for (orig, deq) in weights.iter().zip(dequantized.iter()) {
148            assert!((orig - deq).abs() < 0.1);
149        }
150    }
151
152    #[test]
153    fn test_pruning() {
154        let mut weights = Array2::from_shape_vec((2, 2), vec![1.0, 0.1, 0.2, 2.0]).unwrap();
155        let pruned = prune_weights(&mut weights, 0.5).unwrap();
156
157        assert_eq!(pruned, 2);
158        // Smallest 2 values should be zero
159        let zero_count = weights.iter().filter(|&&w| w == 0.0).count();
160        assert_eq!(zero_count, 2);
161    }
162
163    #[test]
164    fn test_compression_ratio() {
165        let ratio = compression_ratio(1000, 250);
166        assert_eq!(ratio, 4.0);
167    }
168}