1use anyhow::Result;
2use std::collections::HashMap;
3use trustformers_core::tensor::Tensor;
4
5#[derive(Debug, Clone)]
11pub enum CompressionMethod {
12 TopK { k: usize },
14 RandomK { k: usize },
16 Threshold { threshold: f32 },
18 Quantization { bits: u8 },
20 SignSGD,
22 ErrorFeedback { base_method: Box<CompressionMethod> },
24}
25
26#[derive(Debug)]
27pub struct GradientCompressor {
28 method: CompressionMethod,
29 compression_ratio: f32,
30 error_buffer: HashMap<String, Vec<f32>>, }
32
33#[derive(Debug, Clone)]
34pub struct CompressedGradient {
35 pub indices: Vec<usize>,
36 pub values: Vec<f32>,
37 pub original_size: usize,
38 pub compression_ratio: f32,
39}
40
41impl GradientCompressor {
42 pub fn new(method: CompressionMethod) -> Self {
43 Self {
44 method,
45 compression_ratio: 0.0,
46 error_buffer: HashMap::new(),
47 }
48 }
49
50 pub fn compress(
51 &mut self,
52 gradients: &HashMap<String, Tensor>,
53 ) -> Result<HashMap<String, CompressedGradient>> {
54 let mut compressed = HashMap::new();
55
56 for (name, gradient) in gradients.iter() {
57 let grad_data = gradient.data()?;
58 let compressed_grad = self.compress_single(&grad_data, name)?;
59 compressed.insert(name.clone(), compressed_grad);
60 }
61
62 Ok(compressed)
63 }
64
65 pub fn decompress(
66 &self,
67 compressed: &HashMap<String, CompressedGradient>,
68 ) -> Result<HashMap<String, Tensor>> {
69 let mut decompressed = HashMap::new();
70
71 for (name, compressed_grad) in compressed.iter() {
72 let grad_data = self.decompress_single(compressed_grad)?;
73 decompressed.insert(name.clone(), Tensor::new(grad_data)?);
74 }
75
76 Ok(decompressed)
77 }
78
79 fn compress_single(
80 &mut self,
81 gradient: &[f32],
82 param_name: &str,
83 ) -> Result<CompressedGradient> {
84 match self.method.clone() {
85 CompressionMethod::TopK { k } => self.compress_topk(gradient, k),
86 CompressionMethod::RandomK { k } => self.compress_randomk(gradient, k),
87 CompressionMethod::Threshold { threshold } => {
88 self.compress_threshold(gradient, threshold)
89 },
90 CompressionMethod::Quantization { bits } => self.compress_quantized(gradient, bits),
91 CompressionMethod::SignSGD => self.compress_signsgd(gradient),
92 CompressionMethod::ErrorFeedback { base_method } => {
93 self.compress_with_error_feedback(gradient, param_name, &base_method)
94 },
95 }
96 }
97
98 fn compress_topk(&self, gradient: &[f32], k: usize) -> Result<CompressedGradient> {
99 let mut indexed_grads: Vec<(usize, f32)> =
100 gradient.iter().enumerate().map(|(i, &val)| (i, val.abs())).collect();
101
102 indexed_grads.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
104
105 let k = k.min(gradient.len());
106 let mut indices = Vec::with_capacity(k);
107 let mut values = Vec::with_capacity(k);
108
109 for i in 0..k {
110 let (idx, _) = indexed_grads[i];
111 indices.push(idx);
112 values.push(gradient[idx]);
113 }
114
115 Ok(CompressedGradient {
116 indices,
117 values,
118 original_size: gradient.len(),
119 compression_ratio: k as f32 / gradient.len() as f32,
120 })
121 }
122
123 fn compress_randomk(&self, gradient: &[f32], k: usize) -> Result<CompressedGradient> {
124 use std::collections::HashSet;
125
126 let k = k.min(gradient.len());
127 let mut indices = Vec::with_capacity(k);
128 let mut values = Vec::with_capacity(k);
129 let mut selected_indices = HashSet::new();
130
131 let step = gradient.len() / k.max(1);
133 for i in (0..gradient.len()).step_by(step) {
134 if indices.len() < k && !selected_indices.contains(&i) {
135 indices.push(i);
136 values.push(gradient[i]);
137 selected_indices.insert(i);
138 }
139 }
140
141 Ok(CompressedGradient {
142 indices,
143 values,
144 original_size: gradient.len(),
145 compression_ratio: k as f32 / gradient.len() as f32,
146 })
147 }
148
149 fn compress_threshold(&self, gradient: &[f32], threshold: f32) -> Result<CompressedGradient> {
150 let mut indices = Vec::new();
151 let mut values = Vec::new();
152
153 for (i, &val) in gradient.iter().enumerate() {
154 if val.abs() > threshold {
155 indices.push(i);
156 values.push(val);
157 }
158 }
159
160 let compression_ratio = indices.len() as f32 / gradient.len() as f32;
161
162 Ok(CompressedGradient {
163 indices,
164 values,
165 original_size: gradient.len(),
166 compression_ratio,
167 })
168 }
169
170 fn compress_quantized(&self, gradient: &[f32], bits: u8) -> Result<CompressedGradient> {
171 let levels = (1 << bits) - 1;
172 let min_val = gradient.iter().fold(f32::INFINITY, |a, &b| a.min(b));
173 let max_val = gradient.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
174 let scale = (max_val - min_val) / levels as f32;
175
176 let mut quantized_values = Vec::new();
177 let mut indices = Vec::new();
178
179 for (i, &val) in gradient.iter().enumerate() {
180 let quantized = ((val - min_val) / scale).round() as i32;
181 let dequantized = min_val + quantized as f32 * scale;
182
183 indices.push(i);
184 quantized_values.push(dequantized);
185 }
186
187 Ok(CompressedGradient {
188 indices,
189 values: quantized_values,
190 original_size: gradient.len(),
191 compression_ratio: (bits as f32) / 32.0, })
193 }
194
195 fn compress_signsgd(&self, gradient: &[f32]) -> Result<CompressedGradient> {
196 let mut indices = Vec::new();
197 let mut values = Vec::new();
198
199 for (i, &val) in gradient.iter().enumerate() {
200 indices.push(i);
201 values.push(if val >= 0.0 { 1.0 } else { -1.0 });
202 }
203
204 Ok(CompressedGradient {
205 indices,
206 values,
207 original_size: gradient.len(),
208 compression_ratio: 1.0 / 32.0, })
210 }
211
212 fn compress_with_error_feedback(
213 &mut self,
214 gradient: &[f32],
215 param_name: &str,
216 base_method: &Box<CompressionMethod>,
217 ) -> Result<CompressedGradient> {
218 let mut corrected_gradient = gradient.to_vec();
220
221 if let Some(error) = self.error_buffer.get(param_name) {
222 for i in 0..corrected_gradient.len().min(error.len()) {
223 corrected_gradient[i] += error[i];
224 }
225 }
226
227 let mut temp_compressor = GradientCompressor::new((**base_method).clone());
229 let compressed = temp_compressor.compress_single(&corrected_gradient, param_name)?;
230
231 let decompressed = self.decompress_single(&compressed)?;
233 let mut new_error = vec![0.0; corrected_gradient.len()];
234
235 for i in 0..new_error.len() {
236 new_error[i] = corrected_gradient[i] - decompressed.get(i).copied().unwrap_or(0.0);
237 }
238
239 self.error_buffer.insert(param_name.to_string(), new_error);
240
241 Ok(compressed)
242 }
243
244 fn decompress_single(&self, compressed: &CompressedGradient) -> Result<Vec<f32>> {
245 let mut gradient = vec![0.0; compressed.original_size];
246
247 for (&i, &value) in compressed.indices.iter().zip(compressed.values.iter()) {
248 if i < gradient.len() {
249 gradient[i] = value;
250 }
251 }
252
253 Ok(gradient)
254 }
255
256 pub fn get_compression_ratio(&self) -> f32 {
257 self.compression_ratio
258 }
259
260 pub fn reset_error_buffer(&mut self) {
261 self.error_buffer.clear();
262 }
263}
264
265#[derive(Debug)]
267pub struct CompressedAllReduce {
268 compressor: GradientCompressor,
269 world_size: usize,
270}
271
272impl CompressedAllReduce {
273 pub fn new(compression_method: CompressionMethod, world_size: usize) -> Self {
274 Self {
275 compressor: GradientCompressor::new(compression_method),
276 world_size,
277 }
278 }
279
280 pub fn all_reduce(
281 &mut self,
282 gradients: &HashMap<String, Tensor>,
283 ) -> Result<HashMap<String, Tensor>> {
284 let compressed = self.compressor.compress(gradients)?;
286
287 let aggregated = self.simulate_all_reduce(&compressed)?;
289
290 let mut result = self.compressor.decompress(&aggregated)?;
292
293 for (_, gradient) in result.iter_mut() {
295 let mut data = gradient.data()?;
296 for val in data.iter_mut() {
297 *val /= self.world_size as f32;
298 }
299 *gradient = Tensor::new(data)?;
300 }
301
302 Ok(result)
303 }
304
305 fn simulate_all_reduce(
306 &self,
307 compressed: &HashMap<String, CompressedGradient>,
308 ) -> Result<HashMap<String, CompressedGradient>> {
309 let mut result = HashMap::new();
317
318 for (name, grad) in compressed.iter() {
319 let mut aggregated_values = grad.values.clone();
320 for val in aggregated_values.iter_mut() {
321 *val *= self.world_size as f32; }
323
324 result.insert(
325 name.clone(),
326 CompressedGradient {
327 indices: grad.indices.clone(),
328 values: aggregated_values,
329 original_size: grad.original_size,
330 compression_ratio: grad.compression_ratio,
331 },
332 );
333 }
334
335 Ok(result)
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_topk_compression() {
345 let mut compressor = GradientCompressor::new(CompressionMethod::TopK { k: 3 });
346 let gradient = vec![0.1, 0.8, 0.2, -0.9, 0.3, -0.1];
347
348 let compressed = compressor.compress_single(&gradient, "test").unwrap();
349
350 assert_eq!(compressed.indices.len(), 3);
351 assert_eq!(compressed.values.len(), 3);
352 assert_eq!(compressed.original_size, 6);
353 assert!(compressed.compression_ratio < 1.0);
354
355 assert!(compressed.values.contains(&-0.9));
357 assert!(compressed.values.contains(&0.8));
358 assert!(compressed.values.contains(&0.3));
359 }
360
361 #[test]
362 fn test_threshold_compression() {
363 let mut compressor =
364 GradientCompressor::new(CompressionMethod::Threshold { threshold: 0.5 });
365 let gradient = vec![0.1, 0.8, 0.2, -0.9, 0.3, -0.1];
366
367 let compressed = compressor.compress_single(&gradient, "test").unwrap();
368
369 assert_eq!(compressed.values.len(), 2);
371 assert!(compressed.values.contains(&0.8));
372 assert!(compressed.values.contains(&-0.9));
373 }
374
375 #[test]
376 fn test_signsgd_compression() {
377 let mut compressor = GradientCompressor::new(CompressionMethod::SignSGD);
378 let gradient = vec![0.1, -0.8, 0.2, -0.9, 0.3, -0.1];
379
380 let compressed = compressor.compress_single(&gradient, "test").unwrap();
381
382 assert_eq!(compressed.values.len(), gradient.len());
383 assert_eq!(compressed.compression_ratio, 1.0 / 32.0);
384
385 for &val in &compressed.values {
387 assert!(val == 1.0 || val == -1.0);
388 }
389 }
390
391 #[test]
392 fn test_compression_decompression_roundtrip() {
393 let mut compressor = GradientCompressor::new(CompressionMethod::TopK { k: 3 });
394 let mut gradients = HashMap::new();
395
396 let grad_data = vec![0.1, 0.8, 0.2, -0.9, 0.3, -0.1];
397 gradients.insert(
398 "param1".to_string(),
399 Tensor::new(grad_data.clone()).unwrap(),
400 );
401
402 let compressed = compressor.compress(&gradients).unwrap();
403 let decompressed = compressor.decompress(&compressed).unwrap();
404
405 let result_data = decompressed.get("param1").unwrap().data().unwrap();
406 assert_eq!(result_data.len(), grad_data.len());
407
408 assert!(result_data.contains(&0.8));
410 assert!(result_data.contains(&-0.9));
411 }
412
413 #[test]
414 fn test_compressed_all_reduce() {
415 let mut all_reduce = CompressedAllReduce::new(
416 CompressionMethod::TopK { k: 2 },
417 4, );
419
420 let mut gradients = HashMap::new();
421 let grad_data = vec![0.4, 0.8, 0.2, -0.6];
422 gradients.insert("param1".to_string(), Tensor::new(grad_data).unwrap());
423
424 let result = all_reduce.all_reduce(&gradients).unwrap();
425
426 let result_data = result.get("param1").unwrap().data().unwrap();
427 assert_eq!(result_data.len(), 4);
428
429 for &val in &result_data {
431 if val != 0.0 {
432 assert!(val.abs() <= 1.0); }
435 }
436 }
437}