Skip to main content

trustformers_optim/
compression.rs

1use anyhow::Result;
2use std::collections::HashMap;
3use trustformers_core::tensor::Tensor;
4
5/// Gradient compression algorithms for distributed training.
6///
7/// Reduces communication overhead by compressing gradients before
8/// sending them across the network in distributed training setups.
9
10#[derive(Debug, Clone)]
11pub enum CompressionMethod {
12    /// Top-K sparsification: only send the K largest gradients
13    TopK { k: usize },
14    /// Random-K sparsification: randomly sample K gradients
15    RandomK { k: usize },
16    /// Threshold-based sparsification: send gradients above threshold
17    Threshold { threshold: f32 },
18    /// Quantization-based compression
19    Quantization { bits: u8 },
20    /// SignSGD: send only the sign of gradients
21    SignSGD,
22    /// Error feedback compression
23    ErrorFeedback { base_method: Box<CompressionMethod> },
24}
25
26#[derive(Debug)]
27pub struct GradientCompressor {
28    method: CompressionMethod,
29    compression_ratio: f32,
30    error_buffer: HashMap<String, Vec<f32>>, // For error feedback
31}
32
33#[derive(Debug, Clone)]
34pub struct CompressedGradient {
35    pub indices: Vec<usize>,
36    pub values: Vec<f32>,
37    pub original_size: usize,
38    pub compression_ratio: f32,
39}
40
41impl GradientCompressor {
42    pub fn new(method: CompressionMethod) -> Self {
43        Self {
44            method,
45            compression_ratio: 0.0,
46            error_buffer: HashMap::new(),
47        }
48    }
49
50    pub fn compress(
51        &mut self,
52        gradients: &HashMap<String, Tensor>,
53    ) -> Result<HashMap<String, CompressedGradient>> {
54        let mut compressed = HashMap::new();
55
56        for (name, gradient) in gradients.iter() {
57            let grad_data = gradient.data()?;
58            let compressed_grad = self.compress_single(&grad_data, name)?;
59            compressed.insert(name.clone(), compressed_grad);
60        }
61
62        Ok(compressed)
63    }
64
65    pub fn decompress(
66        &self,
67        compressed: &HashMap<String, CompressedGradient>,
68    ) -> Result<HashMap<String, Tensor>> {
69        let mut decompressed = HashMap::new();
70
71        for (name, compressed_grad) in compressed.iter() {
72            let grad_data = self.decompress_single(compressed_grad)?;
73            decompressed.insert(name.clone(), Tensor::new(grad_data)?);
74        }
75
76        Ok(decompressed)
77    }
78
79    fn compress_single(
80        &mut self,
81        gradient: &[f32],
82        param_name: &str,
83    ) -> Result<CompressedGradient> {
84        match self.method.clone() {
85            CompressionMethod::TopK { k } => self.compress_topk(gradient, k),
86            CompressionMethod::RandomK { k } => self.compress_randomk(gradient, k),
87            CompressionMethod::Threshold { threshold } => {
88                self.compress_threshold(gradient, threshold)
89            },
90            CompressionMethod::Quantization { bits } => self.compress_quantized(gradient, bits),
91            CompressionMethod::SignSGD => self.compress_signsgd(gradient),
92            CompressionMethod::ErrorFeedback { base_method } => {
93                self.compress_with_error_feedback(gradient, param_name, &base_method)
94            },
95        }
96    }
97
98    fn compress_topk(&self, gradient: &[f32], k: usize) -> Result<CompressedGradient> {
99        let mut indexed_grads: Vec<(usize, f32)> =
100            gradient.iter().enumerate().map(|(i, &val)| (i, val.abs())).collect();
101
102        // Sort by absolute value in descending order
103        indexed_grads.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
104
105        let k = k.min(gradient.len());
106        let mut indices = Vec::with_capacity(k);
107        let mut values = Vec::with_capacity(k);
108
109        for i in 0..k {
110            let (idx, _) = indexed_grads[i];
111            indices.push(idx);
112            values.push(gradient[idx]);
113        }
114
115        Ok(CompressedGradient {
116            indices,
117            values,
118            original_size: gradient.len(),
119            compression_ratio: k as f32 / gradient.len() as f32,
120        })
121    }
122
123    fn compress_randomk(&self, gradient: &[f32], k: usize) -> Result<CompressedGradient> {
124        use std::collections::HashSet;
125
126        let k = k.min(gradient.len());
127        let mut indices = Vec::with_capacity(k);
128        let mut values = Vec::with_capacity(k);
129        let mut selected_indices = HashSet::new();
130
131        // Simple random sampling (in practice, would use proper RNG)
132        let step = gradient.len() / k.max(1);
133        for i in (0..gradient.len()).step_by(step) {
134            if indices.len() < k && !selected_indices.contains(&i) {
135                indices.push(i);
136                values.push(gradient[i]);
137                selected_indices.insert(i);
138            }
139        }
140
141        Ok(CompressedGradient {
142            indices,
143            values,
144            original_size: gradient.len(),
145            compression_ratio: k as f32 / gradient.len() as f32,
146        })
147    }
148
149    fn compress_threshold(&self, gradient: &[f32], threshold: f32) -> Result<CompressedGradient> {
150        let mut indices = Vec::new();
151        let mut values = Vec::new();
152
153        for (i, &val) in gradient.iter().enumerate() {
154            if val.abs() > threshold {
155                indices.push(i);
156                values.push(val);
157            }
158        }
159
160        let compression_ratio = indices.len() as f32 / gradient.len() as f32;
161
162        Ok(CompressedGradient {
163            indices,
164            values,
165            original_size: gradient.len(),
166            compression_ratio,
167        })
168    }
169
170    fn compress_quantized(&self, gradient: &[f32], bits: u8) -> Result<CompressedGradient> {
171        let levels = (1 << bits) - 1;
172        let min_val = gradient.iter().fold(f32::INFINITY, |a, &b| a.min(b));
173        let max_val = gradient.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
174        let scale = (max_val - min_val) / levels as f32;
175
176        let mut quantized_values = Vec::new();
177        let mut indices = Vec::new();
178
179        for (i, &val) in gradient.iter().enumerate() {
180            let quantized = ((val - min_val) / scale).round() as i32;
181            let dequantized = min_val + quantized as f32 * scale;
182
183            indices.push(i);
184            quantized_values.push(dequantized);
185        }
186
187        Ok(CompressedGradient {
188            indices,
189            values: quantized_values,
190            original_size: gradient.len(),
191            compression_ratio: (bits as f32) / 32.0, // Assuming f32 gradients
192        })
193    }
194
195    fn compress_signsgd(&self, gradient: &[f32]) -> Result<CompressedGradient> {
196        let mut indices = Vec::new();
197        let mut values = Vec::new();
198
199        for (i, &val) in gradient.iter().enumerate() {
200            indices.push(i);
201            values.push(if val >= 0.0 { 1.0 } else { -1.0 });
202        }
203
204        Ok(CompressedGradient {
205            indices,
206            values,
207            original_size: gradient.len(),
208            compression_ratio: 1.0 / 32.0, // 1 bit vs 32 bits
209        })
210    }
211
212    fn compress_with_error_feedback(
213        &mut self,
214        gradient: &[f32],
215        param_name: &str,
216        base_method: &Box<CompressionMethod>,
217    ) -> Result<CompressedGradient> {
218        // Add accumulated error to current gradient
219        let mut corrected_gradient = gradient.to_vec();
220
221        if let Some(error) = self.error_buffer.get(param_name) {
222            for i in 0..corrected_gradient.len().min(error.len()) {
223                corrected_gradient[i] += error[i];
224            }
225        }
226
227        // Compress the corrected gradient
228        let mut temp_compressor = GradientCompressor::new((**base_method).clone());
229        let compressed = temp_compressor.compress_single(&corrected_gradient, param_name)?;
230
231        // Compute and store the new error
232        let decompressed = self.decompress_single(&compressed)?;
233        let mut new_error = vec![0.0; corrected_gradient.len()];
234
235        for i in 0..new_error.len() {
236            new_error[i] = corrected_gradient[i] - decompressed.get(i).copied().unwrap_or(0.0);
237        }
238
239        self.error_buffer.insert(param_name.to_string(), new_error);
240
241        Ok(compressed)
242    }
243
244    fn decompress_single(&self, compressed: &CompressedGradient) -> Result<Vec<f32>> {
245        let mut gradient = vec![0.0; compressed.original_size];
246
247        for (&i, &value) in compressed.indices.iter().zip(compressed.values.iter()) {
248            if i < gradient.len() {
249                gradient[i] = value;
250            }
251        }
252
253        Ok(gradient)
254    }
255
256    pub fn get_compression_ratio(&self) -> f32 {
257        self.compression_ratio
258    }
259
260    pub fn reset_error_buffer(&mut self) {
261        self.error_buffer.clear();
262    }
263}
264
265/// Distributed gradient aggregator with compression support
266#[derive(Debug)]
267pub struct CompressedAllReduce {
268    compressor: GradientCompressor,
269    world_size: usize,
270}
271
272impl CompressedAllReduce {
273    pub fn new(compression_method: CompressionMethod, world_size: usize) -> Self {
274        Self {
275            compressor: GradientCompressor::new(compression_method),
276            world_size,
277        }
278    }
279
280    pub fn all_reduce(
281        &mut self,
282        gradients: &HashMap<String, Tensor>,
283    ) -> Result<HashMap<String, Tensor>> {
284        // Compress gradients
285        let compressed = self.compressor.compress(gradients)?;
286
287        // Simulate all-reduce operation (in practice, this would use MPI/NCCL)
288        let aggregated = self.simulate_all_reduce(&compressed)?;
289
290        // Decompress and average
291        let mut result = self.compressor.decompress(&aggregated)?;
292
293        // Average across all workers
294        for (_, gradient) in result.iter_mut() {
295            let mut data = gradient.data()?;
296            for val in data.iter_mut() {
297                *val /= self.world_size as f32;
298            }
299            *gradient = Tensor::new(data)?;
300        }
301
302        Ok(result)
303    }
304
305    fn simulate_all_reduce(
306        &self,
307        compressed: &HashMap<String, CompressedGradient>,
308    ) -> Result<HashMap<String, CompressedGradient>> {
309        // In a real implementation, this would:
310        // 1. Send compressed gradients to all other workers
311        // 2. Receive compressed gradients from all other workers
312        // 3. Aggregate the sparse representations
313        // 4. Return the aggregated result
314
315        // For simulation, just return the input scaled by world_size
316        let mut result = HashMap::new();
317
318        for (name, grad) in compressed.iter() {
319            let mut aggregated_values = grad.values.clone();
320            for val in aggregated_values.iter_mut() {
321                *val *= self.world_size as f32; // Simulate sum across workers
322            }
323
324            result.insert(
325                name.clone(),
326                CompressedGradient {
327                    indices: grad.indices.clone(),
328                    values: aggregated_values,
329                    original_size: grad.original_size,
330                    compression_ratio: grad.compression_ratio,
331                },
332            );
333        }
334
335        Ok(result)
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn test_topk_compression() {
345        let mut compressor = GradientCompressor::new(CompressionMethod::TopK { k: 3 });
346        let gradient = vec![0.1, 0.8, 0.2, -0.9, 0.3, -0.1];
347
348        let compressed = compressor.compress_single(&gradient, "test").unwrap();
349
350        assert_eq!(compressed.indices.len(), 3);
351        assert_eq!(compressed.values.len(), 3);
352        assert_eq!(compressed.original_size, 6);
353        assert!(compressed.compression_ratio < 1.0);
354
355        // Should include the largest magnitude values: -0.9, 0.8, 0.3
356        assert!(compressed.values.contains(&-0.9));
357        assert!(compressed.values.contains(&0.8));
358        assert!(compressed.values.contains(&0.3));
359    }
360
361    #[test]
362    fn test_threshold_compression() {
363        let mut compressor =
364            GradientCompressor::new(CompressionMethod::Threshold { threshold: 0.5 });
365        let gradient = vec![0.1, 0.8, 0.2, -0.9, 0.3, -0.1];
366
367        let compressed = compressor.compress_single(&gradient, "test").unwrap();
368
369        // Only values with abs > 0.5 should be included: 0.8, -0.9
370        assert_eq!(compressed.values.len(), 2);
371        assert!(compressed.values.contains(&0.8));
372        assert!(compressed.values.contains(&-0.9));
373    }
374
375    #[test]
376    fn test_signsgd_compression() {
377        let mut compressor = GradientCompressor::new(CompressionMethod::SignSGD);
378        let gradient = vec![0.1, -0.8, 0.2, -0.9, 0.3, -0.1];
379
380        let compressed = compressor.compress_single(&gradient, "test").unwrap();
381
382        assert_eq!(compressed.values.len(), gradient.len());
383        assert_eq!(compressed.compression_ratio, 1.0 / 32.0);
384
385        // All values should be either 1.0 or -1.0
386        for &val in &compressed.values {
387            assert!(val == 1.0 || val == -1.0);
388        }
389    }
390
391    #[test]
392    fn test_compression_decompression_roundtrip() {
393        let mut compressor = GradientCompressor::new(CompressionMethod::TopK { k: 3 });
394        let mut gradients = HashMap::new();
395
396        let grad_data = vec![0.1, 0.8, 0.2, -0.9, 0.3, -0.1];
397        gradients.insert(
398            "param1".to_string(),
399            Tensor::new(grad_data.clone()).unwrap(),
400        );
401
402        let compressed = compressor.compress(&gradients).unwrap();
403        let decompressed = compressor.decompress(&compressed).unwrap();
404
405        let result_data = decompressed.get("param1").unwrap().data().unwrap();
406        assert_eq!(result_data.len(), grad_data.len());
407
408        // Check that the largest values are preserved
409        assert!(result_data.contains(&0.8));
410        assert!(result_data.contains(&-0.9));
411    }
412
413    #[test]
414    fn test_compressed_all_reduce() {
415        let mut all_reduce = CompressedAllReduce::new(
416            CompressionMethod::TopK { k: 2 },
417            4, // 4 workers
418        );
419
420        let mut gradients = HashMap::new();
421        let grad_data = vec![0.4, 0.8, 0.2, -0.6];
422        gradients.insert("param1".to_string(), Tensor::new(grad_data).unwrap());
423
424        let result = all_reduce.all_reduce(&gradients).unwrap();
425
426        let result_data = result.get("param1").unwrap().data().unwrap();
427        assert_eq!(result_data.len(), 4);
428
429        // Values should be averaged across workers (divided by world_size)
430        for &val in &result_data {
431            if val != 0.0 {
432                // Non-zero values should be the original values (since we simulated identity operation)
433                assert!(val.abs() <= 1.0); // Reasonable bound
434            }
435        }
436    }
437}