1#![allow(dead_code)]
9use crate::{TorshDistributedError, TorshResult};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use torsh_tensor::Tensor;
13use tracing::{debug, info};
14
15#[derive(Debug, Clone)]
17pub struct CompressionConfig {
18 pub method: CompressionMethod,
20 pub compression_ratio: f32,
22 pub error_feedback: bool,
24 pub error_feedback_momentum: f32,
26 pub memory_efficient: bool,
28 pub warmup_steps: usize,
30}
31
32impl Default for CompressionConfig {
33 fn default() -> Self {
34 Self {
35 method: CompressionMethod::TopK { k: 0.1 },
36 compression_ratio: 0.1,
37 error_feedback: true,
38 error_feedback_momentum: 0.9,
39 memory_efficient: true,
40 warmup_steps: 100,
41 }
42 }
43}
44
45#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
47pub enum CompressionMethod {
48 TopK { k: f32 },
50 RandomK { k: f32 },
52 Threshold { threshold: f32 },
54 Quantization { bits: u8 },
56 SignSGD,
58 Sketching { sketch_size: usize },
60 PowerSGD { rank: usize },
62 TernaryQuant { threshold: f32 },
64 BimodalQuant { num_bins: usize },
66 NaturalCompression { compression_factor: f32 },
68 LayerwiseAdaptive { base_ratio: f32, sensitivity: f32 },
70 EF21 {
72 compression_ratio: f32,
73 momentum: f32,
74 },
75 None,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct CompressedGradient {
82 pub method: CompressionMethod,
84 pub data: CompressedData,
86 pub original_shape: Vec<usize>,
88 pub metadata: CompressionMetadata,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub enum CompressedData {
95 Sparse {
97 indices: Vec<usize>,
98 values: Vec<f32>,
99 },
100 Quantized {
102 values: Vec<u8>,
103 scale: f32,
104 zero_point: u8,
105 },
106 Signs { signs: Vec<bool>, norm: f32 },
108 LowRank {
110 left_factor: Vec<f32>,
111 right_factor: Vec<f32>,
112 rank: usize,
113 },
114 Sketch {
116 sketch: Vec<f32>,
117 hash_a: Vec<u32>,
118 hash_b: Vec<u32>,
119 },
120 Ternary { values: Vec<i8>, scale: f32 },
122 Bimodal {
124 bin_indices: Vec<u8>,
125 bin_centers: Vec<f32>,
126 },
127 Natural {
129 values: Vec<f32>,
130 frequencies: Vec<u32>,
131 codebook: Vec<f32>,
132 },
133 EF21 {
135 compressed_values: Vec<f32>,
136 error_feedback: Vec<f32>,
137 },
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct CompressionMetadata {
143 pub compression_ratio: f32,
145 pub error_norm: f32,
147 pub original_norm: f32,
149 pub timestamp: u64,
151}
152
153pub struct GradientCompressor {
155 config: CompressionConfig,
157 error_buffers: HashMap<String, Tensor>,
159 step_count: usize,
161 stats: CompressionStats,
163}
164
165#[derive(Debug, Clone, Default)]
167pub struct CompressionStats {
168 pub total_compressions: u64,
170 pub avg_compression_ratio: f64,
172 pub total_communication_reduction: u64,
174 pub avg_error_norm: f64,
176 pub compression_time_ms: f64,
178}
179
180impl GradientCompressor {
181 pub fn new(config: CompressionConfig) -> Self {
183 info!(
184 "Initializing gradient compressor with method: {:?}",
185 config.method
186 );
187
188 Self {
189 config,
190 error_buffers: HashMap::new(),
191 step_count: 0,
192 stats: CompressionStats::default(),
193 }
194 }
195
196 pub fn compress(
198 &mut self,
199 gradient: &Tensor,
200 param_name: &str,
201 ) -> TorshResult<CompressedGradient> {
202 let start_time = std::time::Instant::now();
203
204 if self.step_count < self.config.warmup_steps {
206 return self.no_compression(gradient, param_name);
207 }
208
209 let adjusted_gradient = if self.config.error_feedback {
211 self.apply_error_feedback(gradient, param_name)?
212 } else {
213 gradient.clone()
214 };
215
216 let compressed = match &self.config.method {
217 CompressionMethod::TopK { k } => self.compress_top_k(&adjusted_gradient, *k)?,
218 CompressionMethod::RandomK { k } => self.compress_random_k(&adjusted_gradient, *k)?,
219 CompressionMethod::Threshold { threshold } => {
220 self.compress_threshold(&adjusted_gradient, *threshold)?
221 }
222 CompressionMethod::Quantization { bits } => {
223 self.compress_quantization(&adjusted_gradient, *bits)?
224 }
225 CompressionMethod::SignSGD => self.compress_sign_sgd(&adjusted_gradient)?,
226 CompressionMethod::Sketching { sketch_size } => {
227 self.compress_sketching(&adjusted_gradient, *sketch_size)?
228 }
229 CompressionMethod::PowerSGD { rank } => {
230 self.compress_power_sgd(&adjusted_gradient, *rank)?
231 }
232 CompressionMethod::TernaryQuant { threshold } => {
233 self.compress_ternary(&adjusted_gradient, *threshold)?
234 }
235 CompressionMethod::BimodalQuant { num_bins } => {
236 self.compress_bimodal(&adjusted_gradient, *num_bins)?
237 }
238 CompressionMethod::NaturalCompression { compression_factor } => {
239 self.compress_natural(&adjusted_gradient, *compression_factor)?
240 }
241 CompressionMethod::LayerwiseAdaptive {
242 base_ratio,
243 sensitivity,
244 } => self.compress_layerwise_adaptive(
245 &adjusted_gradient,
246 *base_ratio,
247 *sensitivity,
248 param_name,
249 )?,
250 CompressionMethod::EF21 {
251 compression_ratio,
252 momentum,
253 } => self.compress_ef21(
254 &adjusted_gradient,
255 *compression_ratio,
256 *momentum,
257 param_name,
258 )?,
259 CompressionMethod::None => return self.no_compression(gradient, param_name),
260 };
261
262 if self.config.error_feedback {
264 self.update_error_feedback(&compressed, gradient, param_name)?;
265 }
266
267 let compression_time = start_time.elapsed().as_millis() as f64;
269 self.update_stats(&compressed, compression_time);
270
271 self.step_count += 1;
272 Ok(compressed)
273 }
274
275 pub fn decompress(&self, compressed: &CompressedGradient) -> TorshResult<Tensor> {
277 match &compressed.data {
278 CompressedData::Sparse { indices, values } => {
279 self.decompress_sparse(indices, values, &compressed.original_shape)
280 }
281 CompressedData::Quantized {
282 values,
283 scale,
284 zero_point,
285 } => self.decompress_quantized(values, *scale, *zero_point, &compressed.original_shape),
286 CompressedData::Signs { signs, norm } => {
287 self.decompress_sign_sgd(signs, *norm, &compressed.original_shape)
288 }
289 CompressedData::LowRank {
290 left_factor,
291 right_factor,
292 rank,
293 } => self.decompress_power_sgd(
294 left_factor,
295 right_factor,
296 *rank,
297 &compressed.original_shape,
298 ),
299 CompressedData::Sketch {
300 sketch,
301 hash_a,
302 hash_b,
303 } => self.decompress_sketching(sketch, hash_a, hash_b, &compressed.original_shape),
304 CompressedData::Ternary { values, scale } => {
305 self.decompress_ternary(values, *scale, &compressed.original_shape)
306 }
307 CompressedData::Bimodal {
308 bin_indices,
309 bin_centers,
310 } => self.decompress_bimodal(bin_indices, bin_centers, &compressed.original_shape),
311 CompressedData::Natural {
312 values,
313 frequencies: _,
314 codebook,
315 } => self.decompress_natural(values, codebook, &compressed.original_shape),
316 CompressedData::EF21 {
317 compressed_values,
318 error_feedback: _,
319 } => self.decompress_ef21(compressed_values, &compressed.original_shape),
320 }
321 }
322
323 fn apply_error_feedback(&mut self, gradient: &Tensor, param_name: &str) -> TorshResult<Tensor> {
325 if let Some(error_buffer) = self.error_buffers.get(param_name) {
326 let scaled_error = error_buffer.mul_scalar(self.config.error_feedback_momentum)?;
328 Ok(gradient.add(&scaled_error)?)
329 } else {
330 Ok(gradient.clone())
331 }
332 }
333
334 fn update_error_feedback(
336 &mut self,
337 compressed: &CompressedGradient,
338 original: &Tensor,
339 param_name: &str,
340 ) -> TorshResult<()> {
341 let decompressed = self.decompress(compressed)?;
342 let error = original.sub(&decompressed)?;
343 self.error_buffers.insert(param_name.to_string(), error);
344 Ok(())
345 }
346
347 fn compress_top_k(&self, gradient: &Tensor, k: f32) -> TorshResult<CompressedGradient> {
349 let flat_grad = gradient.flatten()?;
350 let numel = flat_grad.numel();
351 let k_elements = ((numel as f32) * k).ceil() as usize;
352
353 let abs_grad = flat_grad.abs()?;
355 let grad_data = flat_grad.to_vec()?;
356 let abs_data = abs_grad.to_vec()?;
357
358 let mut indexed_values: Vec<(usize, f32)> =
360 abs_data.iter().enumerate().map(|(i, &v)| (i, v)).collect();
361 indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
362
363 let mut indices = Vec::new();
364 let mut values = Vec::new();
365
366 for &(idx, _) in indexed_values.iter().take(k_elements) {
367 indices.push(idx);
368 values.push(grad_data[idx]);
369 }
370
371 debug!("Top-K compression: kept {}/{} elements", k_elements, numel);
372
373 let original_norm = gradient.norm()?.item()?;
374 let compression_ratio = k;
375
376 Ok(CompressedGradient {
377 method: CompressionMethod::TopK { k },
378 data: CompressedData::Sparse { indices, values },
379 original_shape: gradient.shape().dims().to_vec(),
380 metadata: CompressionMetadata {
381 compression_ratio,
382 error_norm: 0.0, original_norm,
384 timestamp: std::time::SystemTime::now()
385 .duration_since(std::time::UNIX_EPOCH)
386 .expect("time should be after UNIX_EPOCH")
387 .as_secs(),
388 },
389 })
390 }
391
392 fn compress_random_k(&self, gradient: &Tensor, k: f32) -> TorshResult<CompressedGradient> {
394 let flat_grad = gradient.flatten()?;
395 let numel = flat_grad.numel();
396 let k_elements = ((numel as f32) * k).ceil() as usize;
397
398 let grad_data = flat_grad.to_vec()?;
399
400 let mut indices = Vec::new();
402 let mut values = Vec::new();
403
404 let step = numel / k_elements.max(1);
406 for i in (0..numel).step_by(step).take(k_elements) {
407 indices.push(i);
408 values.push(grad_data[i]);
409 }
410
411 debug!(
412 "Random-K compression: kept {}/{} elements",
413 k_elements, numel
414 );
415
416 let original_norm = gradient.norm()?.item()?;
417
418 Ok(CompressedGradient {
419 method: CompressionMethod::RandomK { k },
420 data: CompressedData::Sparse { indices, values },
421 original_shape: gradient.shape().dims().to_vec(),
422 metadata: CompressionMetadata {
423 compression_ratio: k,
424 error_norm: 0.0,
425 original_norm,
426 timestamp: std::time::SystemTime::now()
427 .duration_since(std::time::UNIX_EPOCH)
428 .expect("time should be after UNIX_EPOCH")
429 .as_secs(),
430 },
431 })
432 }
433
434 fn compress_threshold(
436 &self,
437 gradient: &Tensor,
438 threshold: f32,
439 ) -> TorshResult<CompressedGradient> {
440 let flat_grad = gradient.flatten()?;
441 let grad_data = flat_grad.to_vec()?;
442
443 let mut indices = Vec::new();
444 let mut values = Vec::new();
445
446 for (i, &value) in grad_data.iter().enumerate() {
447 if value.abs() >= threshold {
448 indices.push(i);
449 values.push(value);
450 }
451 }
452
453 let compression_ratio = indices.len() as f32 / grad_data.len() as f32;
454 debug!(
455 "Threshold compression: kept {}/{} elements",
456 indices.len(),
457 grad_data.len()
458 );
459
460 let original_norm = gradient.norm()?.item()?;
461
462 Ok(CompressedGradient {
463 method: CompressionMethod::Threshold { threshold },
464 data: CompressedData::Sparse { indices, values },
465 original_shape: gradient.shape().dims().to_vec(),
466 metadata: CompressionMetadata {
467 compression_ratio,
468 error_norm: 0.0,
469 original_norm,
470 timestamp: std::time::SystemTime::now()
471 .duration_since(std::time::UNIX_EPOCH)
472 .expect("time should be after UNIX_EPOCH")
473 .as_secs(),
474 },
475 })
476 }
477
478 fn compress_quantization(
480 &self,
481 gradient: &Tensor,
482 bits: u8,
483 ) -> TorshResult<CompressedGradient> {
484 let flat_grad = gradient.flatten()?;
485 let grad_data = flat_grad.to_vec()?;
486
487 let min_val = grad_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
489 let max_val = grad_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
490
491 let levels = (1 << bits) - 1;
492 let scale = (max_val - min_val) / levels as f32;
493 let zero_point = (-min_val / scale).round() as u8;
494
495 let mut quantized_values = Vec::new();
496 for &value in &grad_data {
497 let quantized = ((value / scale) + zero_point as f32)
498 .round()
499 .clamp(0.0, levels as f32) as u8;
500 quantized_values.push(quantized);
501 }
502
503 debug!("Quantization: {} bits, {} levels", bits, levels);
504
505 let original_norm = gradient.norm()?.item()?;
506 let compression_ratio = (bits as f32) / 32.0; Ok(CompressedGradient {
509 method: CompressionMethod::Quantization { bits },
510 data: CompressedData::Quantized {
511 values: quantized_values,
512 scale,
513 zero_point,
514 },
515 original_shape: gradient.shape().dims().to_vec(),
516 metadata: CompressionMetadata {
517 compression_ratio,
518 error_norm: 0.0,
519 original_norm,
520 timestamp: std::time::SystemTime::now()
521 .duration_since(std::time::UNIX_EPOCH)
522 .expect("time should be after UNIX_EPOCH")
523 .as_secs(),
524 },
525 })
526 }
527
528 fn compress_sign_sgd(&self, gradient: &Tensor) -> TorshResult<CompressedGradient> {
530 let flat_grad = gradient.flatten()?;
531 let grad_data = flat_grad.to_vec()?;
532 let norm = gradient.norm()?.item()?;
533
534 let signs: Vec<bool> = grad_data.iter().map(|&x| x >= 0.0).collect();
535
536 debug!(
537 "SignSGD compression: {} elements -> {} bits",
538 grad_data.len(),
539 signs.len()
540 );
541
542 Ok(CompressedGradient {
543 method: CompressionMethod::SignSGD,
544 data: CompressedData::Signs { signs, norm },
545 original_shape: gradient.shape().dims().to_vec(),
546 metadata: CompressionMetadata {
547 compression_ratio: 1.0 / 32.0, error_norm: 0.0,
549 original_norm: norm,
550 timestamp: std::time::SystemTime::now()
551 .duration_since(std::time::UNIX_EPOCH)
552 .expect("time should be after UNIX_EPOCH")
553 .as_secs(),
554 },
555 })
556 }
557
558 fn compress_sketching(
560 &self,
561 gradient: &Tensor,
562 sketch_size: usize,
563 ) -> TorshResult<CompressedGradient> {
564 let flat_grad = gradient.flatten()?;
565 let grad_data = flat_grad.to_vec()?;
566
567 let sketch: Vec<f32> = grad_data.iter().take(sketch_size).copied().collect();
569
570 let hash_a: Vec<u32> = (0..grad_data.len()).map(|i| (i * 17 + 23) as u32).collect();
572 let hash_b: Vec<u32> = (0..grad_data.len()).map(|i| (i * 37 + 41) as u32).collect();
573
574 let compression_ratio = sketch_size as f32 / grad_data.len() as f32;
575 let original_norm = gradient.norm()?.item()?;
576
577 debug!(
578 "Sketching compression: {} -> {} elements",
579 grad_data.len(),
580 sketch_size
581 );
582
583 Ok(CompressedGradient {
584 method: CompressionMethod::Sketching { sketch_size },
585 data: CompressedData::Sketch {
586 sketch,
587 hash_a,
588 hash_b,
589 },
590 original_shape: gradient.shape().dims().to_vec(),
591 metadata: CompressionMetadata {
592 compression_ratio,
593 error_norm: 0.0,
594 original_norm,
595 timestamp: std::time::SystemTime::now()
596 .duration_since(std::time::UNIX_EPOCH)
597 .expect("time should be after UNIX_EPOCH")
598 .as_secs(),
599 },
600 })
601 }
602
603 fn compress_power_sgd(
605 &self,
606 gradient: &Tensor,
607 rank: usize,
608 ) -> TorshResult<CompressedGradient> {
609 let shape_obj = gradient.shape();
610 let shape = shape_obj.dims();
611 if shape.len() != 2 {
612 return Err(TorshDistributedError::invalid_argument(
613 "gradient",
614 format!("PowerSGD requires 2D tensors, got {}D tensor", shape.len()),
615 "2D tensor with shape [rows, cols]",
616 ));
617 }
618
619 let rows = shape[0];
620 let cols = shape[1];
621
622 let left_factor_size = rows * rank;
624 let right_factor_size = cols * rank;
625
626 let flat_grad = gradient.flatten()?;
627 let grad_data = flat_grad.to_vec()?;
628
629 let left_factor: Vec<f32> = grad_data.iter().take(left_factor_size).copied().collect();
631 let right_factor: Vec<f32> = grad_data
632 .iter()
633 .skip(left_factor_size)
634 .take(right_factor_size)
635 .copied()
636 .collect();
637
638 let compression_ratio =
639 (left_factor_size + right_factor_size) as f32 / grad_data.len() as f32;
640 let original_norm = gradient.norm()?.item()?;
641
642 debug!(
643 "PowerSGD compression: rank {}, ratio {:.3}",
644 rank, compression_ratio
645 );
646
647 Ok(CompressedGradient {
648 method: CompressionMethod::PowerSGD { rank },
649 data: CompressedData::LowRank {
650 left_factor,
651 right_factor,
652 rank,
653 },
654 original_shape: gradient.shape().dims().to_vec(),
655 metadata: CompressionMetadata {
656 compression_ratio,
657 error_norm: 0.0,
658 original_norm,
659 timestamp: std::time::SystemTime::now()
660 .duration_since(std::time::UNIX_EPOCH)
661 .expect("time should be after UNIX_EPOCH")
662 .as_secs(),
663 },
664 })
665 }
666
667 fn compress_ternary(
669 &self,
670 gradient: &Tensor,
671 threshold: f32,
672 ) -> TorshResult<CompressedGradient> {
673 let flat_grad = gradient.flatten()?;
674 let grad_data = flat_grad.to_vec()?;
675 let original_norm = gradient.norm()?.item()?;
676
677 let scale = original_norm / (grad_data.len() as f32).sqrt();
679
680 let mut ternary_values = Vec::new();
681 for &value in &grad_data {
682 let normalized = value / scale;
683 let ternary = if normalized > threshold {
684 1i8
685 } else if normalized < -threshold {
686 -1i8
687 } else {
688 0i8
689 };
690 ternary_values.push(ternary);
691 }
692
693 let compression_ratio = 2.0 / 32.0; debug!(
695 "Ternary compression: threshold {}, scale {:.6}",
696 threshold, scale
697 );
698
699 Ok(CompressedGradient {
700 method: CompressionMethod::TernaryQuant { threshold },
701 data: CompressedData::Ternary {
702 values: ternary_values,
703 scale,
704 },
705 original_shape: gradient.shape().dims().to_vec(),
706 metadata: CompressionMetadata {
707 compression_ratio,
708 error_norm: 0.0,
709 original_norm,
710 timestamp: std::time::SystemTime::now()
711 .duration_since(std::time::UNIX_EPOCH)
712 .expect("time should be after UNIX_EPOCH")
713 .as_secs(),
714 },
715 })
716 }
717
718 fn compress_bimodal(
720 &self,
721 gradient: &Tensor,
722 num_bins: usize,
723 ) -> TorshResult<CompressedGradient> {
724 let flat_grad = gradient.flatten()?;
725 let grad_data = flat_grad.to_vec()?;
726 let original_norm = gradient.norm()?.item()?;
727
728 let min_val = grad_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
730 let max_val = grad_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
731
732 let mut bin_centers = Vec::new();
734 for i in 0..num_bins {
735 let center = min_val + (max_val - min_val) * (i as f32 + 0.5) / (num_bins as f32);
736 bin_centers.push(center);
737 }
738
739 let mut bin_indices = Vec::new();
741 for &value in &grad_data {
742 let mut best_bin = 0;
743 let mut best_distance = f32::INFINITY;
744
745 for (bin_idx, ¢er) in bin_centers.iter().enumerate() {
746 let distance = (value - center).abs();
747 if distance < best_distance {
748 best_distance = distance;
749 best_bin = bin_idx;
750 }
751 }
752 bin_indices.push(best_bin as u8);
753 }
754
755 let bits_per_bin = (num_bins as f32).log2().ceil();
756 let compression_ratio = bits_per_bin / 32.0;
757 debug!(
758 "Bimodal compression: {} bins, {:.1} bits/value",
759 num_bins, bits_per_bin
760 );
761
762 Ok(CompressedGradient {
763 method: CompressionMethod::BimodalQuant { num_bins },
764 data: CompressedData::Bimodal {
765 bin_indices,
766 bin_centers,
767 },
768 original_shape: gradient.shape().dims().to_vec(),
769 metadata: CompressionMetadata {
770 compression_ratio,
771 error_norm: 0.0,
772 original_norm,
773 timestamp: std::time::SystemTime::now()
774 .duration_since(std::time::UNIX_EPOCH)
775 .expect("time should be after UNIX_EPOCH")
776 .as_secs(),
777 },
778 })
779 }
780
781 fn compress_natural(
783 &self,
784 gradient: &Tensor,
785 compression_factor: f32,
786 ) -> TorshResult<CompressedGradient> {
787 let flat_grad = gradient.flatten()?;
788 let grad_data = flat_grad.to_vec()?;
789 let original_norm = gradient.norm()?.item()?;
790
791 let num_unique = (grad_data.len() as f32 * compression_factor).ceil() as usize;
793 let mut value_counts: std::collections::HashMap<i32, u32> =
794 std::collections::HashMap::new();
795
796 let scale = 10000.0; for &value in &grad_data {
799 let quantized = (value * scale).round() as i32;
800 *value_counts.entry(quantized).or_insert(0) += 1;
801 }
802
803 let mut sorted_values: Vec<_> = value_counts.into_iter().collect();
805 sorted_values.sort_by(|a, b| b.1.cmp(&a.1));
806 sorted_values.truncate(num_unique);
807
808 let codebook: Vec<f32> = sorted_values
810 .iter()
811 .map(|(v, _)| *v as f32 / scale)
812 .collect();
813 let frequencies: Vec<u32> = sorted_values.iter().map(|(_, f)| *f).collect();
814
815 let mut compressed_values = Vec::new();
817 for &value in &grad_data {
818 let mut best_idx = 0;
820 let mut best_distance = f32::INFINITY;
821 for (idx, &codebook_val) in codebook.iter().enumerate() {
822 let distance = (value - codebook_val).abs();
823 if distance < best_distance {
824 best_distance = distance;
825 best_idx = idx;
826 }
827 }
828 compressed_values.push(best_idx as f32);
829 }
830
831 debug!(
832 "Natural compression: {} unique values from {} total",
833 num_unique,
834 grad_data.len()
835 );
836
837 Ok(CompressedGradient {
838 method: CompressionMethod::NaturalCompression { compression_factor },
839 data: CompressedData::Natural {
840 values: compressed_values,
841 frequencies,
842 codebook,
843 },
844 original_shape: gradient.shape().dims().to_vec(),
845 metadata: CompressionMetadata {
846 compression_ratio: compression_factor,
847 error_norm: 0.0,
848 original_norm,
849 timestamp: std::time::SystemTime::now()
850 .duration_since(std::time::UNIX_EPOCH)
851 .expect("time should be after UNIX_EPOCH")
852 .as_secs(),
853 },
854 })
855 }
856
857 fn compress_layerwise_adaptive(
859 &self,
860 gradient: &Tensor,
861 base_ratio: f32,
862 sensitivity: f32,
863 param_name: &str,
864 ) -> TorshResult<CompressedGradient> {
865 let _original_norm = gradient.norm()?.item();
866
867 let layer_sensitivity = if param_name.contains("weight") {
869 1.0
870 } else {
871 sensitivity
872 };
873 let adapted_ratio = base_ratio * layer_sensitivity;
874
875 self.compress_top_k(gradient, adapted_ratio)
877 }
878
879 fn compress_ef21(
881 &mut self,
882 gradient: &Tensor,
883 compression_ratio: f32,
884 momentum: f32,
885 param_name: &str,
886 ) -> TorshResult<CompressedGradient> {
887 let flat_grad = gradient.flatten()?;
888 let grad_data = flat_grad.to_vec()?;
889 let original_norm = gradient.norm()?.item()?;
890
891 let error_key = format!("ef21_{}", param_name);
893 let error_feedback = if let Some(prev_error) = self.error_buffers.get(&error_key) {
894 prev_error.flatten()?.to_vec()?
895 } else {
896 vec![0.0; grad_data.len()]
897 };
898
899 let mut adjusted_grad = Vec::new();
901 for (&grad_val, &error_val) in grad_data.iter().zip(error_feedback.iter()) {
902 adjusted_grad.push(grad_val + momentum * error_val);
903 }
904
905 let k_elements = (grad_data.len() as f32 * compression_ratio).ceil() as usize;
907 let mut indexed_values: Vec<(usize, f32)> = adjusted_grad
908 .iter()
909 .enumerate()
910 .map(|(i, &v)| (i, v.abs()))
911 .collect();
912 indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
913
914 let mut compressed_values = vec![0.0; grad_data.len()];
915 let mut new_error_feedback = adjusted_grad.clone();
916
917 for &(idx, _) in indexed_values.iter().take(k_elements) {
919 compressed_values[idx] = adjusted_grad[idx];
920 new_error_feedback[idx] = 0.0; }
922
923 let error_tensor = Tensor::from_vec(new_error_feedback.clone(), gradient.shape().dims())?;
925 self.error_buffers.insert(error_key, error_tensor);
926
927 debug!(
928 "EF21 compression: kept {}/{} elements with momentum {}",
929 k_elements,
930 grad_data.len(),
931 momentum
932 );
933
934 Ok(CompressedGradient {
935 method: CompressionMethod::EF21 {
936 compression_ratio,
937 momentum,
938 },
939 data: CompressedData::EF21 {
940 compressed_values,
941 error_feedback: new_error_feedback,
942 },
943 original_shape: gradient.shape().dims().to_vec(),
944 metadata: CompressionMetadata {
945 compression_ratio,
946 error_norm: 0.0,
947 original_norm,
948 timestamp: std::time::SystemTime::now()
949 .duration_since(std::time::UNIX_EPOCH)
950 .expect("time should be after UNIX_EPOCH")
951 .as_secs(),
952 },
953 })
954 }
955
956 fn no_compression(
958 &self,
959 gradient: &Tensor,
960 _param_name: &str,
961 ) -> TorshResult<CompressedGradient> {
962 let flat_grad = gradient.flatten()?;
963 let grad_data = flat_grad.to_vec()?;
964 let indices: Vec<usize> = (0..grad_data.len()).collect();
965
966 let original_norm = gradient.norm()?.item()?;
967
968 Ok(CompressedGradient {
969 method: CompressionMethod::None,
970 data: CompressedData::Sparse {
971 indices,
972 values: grad_data,
973 },
974 original_shape: gradient.shape().dims().to_vec(),
975 metadata: CompressionMetadata {
976 compression_ratio: 1.0,
977 error_norm: 0.0,
978 original_norm,
979 timestamp: std::time::SystemTime::now()
980 .duration_since(std::time::UNIX_EPOCH)
981 .expect("time should be after UNIX_EPOCH")
982 .as_secs(),
983 },
984 })
985 }
986
987 fn decompress_sparse(
989 &self,
990 indices: &[usize],
991 values: &[f32],
992 shape: &[usize],
993 ) -> TorshResult<Tensor> {
994 let total_elements: usize = shape.iter().product();
995 let mut data = vec![0.0; total_elements];
996
997 for (&idx, &val) in indices.iter().zip(values.iter()) {
998 if idx < total_elements {
999 data[idx] = val;
1000 }
1001 }
1002
1003 Ok(Tensor::from_vec(data, shape)?)
1004 }
1005
1006 fn decompress_quantized(
1008 &self,
1009 values: &[u8],
1010 scale: f32,
1011 zero_point: u8,
1012 shape: &[usize],
1013 ) -> TorshResult<Tensor> {
1014 let data: Vec<f32> = values
1015 .iter()
1016 .map(|&q| (q as f32 - zero_point as f32) * scale)
1017 .collect();
1018
1019 Ok(Tensor::from_vec(data, shape)?)
1020 }
1021
1022 fn decompress_sign_sgd(
1024 &self,
1025 signs: &[bool],
1026 norm: f32,
1027 shape: &[usize],
1028 ) -> TorshResult<Tensor> {
1029 let total_elements: usize = shape.iter().product();
1030 let magnitude = norm / (total_elements as f32).sqrt();
1031
1032 let data: Vec<f32> = signs
1033 .iter()
1034 .map(|&sign| if sign { magnitude } else { -magnitude })
1035 .collect();
1036
1037 Ok(Tensor::from_vec(data, shape)?)
1038 }
1039
1040 fn decompress_power_sgd(
1042 &self,
1043 left_factor: &[f32],
1044 right_factor: &[f32],
1045 _rank: usize,
1046 shape: &[usize],
1047 ) -> TorshResult<Tensor> {
1048 let total_elements: usize = shape.iter().product();
1050 let mut data = vec![0.0; total_elements];
1051
1052 let left_len = left_factor.len();
1053 let right_len = right_factor.len();
1054
1055 for i in 0..total_elements.min(left_len + right_len) {
1056 if i < left_len {
1057 data[i] = left_factor[i];
1058 } else {
1059 data[i] = right_factor[i - left_len];
1060 }
1061 }
1062
1063 Ok(Tensor::from_vec(data, shape)?)
1064 }
1065
1066 fn decompress_sketching(
1068 &self,
1069 sketch: &[f32],
1070 _hash_a: &[u32],
1071 _hash_b: &[u32],
1072 shape: &[usize],
1073 ) -> TorshResult<Tensor> {
1074 let total_elements: usize = shape.iter().product();
1075 let mut data = vec![0.0; total_elements];
1076
1077 for (i, &val) in sketch.iter().enumerate() {
1079 if i < total_elements {
1080 data[i] = val;
1081 }
1082 }
1083
1084 Ok(Tensor::from_vec(data, shape)?)
1085 }
1086
1087 fn decompress_ternary(
1089 &self,
1090 values: &[i8],
1091 scale: f32,
1092 shape: &[usize],
1093 ) -> TorshResult<Tensor> {
1094 let data: Vec<f32> = values
1095 .iter()
1096 .map(|&ternary| (ternary as f32) * scale)
1097 .collect();
1098
1099 Ok(Tensor::from_vec(data, shape)?)
1100 }
1101
1102 fn decompress_bimodal(
1104 &self,
1105 bin_indices: &[u8],
1106 bin_centers: &[f32],
1107 shape: &[usize],
1108 ) -> TorshResult<Tensor> {
1109 let data: Vec<f32> = bin_indices
1110 .iter()
1111 .map(|&bin_idx| bin_centers.get(bin_idx as usize).copied().unwrap_or(0.0))
1112 .collect();
1113
1114 Ok(Tensor::from_vec(data, shape)?)
1115 }
1116
1117 fn decompress_natural(
1119 &self,
1120 values: &[f32],
1121 codebook: &[f32],
1122 shape: &[usize],
1123 ) -> TorshResult<Tensor> {
1124 let data: Vec<f32> = values
1125 .iter()
1126 .map(|&idx| {
1127 let idx_usize = idx as usize;
1128 codebook.get(idx_usize).copied().unwrap_or(0.0)
1129 })
1130 .collect();
1131
1132 Ok(Tensor::from_vec(data, shape)?)
1133 }
1134
1135 fn decompress_ef21(&self, compressed_values: &[f32], shape: &[usize]) -> TorshResult<Tensor> {
1137 Ok(Tensor::from_vec(compressed_values.to_vec(), shape)?)
1139 }
1140
1141 fn update_stats(&mut self, compressed: &CompressedGradient, compression_time: f64) {
1143 self.stats.total_compressions += 1;
1144 self.stats.avg_compression_ratio = (self.stats.avg_compression_ratio
1145 * (self.stats.total_compressions - 1) as f64
1146 + compressed.metadata.compression_ratio as f64)
1147 / self.stats.total_compressions as f64;
1148 self.stats.compression_time_ms += compression_time;
1149 }
1150
1151 pub fn get_stats(&self) -> &CompressionStats {
1153 &self.stats
1154 }
1155
1156 pub fn reset_error_feedback(&mut self) {
1158 self.error_buffers.clear();
1159 }
1160
1161 pub fn step_count(&self) -> usize {
1163 self.step_count
1164 }
1165}
1166
1167#[cfg(test)]
1168mod tests {
1169 use super::*;
1170
1171 #[test]
1172 fn test_compression_config() {
1173 let config = CompressionConfig::default();
1174 assert_eq!(config.compression_ratio, 0.1);
1175 assert!(config.error_feedback);
1176 assert_eq!(config.warmup_steps, 100);
1177 }
1178
1179 #[test]
1180 fn test_compression_methods() {
1181 assert_ne!(
1182 CompressionMethod::TopK { k: 0.1 },
1183 CompressionMethod::SignSGD
1184 );
1185 assert_ne!(
1186 CompressionMethod::Quantization { bits: 8 },
1187 CompressionMethod::None
1188 );
1189 }
1190
1191 #[tokio::test]
1192 async fn test_gradient_compressor_creation() {
1193 let config = CompressionConfig::default();
1194 let compressor = GradientCompressor::new(config);
1195
1196 assert_eq!(compressor.step_count(), 0);
1197 assert_eq!(compressor.get_stats().total_compressions, 0);
1198 }
1199
1200 #[tokio::test]
1201 async fn test_top_k_compression() -> TorshResult<()> {
1202 let config = CompressionConfig {
1203 method: CompressionMethod::TopK { k: 0.5 },
1204 warmup_steps: 0,
1205 ..Default::default()
1206 };
1207 let mut compressor = GradientCompressor::new(config);
1208
1209 let gradient = torsh_tensor::creation::randn(&[10, 10])?;
1210 let compressed = compressor.compress(&gradient, "test_param")?;
1211
1212 match &compressed.data {
1213 CompressedData::Sparse { indices, values } => {
1214 assert_eq!(indices.len(), values.len());
1215 assert!(indices.len() <= 50); }
1217 _ => panic!("Expected sparse compression for TopK"),
1218 }
1219
1220 let decompressed = compressor.decompress(&compressed)?;
1221 assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
1222
1223 Ok(())
1224 }
1225
1226 #[tokio::test]
1227 async fn test_sign_sgd_compression() -> TorshResult<()> {
1228 let config = CompressionConfig {
1229 method: CompressionMethod::SignSGD,
1230 warmup_steps: 0,
1231 ..Default::default()
1232 };
1233 let mut compressor = GradientCompressor::new(config);
1234
1235 let gradient = torsh_tensor::creation::randn(&[5, 5])?;
1236 let compressed = compressor.compress(&gradient, "test_param")?;
1237
1238 match &compressed.data {
1239 CompressedData::Signs { signs, norm } => {
1240 assert_eq!(signs.len(), 25); assert!(*norm > 0.0);
1242 }
1243 _ => panic!("Expected sign compression for SignSGD"),
1244 }
1245
1246 let decompressed = compressor.decompress(&compressed)?;
1247 assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
1248
1249 Ok(())
1250 }
1251
1252 #[tokio::test]
1253 async fn test_quantization_compression() -> TorshResult<()> {
1254 let config = CompressionConfig {
1255 method: CompressionMethod::Quantization { bits: 8 },
1256 warmup_steps: 0,
1257 ..Default::default()
1258 };
1259 let mut compressor = GradientCompressor::new(config);
1260
1261 let gradient = torsh_tensor::creation::randn(&[4, 4])?;
1262 let compressed = compressor.compress(&gradient, "test_param")?;
1263
1264 match &compressed.data {
1265 CompressedData::Quantized {
1266 values,
1267 scale,
1268 zero_point: _,
1269 } => {
1270 assert_eq!(values.len(), 16); assert!(*scale > 0.0);
1272 }
1274 _ => panic!("Expected quantized compression"),
1275 }
1276
1277 let decompressed = compressor.decompress(&compressed)?;
1278 assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
1279
1280 Ok(())
1281 }
1282
1283 #[tokio::test]
1284 async fn test_no_compression() -> TorshResult<()> {
1285 let config = CompressionConfig {
1286 method: CompressionMethod::None,
1287 warmup_steps: 0,
1288 ..Default::default()
1289 };
1290 let mut compressor = GradientCompressor::new(config);
1291
1292 let gradient = torsh_tensor::creation::randn(&[3, 3])?;
1293 let compressed = compressor.compress(&gradient, "test_param")?;
1294
1295 assert_eq!(compressed.metadata.compression_ratio, 1.0);
1296
1297 let decompressed = compressor.decompress(&compressed)?;
1298 assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
1299
1300 Ok(())
1301 }
1302
1303 #[test]
1304 fn test_compression_stats() {
1305 let stats = CompressionStats {
1306 total_compressions: 100,
1307 avg_compression_ratio: 0.25,
1308 total_communication_reduction: 1024 * 1024, avg_error_norm: 0.01,
1310 compression_time_ms: 250.5,
1311 };
1312
1313 assert_eq!(stats.total_compressions, 100);
1314 assert_eq!(stats.avg_compression_ratio, 0.25);
1315 assert_eq!(stats.total_communication_reduction, 1024 * 1024);
1316 }
1317}