ruvector_tiny_dancer_core/
optimization.rs1use crate::error::{Result, TinyDancerError};
4use ndarray::Array2;
5
6#[derive(Debug, Clone, Copy)]
8pub enum QuantizationMode {
9 None,
11 Int8,
13 Int16,
15}
16
17#[derive(Debug, Clone)]
19pub struct QuantizationParams {
20 pub scale: f32,
22 pub zero_point: i32,
24 pub min_val: f32,
26 pub max_val: f32,
28}
29
30pub fn quantize_to_int8(weights: &Array2<f32>) -> Result<(Vec<i8>, QuantizationParams)> {
32 let min_val = weights.iter().fold(f32::INFINITY, |a, &b| a.min(b));
33 let max_val = weights.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
34
35 if (max_val - min_val).abs() < f32::EPSILON {
36 return Err(TinyDancerError::InvalidInput(
37 "Cannot quantize constant weights".to_string(),
38 ));
39 }
40
41 let scale = (max_val - min_val) / 255.0;
43 let zero_point = -128;
44
45 let quantized: Vec<i8> = weights
46 .iter()
47 .map(|&w| {
48 let q = ((w - min_val) / scale) as i32 + zero_point;
49 q.clamp(-128, 127) as i8
50 })
51 .collect();
52
53 let params = QuantizationParams {
54 scale,
55 zero_point,
56 min_val,
57 max_val,
58 };
59
60 Ok((quantized, params))
61}
62
63pub fn dequantize_from_int8(
65 quantized: &[i8],
66 params: &QuantizationParams,
67 shape: (usize, usize),
68) -> Result<Array2<f32>> {
69 let weights: Vec<f32> = quantized
70 .iter()
71 .map(|&q| {
72 let dequantized = (q as i32 - params.zero_point) as f32 * params.scale + params.min_val;
73 dequantized
74 })
75 .collect();
76
77 Array2::from_shape_vec(shape, weights)
78 .map_err(|e| TinyDancerError::InvalidInput(format!("Shape error: {}", e)))
79}
80
81pub fn prune_weights(weights: &mut Array2<f32>, sparsity: f32) -> Result<usize> {
83 if !(0.0..=1.0).contains(&sparsity) {
84 return Err(TinyDancerError::InvalidInput(
85 "Sparsity must be between 0.0 and 1.0".to_string(),
86 ));
87 }
88
89 let total_weights = weights.len();
90 let num_to_prune = (total_weights as f32 * sparsity) as usize;
91
92 let mut abs_weights: Vec<(usize, f32)> = weights
94 .iter()
95 .enumerate()
96 .map(|(i, &w)| (i, w.abs()))
97 .collect();
98
99 abs_weights.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
101
102 let mut pruned_count = 0;
104 for i in 0..num_to_prune {
105 let idx = abs_weights[i].0;
106 let (row, col) = (idx / weights.ncols(), idx % weights.ncols());
107 weights[[row, col]] = 0.0;
108 pruned_count += 1;
109 }
110
111 Ok(pruned_count)
112}
113
114pub fn compression_ratio(original_size: usize, compressed_size: usize) -> f32 {
116 original_size as f32 / compressed_size as f32
117}
118
119pub fn calculate_speedup(original_time_us: u64, optimized_time_us: u64) -> f32 {
121 original_time_us as f32 / optimized_time_us as f32
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use ndarray::Array2;
128
129 #[test]
130 fn test_int8_quantization() {
131 let weights = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
132 let (quantized, params) = quantize_to_int8(&weights).unwrap();
133
134 assert_eq!(quantized.len(), 4);
135 assert!(params.scale > 0.0);
136 }
137
138 #[test]
139 fn test_quantization_dequantization() {
140 let weights =
141 Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
142 .unwrap();
143 let (quantized, params) = quantize_to_int8(&weights).unwrap();
144 let dequantized = dequantize_from_int8(&quantized, ¶ms, (3, 3)).unwrap();
145
146 for (orig, deq) in weights.iter().zip(dequantized.iter()) {
148 assert!((orig - deq).abs() < 0.1);
149 }
150 }
151
152 #[test]
153 fn test_pruning() {
154 let mut weights = Array2::from_shape_vec((2, 2), vec![1.0, 0.1, 0.2, 2.0]).unwrap();
155 let pruned = prune_weights(&mut weights, 0.5).unwrap();
156
157 assert_eq!(pruned, 2);
158 let zero_count = weights.iter().filter(|&&w| w == 0.0).count();
160 assert_eq!(zero_count, 2);
161 }
162
163 #[test]
164 fn test_compression_ratio() {
165 let ratio = compression_ratio(1000, 250);
166 assert_eq!(ratio, 4.0);
167 }
168}