Skip to main content

torsh_distributed/
gradient_compression.rs

1//! Gradient Compression for Distributed Training
2//!
3//! This module implements various gradient compression techniques to reduce
4//! communication overhead in distributed training. Includes quantization,
5//! sparsification, and error feedback mechanisms.
6
7// Framework infrastructure - components designed for future use
8#![allow(dead_code)]
9use crate::{TorshDistributedError, TorshResult};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use torsh_tensor::Tensor;
13use tracing::{debug, info};
14
15/// Gradient compression configuration
16#[derive(Debug, Clone)]
17pub struct CompressionConfig {
18    /// Compression method to use
19    pub method: CompressionMethod,
20    /// Compression ratio (0.0 to 1.0)
21    pub compression_ratio: f32,
22    /// Whether to use error feedback
23    pub error_feedback: bool,
24    /// Momentum for error feedback
25    pub error_feedback_momentum: f32,
26    /// Whether to use memory-efficient compression
27    pub memory_efficient: bool,
28    /// Warmup steps before applying compression
29    pub warmup_steps: usize,
30}
31
32impl Default for CompressionConfig {
33    fn default() -> Self {
34        Self {
35            method: CompressionMethod::TopK { k: 0.1 },
36            compression_ratio: 0.1,
37            error_feedback: true,
38            error_feedback_momentum: 0.9,
39            memory_efficient: true,
40            warmup_steps: 100,
41        }
42    }
43}
44
45/// Supported compression methods
46#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
47pub enum CompressionMethod {
48    /// Top-K sparsification (keep top k% of gradients)
49    TopK { k: f32 },
50    /// Random sparsification
51    RandomK { k: f32 },
52    /// Threshold-based sparsification
53    Threshold { threshold: f32 },
54    /// Quantization to specific number of bits
55    Quantization { bits: u8 },
56    /// Sign-based compression (SignSGD)
57    SignSGD,
58    /// Gradient sketching using Count-Sketch
59    Sketching { sketch_size: usize },
60    /// PowerSGD low-rank approximation
61    PowerSGD { rank: usize },
62    /// Ternary quantization (-1, 0, +1)
63    TernaryQuant { threshold: f32 },
64    /// Bimodal quantization (adaptive binning)
65    BimodalQuant { num_bins: usize },
66    /// Natural compression (based on gradient distribution)
67    NaturalCompression { compression_factor: f32 },
68    /// Layerwise adaptive compression
69    LayerwiseAdaptive { base_ratio: f32, sensitivity: f32 },
70    /// EF21 compression with momentum
71    EF21 {
72        compression_ratio: f32,
73        momentum: f32,
74    },
75    /// No compression (baseline)
76    None,
77}
78
79/// Compressed gradient representation
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct CompressedGradient {
82    /// Compression method used
83    pub method: CompressionMethod,
84    /// Compressed data
85    pub data: CompressedData,
86    /// Original gradient shape
87    pub original_shape: Vec<usize>,
88    /// Metadata for decompression
89    pub metadata: CompressionMetadata,
90}
91
92/// Compressed data types
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub enum CompressedData {
95    /// Sparse representation with indices and values
96    Sparse {
97        indices: Vec<usize>,
98        values: Vec<f32>,
99    },
100    /// Quantized values
101    Quantized {
102        values: Vec<u8>,
103        scale: f32,
104        zero_point: u8,
105    },
106    /// Sign representation
107    Signs { signs: Vec<bool>, norm: f32 },
108    /// Low-rank factors
109    LowRank {
110        left_factor: Vec<f32>,
111        right_factor: Vec<f32>,
112        rank: usize,
113    },
114    /// Sketch representation
115    Sketch {
116        sketch: Vec<f32>,
117        hash_a: Vec<u32>,
118        hash_b: Vec<u32>,
119    },
120    /// Ternary representation (-1, 0, +1)
121    Ternary { values: Vec<i8>, scale: f32 },
122    /// Bimodal quantization bins
123    Bimodal {
124        bin_indices: Vec<u8>,
125        bin_centers: Vec<f32>,
126    },
127    /// Natural compression (frequency-based)
128    Natural {
129        values: Vec<f32>,
130        frequencies: Vec<u32>,
131        codebook: Vec<f32>,
132    },
133    /// EF21 compressed representation
134    EF21 {
135        compressed_values: Vec<f32>,
136        error_feedback: Vec<f32>,
137    },
138}
139
140/// Metadata for compression/decompression
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct CompressionMetadata {
143    /// Compression ratio achieved
144    pub compression_ratio: f32,
145    /// Error norm introduced by compression
146    pub error_norm: f32,
147    /// Original gradient norm
148    pub original_norm: f32,
149    /// Timestamp
150    pub timestamp: u64,
151}
152
153/// Gradient compressor
154pub struct GradientCompressor {
155    /// Configuration
156    config: CompressionConfig,
157    /// Error feedback buffers
158    error_buffers: HashMap<String, Tensor>,
159    /// Step counter
160    step_count: usize,
161    /// Compression statistics
162    stats: CompressionStats,
163}
164
165/// Compression statistics
166#[derive(Debug, Clone, Default)]
167pub struct CompressionStats {
168    /// Total number of compressions performed
169    pub total_compressions: u64,
170    /// Average compression ratio
171    pub avg_compression_ratio: f64,
172    /// Total communication reduction (bytes)
173    pub total_communication_reduction: u64,
174    /// Average error norm introduced
175    pub avg_error_norm: f64,
176    /// Time spent on compression (ms)
177    pub compression_time_ms: f64,
178}
179
180impl GradientCompressor {
181    /// Create a new gradient compressor
182    pub fn new(config: CompressionConfig) -> Self {
183        info!(
184            "Initializing gradient compressor with method: {:?}",
185            config.method
186        );
187
188        Self {
189            config,
190            error_buffers: HashMap::new(),
191            step_count: 0,
192            stats: CompressionStats::default(),
193        }
194    }
195
196    /// Compress a gradient tensor
197    pub fn compress(
198        &mut self,
199        gradient: &Tensor,
200        param_name: &str,
201    ) -> TorshResult<CompressedGradient> {
202        let start_time = std::time::Instant::now();
203
204        // Skip compression during warmup
205        if self.step_count < self.config.warmup_steps {
206            return self.no_compression(gradient, param_name);
207        }
208
209        // Apply error feedback if enabled
210        let adjusted_gradient = if self.config.error_feedback {
211            self.apply_error_feedback(gradient, param_name)?
212        } else {
213            gradient.clone()
214        };
215
216        let compressed = match &self.config.method {
217            CompressionMethod::TopK { k } => self.compress_top_k(&adjusted_gradient, *k)?,
218            CompressionMethod::RandomK { k } => self.compress_random_k(&adjusted_gradient, *k)?,
219            CompressionMethod::Threshold { threshold } => {
220                self.compress_threshold(&adjusted_gradient, *threshold)?
221            }
222            CompressionMethod::Quantization { bits } => {
223                self.compress_quantization(&adjusted_gradient, *bits)?
224            }
225            CompressionMethod::SignSGD => self.compress_sign_sgd(&adjusted_gradient)?,
226            CompressionMethod::Sketching { sketch_size } => {
227                self.compress_sketching(&adjusted_gradient, *sketch_size)?
228            }
229            CompressionMethod::PowerSGD { rank } => {
230                self.compress_power_sgd(&adjusted_gradient, *rank)?
231            }
232            CompressionMethod::TernaryQuant { threshold } => {
233                self.compress_ternary(&adjusted_gradient, *threshold)?
234            }
235            CompressionMethod::BimodalQuant { num_bins } => {
236                self.compress_bimodal(&adjusted_gradient, *num_bins)?
237            }
238            CompressionMethod::NaturalCompression { compression_factor } => {
239                self.compress_natural(&adjusted_gradient, *compression_factor)?
240            }
241            CompressionMethod::LayerwiseAdaptive {
242                base_ratio,
243                sensitivity,
244            } => self.compress_layerwise_adaptive(
245                &adjusted_gradient,
246                *base_ratio,
247                *sensitivity,
248                param_name,
249            )?,
250            CompressionMethod::EF21 {
251                compression_ratio,
252                momentum,
253            } => self.compress_ef21(
254                &adjusted_gradient,
255                *compression_ratio,
256                *momentum,
257                param_name,
258            )?,
259            CompressionMethod::None => return self.no_compression(gradient, param_name),
260        };
261
262        // Store compression error for error feedback
263        if self.config.error_feedback {
264            self.update_error_feedback(&compressed, gradient, param_name)?;
265        }
266
267        // Update statistics
268        let compression_time = start_time.elapsed().as_millis() as f64;
269        self.update_stats(&compressed, compression_time);
270
271        self.step_count += 1;
272        Ok(compressed)
273    }
274
275    /// Decompress a gradient
276    pub fn decompress(&self, compressed: &CompressedGradient) -> TorshResult<Tensor> {
277        match &compressed.data {
278            CompressedData::Sparse { indices, values } => {
279                self.decompress_sparse(indices, values, &compressed.original_shape)
280            }
281            CompressedData::Quantized {
282                values,
283                scale,
284                zero_point,
285            } => self.decompress_quantized(values, *scale, *zero_point, &compressed.original_shape),
286            CompressedData::Signs { signs, norm } => {
287                self.decompress_sign_sgd(signs, *norm, &compressed.original_shape)
288            }
289            CompressedData::LowRank {
290                left_factor,
291                right_factor,
292                rank,
293            } => self.decompress_power_sgd(
294                left_factor,
295                right_factor,
296                *rank,
297                &compressed.original_shape,
298            ),
299            CompressedData::Sketch {
300                sketch,
301                hash_a,
302                hash_b,
303            } => self.decompress_sketching(sketch, hash_a, hash_b, &compressed.original_shape),
304            CompressedData::Ternary { values, scale } => {
305                self.decompress_ternary(values, *scale, &compressed.original_shape)
306            }
307            CompressedData::Bimodal {
308                bin_indices,
309                bin_centers,
310            } => self.decompress_bimodal(bin_indices, bin_centers, &compressed.original_shape),
311            CompressedData::Natural {
312                values,
313                frequencies: _,
314                codebook,
315            } => self.decompress_natural(values, codebook, &compressed.original_shape),
316            CompressedData::EF21 {
317                compressed_values,
318                error_feedback: _,
319            } => self.decompress_ef21(compressed_values, &compressed.original_shape),
320        }
321    }
322
323    /// Apply error feedback
324    fn apply_error_feedback(&mut self, gradient: &Tensor, param_name: &str) -> TorshResult<Tensor> {
325        if let Some(error_buffer) = self.error_buffers.get(param_name) {
326            // adjusted_gradient = gradient + momentum * error_buffer
327            let scaled_error = error_buffer.mul_scalar(self.config.error_feedback_momentum)?;
328            Ok(gradient.add(&scaled_error)?)
329        } else {
330            Ok(gradient.clone())
331        }
332    }
333
334    /// Update error feedback buffer
335    fn update_error_feedback(
336        &mut self,
337        compressed: &CompressedGradient,
338        original: &Tensor,
339        param_name: &str,
340    ) -> TorshResult<()> {
341        let decompressed = self.decompress(compressed)?;
342        let error = original.sub(&decompressed)?;
343        self.error_buffers.insert(param_name.to_string(), error);
344        Ok(())
345    }
346
347    /// Top-K sparsification
348    fn compress_top_k(&self, gradient: &Tensor, k: f32) -> TorshResult<CompressedGradient> {
349        let flat_grad = gradient.flatten()?;
350        let numel = flat_grad.numel();
351        let k_elements = ((numel as f32) * k).ceil() as usize;
352
353        // Get absolute values for sorting
354        let abs_grad = flat_grad.abs()?;
355        let grad_data = flat_grad.to_vec()?;
356        let abs_data = abs_grad.to_vec()?;
357
358        // Find top-k indices
359        let mut indexed_values: Vec<(usize, f32)> =
360            abs_data.iter().enumerate().map(|(i, &v)| (i, v)).collect();
361        indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
362
363        let mut indices = Vec::new();
364        let mut values = Vec::new();
365
366        for &(idx, _) in indexed_values.iter().take(k_elements) {
367            indices.push(idx);
368            values.push(grad_data[idx]);
369        }
370
371        debug!("Top-K compression: kept {}/{} elements", k_elements, numel);
372
373        let original_norm = gradient.norm()?.item()?;
374        let compression_ratio = k;
375
376        Ok(CompressedGradient {
377            method: CompressionMethod::TopK { k },
378            data: CompressedData::Sparse { indices, values },
379            original_shape: gradient.shape().dims().to_vec(),
380            metadata: CompressionMetadata {
381                compression_ratio,
382                error_norm: 0.0, // Would calculate actual error
383                original_norm,
384                timestamp: std::time::SystemTime::now()
385                    .duration_since(std::time::UNIX_EPOCH)
386                    .expect("time should be after UNIX_EPOCH")
387                    .as_secs(),
388            },
389        })
390    }
391
392    /// Random-K sparsification
393    fn compress_random_k(&self, gradient: &Tensor, k: f32) -> TorshResult<CompressedGradient> {
394        let flat_grad = gradient.flatten()?;
395        let numel = flat_grad.numel();
396        let k_elements = ((numel as f32) * k).ceil() as usize;
397
398        let grad_data = flat_grad.to_vec()?;
399
400        // Random sampling of indices
401        let mut indices = Vec::new();
402        let mut values = Vec::new();
403
404        // Simple deterministic "random" selection for reproducibility
405        let step = numel / k_elements.max(1);
406        for i in (0..numel).step_by(step).take(k_elements) {
407            indices.push(i);
408            values.push(grad_data[i]);
409        }
410
411        debug!(
412            "Random-K compression: kept {}/{} elements",
413            k_elements, numel
414        );
415
416        let original_norm = gradient.norm()?.item()?;
417
418        Ok(CompressedGradient {
419            method: CompressionMethod::RandomK { k },
420            data: CompressedData::Sparse { indices, values },
421            original_shape: gradient.shape().dims().to_vec(),
422            metadata: CompressionMetadata {
423                compression_ratio: k,
424                error_norm: 0.0,
425                original_norm,
426                timestamp: std::time::SystemTime::now()
427                    .duration_since(std::time::UNIX_EPOCH)
428                    .expect("time should be after UNIX_EPOCH")
429                    .as_secs(),
430            },
431        })
432    }
433
434    /// Threshold-based sparsification
435    fn compress_threshold(
436        &self,
437        gradient: &Tensor,
438        threshold: f32,
439    ) -> TorshResult<CompressedGradient> {
440        let flat_grad = gradient.flatten()?;
441        let grad_data = flat_grad.to_vec()?;
442
443        let mut indices = Vec::new();
444        let mut values = Vec::new();
445
446        for (i, &value) in grad_data.iter().enumerate() {
447            if value.abs() >= threshold {
448                indices.push(i);
449                values.push(value);
450            }
451        }
452
453        let compression_ratio = indices.len() as f32 / grad_data.len() as f32;
454        debug!(
455            "Threshold compression: kept {}/{} elements",
456            indices.len(),
457            grad_data.len()
458        );
459
460        let original_norm = gradient.norm()?.item()?;
461
462        Ok(CompressedGradient {
463            method: CompressionMethod::Threshold { threshold },
464            data: CompressedData::Sparse { indices, values },
465            original_shape: gradient.shape().dims().to_vec(),
466            metadata: CompressionMetadata {
467                compression_ratio,
468                error_norm: 0.0,
469                original_norm,
470                timestamp: std::time::SystemTime::now()
471                    .duration_since(std::time::UNIX_EPOCH)
472                    .expect("time should be after UNIX_EPOCH")
473                    .as_secs(),
474            },
475        })
476    }
477
478    /// Quantization compression
479    fn compress_quantization(
480        &self,
481        gradient: &Tensor,
482        bits: u8,
483    ) -> TorshResult<CompressedGradient> {
484        let flat_grad = gradient.flatten()?;
485        let grad_data = flat_grad.to_vec()?;
486
487        // Simple uniform quantization
488        let min_val = grad_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
489        let max_val = grad_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
490
491        let levels = (1 << bits) - 1;
492        let scale = (max_val - min_val) / levels as f32;
493        let zero_point = (-min_val / scale).round() as u8;
494
495        let mut quantized_values = Vec::new();
496        for &value in &grad_data {
497            let quantized = ((value / scale) + zero_point as f32)
498                .round()
499                .clamp(0.0, levels as f32) as u8;
500            quantized_values.push(quantized);
501        }
502
503        debug!("Quantization: {} bits, {} levels", bits, levels);
504
505        let original_norm = gradient.norm()?.item()?;
506        let compression_ratio = (bits as f32) / 32.0; // Assuming original is fp32
507
508        Ok(CompressedGradient {
509            method: CompressionMethod::Quantization { bits },
510            data: CompressedData::Quantized {
511                values: quantized_values,
512                scale,
513                zero_point,
514            },
515            original_shape: gradient.shape().dims().to_vec(),
516            metadata: CompressionMetadata {
517                compression_ratio,
518                error_norm: 0.0,
519                original_norm,
520                timestamp: std::time::SystemTime::now()
521                    .duration_since(std::time::UNIX_EPOCH)
522                    .expect("time should be after UNIX_EPOCH")
523                    .as_secs(),
524            },
525        })
526    }
527
528    /// Sign SGD compression
529    fn compress_sign_sgd(&self, gradient: &Tensor) -> TorshResult<CompressedGradient> {
530        let flat_grad = gradient.flatten()?;
531        let grad_data = flat_grad.to_vec()?;
532        let norm = gradient.norm()?.item()?;
533
534        let signs: Vec<bool> = grad_data.iter().map(|&x| x >= 0.0).collect();
535
536        debug!(
537            "SignSGD compression: {} elements -> {} bits",
538            grad_data.len(),
539            signs.len()
540        );
541
542        Ok(CompressedGradient {
543            method: CompressionMethod::SignSGD,
544            data: CompressedData::Signs { signs, norm },
545            original_shape: gradient.shape().dims().to_vec(),
546            metadata: CompressionMetadata {
547                compression_ratio: 1.0 / 32.0, // 1 bit vs 32 bits
548                error_norm: 0.0,
549                original_norm: norm,
550                timestamp: std::time::SystemTime::now()
551                    .duration_since(std::time::UNIX_EPOCH)
552                    .expect("time should be after UNIX_EPOCH")
553                    .as_secs(),
554            },
555        })
556    }
557
558    /// Sketching compression (simplified)
559    fn compress_sketching(
560        &self,
561        gradient: &Tensor,
562        sketch_size: usize,
563    ) -> TorshResult<CompressedGradient> {
564        let flat_grad = gradient.flatten()?;
565        let grad_data = flat_grad.to_vec()?;
566
567        // Simple sketch: just take first sketch_size elements
568        let sketch: Vec<f32> = grad_data.iter().take(sketch_size).copied().collect();
569
570        // Mock hash functions
571        let hash_a: Vec<u32> = (0..grad_data.len()).map(|i| (i * 17 + 23) as u32).collect();
572        let hash_b: Vec<u32> = (0..grad_data.len()).map(|i| (i * 37 + 41) as u32).collect();
573
574        let compression_ratio = sketch_size as f32 / grad_data.len() as f32;
575        let original_norm = gradient.norm()?.item()?;
576
577        debug!(
578            "Sketching compression: {} -> {} elements",
579            grad_data.len(),
580            sketch_size
581        );
582
583        Ok(CompressedGradient {
584            method: CompressionMethod::Sketching { sketch_size },
585            data: CompressedData::Sketch {
586                sketch,
587                hash_a,
588                hash_b,
589            },
590            original_shape: gradient.shape().dims().to_vec(),
591            metadata: CompressionMetadata {
592                compression_ratio,
593                error_norm: 0.0,
594                original_norm,
595                timestamp: std::time::SystemTime::now()
596                    .duration_since(std::time::UNIX_EPOCH)
597                    .expect("time should be after UNIX_EPOCH")
598                    .as_secs(),
599            },
600        })
601    }
602
603    /// PowerSGD compression (simplified)
604    fn compress_power_sgd(
605        &self,
606        gradient: &Tensor,
607        rank: usize,
608    ) -> TorshResult<CompressedGradient> {
609        let shape_obj = gradient.shape();
610        let shape = shape_obj.dims();
611        if shape.len() != 2 {
612            return Err(TorshDistributedError::invalid_argument(
613                "gradient",
614                format!("PowerSGD requires 2D tensors, got {}D tensor", shape.len()),
615                "2D tensor with shape [rows, cols]",
616            ));
617        }
618
619        let rows = shape[0];
620        let cols = shape[1];
621
622        // Mock low-rank factorization: A ≈ P @ Q^T
623        let left_factor_size = rows * rank;
624        let right_factor_size = cols * rank;
625
626        let flat_grad = gradient.flatten()?;
627        let grad_data = flat_grad.to_vec()?;
628
629        // Simplified: just take portions of the gradient as factors
630        let left_factor: Vec<f32> = grad_data.iter().take(left_factor_size).copied().collect();
631        let right_factor: Vec<f32> = grad_data
632            .iter()
633            .skip(left_factor_size)
634            .take(right_factor_size)
635            .copied()
636            .collect();
637
638        let compression_ratio =
639            (left_factor_size + right_factor_size) as f32 / grad_data.len() as f32;
640        let original_norm = gradient.norm()?.item()?;
641
642        debug!(
643            "PowerSGD compression: rank {}, ratio {:.3}",
644            rank, compression_ratio
645        );
646
647        Ok(CompressedGradient {
648            method: CompressionMethod::PowerSGD { rank },
649            data: CompressedData::LowRank {
650                left_factor,
651                right_factor,
652                rank,
653            },
654            original_shape: gradient.shape().dims().to_vec(),
655            metadata: CompressionMetadata {
656                compression_ratio,
657                error_norm: 0.0,
658                original_norm,
659                timestamp: std::time::SystemTime::now()
660                    .duration_since(std::time::UNIX_EPOCH)
661                    .expect("time should be after UNIX_EPOCH")
662                    .as_secs(),
663            },
664        })
665    }
666
667    /// Ternary quantization compression
668    fn compress_ternary(
669        &self,
670        gradient: &Tensor,
671        threshold: f32,
672    ) -> TorshResult<CompressedGradient> {
673        let flat_grad = gradient.flatten()?;
674        let grad_data = flat_grad.to_vec()?;
675        let original_norm = gradient.norm()?.item()?;
676
677        // Compute scaling factor based on gradient magnitude
678        let scale = original_norm / (grad_data.len() as f32).sqrt();
679
680        let mut ternary_values = Vec::new();
681        for &value in &grad_data {
682            let normalized = value / scale;
683            let ternary = if normalized > threshold {
684                1i8
685            } else if normalized < -threshold {
686                -1i8
687            } else {
688                0i8
689            };
690            ternary_values.push(ternary);
691        }
692
693        let compression_ratio = 2.0 / 32.0; // ~2 bits per value vs 32 bits
694        debug!(
695            "Ternary compression: threshold {}, scale {:.6}",
696            threshold, scale
697        );
698
699        Ok(CompressedGradient {
700            method: CompressionMethod::TernaryQuant { threshold },
701            data: CompressedData::Ternary {
702                values: ternary_values,
703                scale,
704            },
705            original_shape: gradient.shape().dims().to_vec(),
706            metadata: CompressionMetadata {
707                compression_ratio,
708                error_norm: 0.0,
709                original_norm,
710                timestamp: std::time::SystemTime::now()
711                    .duration_since(std::time::UNIX_EPOCH)
712                    .expect("time should be after UNIX_EPOCH")
713                    .as_secs(),
714            },
715        })
716    }
717
718    /// Bimodal quantization compression
719    fn compress_bimodal(
720        &self,
721        gradient: &Tensor,
722        num_bins: usize,
723    ) -> TorshResult<CompressedGradient> {
724        let flat_grad = gradient.flatten()?;
725        let grad_data = flat_grad.to_vec()?;
726        let original_norm = gradient.norm()?.item()?;
727
728        // Find min and max values for binning
729        let min_val = grad_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
730        let max_val = grad_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
731
732        // Create bin centers
733        let mut bin_centers = Vec::new();
734        for i in 0..num_bins {
735            let center = min_val + (max_val - min_val) * (i as f32 + 0.5) / (num_bins as f32);
736            bin_centers.push(center);
737        }
738
739        // Assign each value to nearest bin
740        let mut bin_indices = Vec::new();
741        for &value in &grad_data {
742            let mut best_bin = 0;
743            let mut best_distance = f32::INFINITY;
744
745            for (bin_idx, &center) in bin_centers.iter().enumerate() {
746                let distance = (value - center).abs();
747                if distance < best_distance {
748                    best_distance = distance;
749                    best_bin = bin_idx;
750                }
751            }
752            bin_indices.push(best_bin as u8);
753        }
754
755        let bits_per_bin = (num_bins as f32).log2().ceil();
756        let compression_ratio = bits_per_bin / 32.0;
757        debug!(
758            "Bimodal compression: {} bins, {:.1} bits/value",
759            num_bins, bits_per_bin
760        );
761
762        Ok(CompressedGradient {
763            method: CompressionMethod::BimodalQuant { num_bins },
764            data: CompressedData::Bimodal {
765                bin_indices,
766                bin_centers,
767            },
768            original_shape: gradient.shape().dims().to_vec(),
769            metadata: CompressionMetadata {
770                compression_ratio,
771                error_norm: 0.0,
772                original_norm,
773                timestamp: std::time::SystemTime::now()
774                    .duration_since(std::time::UNIX_EPOCH)
775                    .expect("time should be after UNIX_EPOCH")
776                    .as_secs(),
777            },
778        })
779    }
780
781    /// Natural compression based on gradient distribution
782    fn compress_natural(
783        &self,
784        gradient: &Tensor,
785        compression_factor: f32,
786    ) -> TorshResult<CompressedGradient> {
787        let flat_grad = gradient.flatten()?;
788        let grad_data = flat_grad.to_vec()?;
789        let original_norm = gradient.norm()?.item()?;
790
791        // Create frequency histogram for natural encoding
792        let num_unique = (grad_data.len() as f32 * compression_factor).ceil() as usize;
793        let mut value_counts: std::collections::HashMap<i32, u32> =
794            std::collections::HashMap::new();
795
796        // Quantize values for frequency counting
797        let scale = 10000.0; // Fixed point scale
798        for &value in &grad_data {
799            let quantized = (value * scale).round() as i32;
800            *value_counts.entry(quantized).or_insert(0) += 1;
801        }
802
803        // Get most frequent values
804        let mut sorted_values: Vec<_> = value_counts.into_iter().collect();
805        sorted_values.sort_by(|a, b| b.1.cmp(&a.1));
806        sorted_values.truncate(num_unique);
807
808        // Create codebook and compressed representation
809        let codebook: Vec<f32> = sorted_values
810            .iter()
811            .map(|(v, _)| *v as f32 / scale)
812            .collect();
813        let frequencies: Vec<u32> = sorted_values.iter().map(|(_, f)| *f).collect();
814
815        // Encode values using codebook
816        let mut compressed_values = Vec::new();
817        for &value in &grad_data {
818            // Find closest codebook entry
819            let mut best_idx = 0;
820            let mut best_distance = f32::INFINITY;
821            for (idx, &codebook_val) in codebook.iter().enumerate() {
822                let distance = (value - codebook_val).abs();
823                if distance < best_distance {
824                    best_distance = distance;
825                    best_idx = idx;
826                }
827            }
828            compressed_values.push(best_idx as f32);
829        }
830
831        debug!(
832            "Natural compression: {} unique values from {} total",
833            num_unique,
834            grad_data.len()
835        );
836
837        Ok(CompressedGradient {
838            method: CompressionMethod::NaturalCompression { compression_factor },
839            data: CompressedData::Natural {
840                values: compressed_values,
841                frequencies,
842                codebook,
843            },
844            original_shape: gradient.shape().dims().to_vec(),
845            metadata: CompressionMetadata {
846                compression_ratio: compression_factor,
847                error_norm: 0.0,
848                original_norm,
849                timestamp: std::time::SystemTime::now()
850                    .duration_since(std::time::UNIX_EPOCH)
851                    .expect("time should be after UNIX_EPOCH")
852                    .as_secs(),
853            },
854        })
855    }
856
857    /// Layerwise adaptive compression
858    fn compress_layerwise_adaptive(
859        &self,
860        gradient: &Tensor,
861        base_ratio: f32,
862        sensitivity: f32,
863        param_name: &str,
864    ) -> TorshResult<CompressedGradient> {
865        let _original_norm = gradient.norm()?.item();
866
867        // Adapt compression ratio based on layer sensitivity
868        let layer_sensitivity = if param_name.contains("weight") {
869            1.0
870        } else {
871            sensitivity
872        };
873        let adapted_ratio = base_ratio * layer_sensitivity;
874
875        // Use TopK with adapted ratio
876        self.compress_top_k(gradient, adapted_ratio)
877    }
878
879    /// EF21 compression with error feedback and momentum
880    fn compress_ef21(
881        &mut self,
882        gradient: &Tensor,
883        compression_ratio: f32,
884        momentum: f32,
885        param_name: &str,
886    ) -> TorshResult<CompressedGradient> {
887        let flat_grad = gradient.flatten()?;
888        let grad_data = flat_grad.to_vec()?;
889        let original_norm = gradient.norm()?.item()?;
890
891        // Get or create error feedback buffer
892        let error_key = format!("ef21_{}", param_name);
893        let error_feedback = if let Some(prev_error) = self.error_buffers.get(&error_key) {
894            prev_error.flatten()?.to_vec()?
895        } else {
896            vec![0.0; grad_data.len()]
897        };
898
899        // Apply momentum to error feedback
900        let mut adjusted_grad = Vec::new();
901        for (&grad_val, &error_val) in grad_data.iter().zip(error_feedback.iter()) {
902            adjusted_grad.push(grad_val + momentum * error_val);
903        }
904
905        // Compress using TopK
906        let k_elements = (grad_data.len() as f32 * compression_ratio).ceil() as usize;
907        let mut indexed_values: Vec<(usize, f32)> = adjusted_grad
908            .iter()
909            .enumerate()
910            .map(|(i, &v)| (i, v.abs()))
911            .collect();
912        indexed_values.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
913
914        let mut compressed_values = vec![0.0; grad_data.len()];
915        let mut new_error_feedback = adjusted_grad.clone();
916
917        // Keep top-k values
918        for &(idx, _) in indexed_values.iter().take(k_elements) {
919            compressed_values[idx] = adjusted_grad[idx];
920            new_error_feedback[idx] = 0.0; // Reset error for transmitted values
921        }
922
923        // Update error feedback buffer
924        let error_tensor = Tensor::from_vec(new_error_feedback.clone(), gradient.shape().dims())?;
925        self.error_buffers.insert(error_key, error_tensor);
926
927        debug!(
928            "EF21 compression: kept {}/{} elements with momentum {}",
929            k_elements,
930            grad_data.len(),
931            momentum
932        );
933
934        Ok(CompressedGradient {
935            method: CompressionMethod::EF21 {
936                compression_ratio,
937                momentum,
938            },
939            data: CompressedData::EF21 {
940                compressed_values,
941                error_feedback: new_error_feedback,
942            },
943            original_shape: gradient.shape().dims().to_vec(),
944            metadata: CompressionMetadata {
945                compression_ratio,
946                error_norm: 0.0,
947                original_norm,
948                timestamp: std::time::SystemTime::now()
949                    .duration_since(std::time::UNIX_EPOCH)
950                    .expect("time should be after UNIX_EPOCH")
951                    .as_secs(),
952            },
953        })
954    }
955
956    /// No compression (passthrough)
957    fn no_compression(
958        &self,
959        gradient: &Tensor,
960        _param_name: &str,
961    ) -> TorshResult<CompressedGradient> {
962        let flat_grad = gradient.flatten()?;
963        let grad_data = flat_grad.to_vec()?;
964        let indices: Vec<usize> = (0..grad_data.len()).collect();
965
966        let original_norm = gradient.norm()?.item()?;
967
968        Ok(CompressedGradient {
969            method: CompressionMethod::None,
970            data: CompressedData::Sparse {
971                indices,
972                values: grad_data,
973            },
974            original_shape: gradient.shape().dims().to_vec(),
975            metadata: CompressionMetadata {
976                compression_ratio: 1.0,
977                error_norm: 0.0,
978                original_norm,
979                timestamp: std::time::SystemTime::now()
980                    .duration_since(std::time::UNIX_EPOCH)
981                    .expect("time should be after UNIX_EPOCH")
982                    .as_secs(),
983            },
984        })
985    }
986
987    /// Decompress sparse representation
988    fn decompress_sparse(
989        &self,
990        indices: &[usize],
991        values: &[f32],
992        shape: &[usize],
993    ) -> TorshResult<Tensor> {
994        let total_elements: usize = shape.iter().product();
995        let mut data = vec![0.0; total_elements];
996
997        for (&idx, &val) in indices.iter().zip(values.iter()) {
998            if idx < total_elements {
999                data[idx] = val;
1000            }
1001        }
1002
1003        Ok(Tensor::from_vec(data, shape)?)
1004    }
1005
1006    /// Decompress quantized representation
1007    fn decompress_quantized(
1008        &self,
1009        values: &[u8],
1010        scale: f32,
1011        zero_point: u8,
1012        shape: &[usize],
1013    ) -> TorshResult<Tensor> {
1014        let data: Vec<f32> = values
1015            .iter()
1016            .map(|&q| (q as f32 - zero_point as f32) * scale)
1017            .collect();
1018
1019        Ok(Tensor::from_vec(data, shape)?)
1020    }
1021
1022    /// Decompress SignSGD representation
1023    fn decompress_sign_sgd(
1024        &self,
1025        signs: &[bool],
1026        norm: f32,
1027        shape: &[usize],
1028    ) -> TorshResult<Tensor> {
1029        let total_elements: usize = shape.iter().product();
1030        let magnitude = norm / (total_elements as f32).sqrt();
1031
1032        let data: Vec<f32> = signs
1033            .iter()
1034            .map(|&sign| if sign { magnitude } else { -magnitude })
1035            .collect();
1036
1037        Ok(Tensor::from_vec(data, shape)?)
1038    }
1039
1040    /// Decompress PowerSGD representation (simplified)
1041    fn decompress_power_sgd(
1042        &self,
1043        left_factor: &[f32],
1044        right_factor: &[f32],
1045        _rank: usize,
1046        shape: &[usize],
1047    ) -> TorshResult<Tensor> {
1048        // Simplified: just combine the factors somehow
1049        let total_elements: usize = shape.iter().product();
1050        let mut data = vec![0.0; total_elements];
1051
1052        let left_len = left_factor.len();
1053        let right_len = right_factor.len();
1054
1055        for i in 0..total_elements.min(left_len + right_len) {
1056            if i < left_len {
1057                data[i] = left_factor[i];
1058            } else {
1059                data[i] = right_factor[i - left_len];
1060            }
1061        }
1062
1063        Ok(Tensor::from_vec(data, shape)?)
1064    }
1065
1066    /// Decompress sketching representation (simplified)
1067    fn decompress_sketching(
1068        &self,
1069        sketch: &[f32],
1070        _hash_a: &[u32],
1071        _hash_b: &[u32],
1072        shape: &[usize],
1073    ) -> TorshResult<Tensor> {
1074        let total_elements: usize = shape.iter().product();
1075        let mut data = vec![0.0; total_elements];
1076
1077        // Simplified: just spread sketch values
1078        for (i, &val) in sketch.iter().enumerate() {
1079            if i < total_elements {
1080                data[i] = val;
1081            }
1082        }
1083
1084        Ok(Tensor::from_vec(data, shape)?)
1085    }
1086
1087    /// Decompress ternary representation
1088    fn decompress_ternary(
1089        &self,
1090        values: &[i8],
1091        scale: f32,
1092        shape: &[usize],
1093    ) -> TorshResult<Tensor> {
1094        let data: Vec<f32> = values
1095            .iter()
1096            .map(|&ternary| (ternary as f32) * scale)
1097            .collect();
1098
1099        Ok(Tensor::from_vec(data, shape)?)
1100    }
1101
1102    /// Decompress bimodal representation
1103    fn decompress_bimodal(
1104        &self,
1105        bin_indices: &[u8],
1106        bin_centers: &[f32],
1107        shape: &[usize],
1108    ) -> TorshResult<Tensor> {
1109        let data: Vec<f32> = bin_indices
1110            .iter()
1111            .map(|&bin_idx| bin_centers.get(bin_idx as usize).copied().unwrap_or(0.0))
1112            .collect();
1113
1114        Ok(Tensor::from_vec(data, shape)?)
1115    }
1116
1117    /// Decompress natural representation
1118    fn decompress_natural(
1119        &self,
1120        values: &[f32],
1121        codebook: &[f32],
1122        shape: &[usize],
1123    ) -> TorshResult<Tensor> {
1124        let data: Vec<f32> = values
1125            .iter()
1126            .map(|&idx| {
1127                let idx_usize = idx as usize;
1128                codebook.get(idx_usize).copied().unwrap_or(0.0)
1129            })
1130            .collect();
1131
1132        Ok(Tensor::from_vec(data, shape)?)
1133    }
1134
1135    /// Decompress EF21 representation
1136    fn decompress_ef21(&self, compressed_values: &[f32], shape: &[usize]) -> TorshResult<Tensor> {
1137        // For EF21, the compressed values are already in the correct format
1138        Ok(Tensor::from_vec(compressed_values.to_vec(), shape)?)
1139    }
1140
1141    /// Update compression statistics
1142    fn update_stats(&mut self, compressed: &CompressedGradient, compression_time: f64) {
1143        self.stats.total_compressions += 1;
1144        self.stats.avg_compression_ratio = (self.stats.avg_compression_ratio
1145            * (self.stats.total_compressions - 1) as f64
1146            + compressed.metadata.compression_ratio as f64)
1147            / self.stats.total_compressions as f64;
1148        self.stats.compression_time_ms += compression_time;
1149    }
1150
1151    /// Get compression statistics
1152    pub fn get_stats(&self) -> &CompressionStats {
1153        &self.stats
1154    }
1155
1156    /// Reset error feedback buffers
1157    pub fn reset_error_feedback(&mut self) {
1158        self.error_buffers.clear();
1159    }
1160
1161    /// Get current step count
1162    pub fn step_count(&self) -> usize {
1163        self.step_count
1164    }
1165}
1166
1167#[cfg(test)]
1168mod tests {
1169    use super::*;
1170
1171    #[test]
1172    fn test_compression_config() {
1173        let config = CompressionConfig::default();
1174        assert_eq!(config.compression_ratio, 0.1);
1175        assert!(config.error_feedback);
1176        assert_eq!(config.warmup_steps, 100);
1177    }
1178
1179    #[test]
1180    fn test_compression_methods() {
1181        assert_ne!(
1182            CompressionMethod::TopK { k: 0.1 },
1183            CompressionMethod::SignSGD
1184        );
1185        assert_ne!(
1186            CompressionMethod::Quantization { bits: 8 },
1187            CompressionMethod::None
1188        );
1189    }
1190
1191    #[tokio::test]
1192    async fn test_gradient_compressor_creation() {
1193        let config = CompressionConfig::default();
1194        let compressor = GradientCompressor::new(config);
1195
1196        assert_eq!(compressor.step_count(), 0);
1197        assert_eq!(compressor.get_stats().total_compressions, 0);
1198    }
1199
1200    #[tokio::test]
1201    async fn test_top_k_compression() -> TorshResult<()> {
1202        let config = CompressionConfig {
1203            method: CompressionMethod::TopK { k: 0.5 },
1204            warmup_steps: 0,
1205            ..Default::default()
1206        };
1207        let mut compressor = GradientCompressor::new(config);
1208
1209        let gradient = torsh_tensor::creation::randn(&[10, 10])?;
1210        let compressed = compressor.compress(&gradient, "test_param")?;
1211
1212        match &compressed.data {
1213            CompressedData::Sparse { indices, values } => {
1214                assert_eq!(indices.len(), values.len());
1215                assert!(indices.len() <= 50); // Top 50% of 100 elements
1216            }
1217            _ => panic!("Expected sparse compression for TopK"),
1218        }
1219
1220        let decompressed = compressor.decompress(&compressed)?;
1221        assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
1222
1223        Ok(())
1224    }
1225
1226    #[tokio::test]
1227    async fn test_sign_sgd_compression() -> TorshResult<()> {
1228        let config = CompressionConfig {
1229            method: CompressionMethod::SignSGD,
1230            warmup_steps: 0,
1231            ..Default::default()
1232        };
1233        let mut compressor = GradientCompressor::new(config);
1234
1235        let gradient = torsh_tensor::creation::randn(&[5, 5])?;
1236        let compressed = compressor.compress(&gradient, "test_param")?;
1237
1238        match &compressed.data {
1239            CompressedData::Signs { signs, norm } => {
1240                assert_eq!(signs.len(), 25); // 5x5 = 25 elements
1241                assert!(*norm > 0.0);
1242            }
1243            _ => panic!("Expected sign compression for SignSGD"),
1244        }
1245
1246        let decompressed = compressor.decompress(&compressed)?;
1247        assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
1248
1249        Ok(())
1250    }
1251
1252    #[tokio::test]
1253    async fn test_quantization_compression() -> TorshResult<()> {
1254        let config = CompressionConfig {
1255            method: CompressionMethod::Quantization { bits: 8 },
1256            warmup_steps: 0,
1257            ..Default::default()
1258        };
1259        let mut compressor = GradientCompressor::new(config);
1260
1261        let gradient = torsh_tensor::creation::randn(&[4, 4])?;
1262        let compressed = compressor.compress(&gradient, "test_param")?;
1263
1264        match &compressed.data {
1265            CompressedData::Quantized {
1266                values,
1267                scale,
1268                zero_point: _,
1269            } => {
1270                assert_eq!(values.len(), 16); // 4x4 = 16 elements
1271                assert!(*scale > 0.0);
1272                // zero_point is u8, always <= 255
1273            }
1274            _ => panic!("Expected quantized compression"),
1275        }
1276
1277        let decompressed = compressor.decompress(&compressed)?;
1278        assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
1279
1280        Ok(())
1281    }
1282
1283    #[tokio::test]
1284    async fn test_no_compression() -> TorshResult<()> {
1285        let config = CompressionConfig {
1286            method: CompressionMethod::None,
1287            warmup_steps: 0,
1288            ..Default::default()
1289        };
1290        let mut compressor = GradientCompressor::new(config);
1291
1292        let gradient = torsh_tensor::creation::randn(&[3, 3])?;
1293        let compressed = compressor.compress(&gradient, "test_param")?;
1294
1295        assert_eq!(compressed.metadata.compression_ratio, 1.0);
1296
1297        let decompressed = compressor.decompress(&compressed)?;
1298        assert_eq!(decompressed.shape().dims(), gradient.shape().dims());
1299
1300        Ok(())
1301    }
1302
1303    #[test]
1304    fn test_compression_stats() {
1305        let stats = CompressionStats {
1306            total_compressions: 100,
1307            avg_compression_ratio: 0.25,
1308            total_communication_reduction: 1024 * 1024, // 1MB
1309            avg_error_norm: 0.01,
1310            compression_time_ms: 250.5,
1311        };
1312
1313        assert_eq!(stats.total_compressions, 100);
1314        assert_eq!(stats.avg_compression_ratio, 0.25);
1315        assert_eq!(stats.total_communication_reduction, 1024 * 1024);
1316    }
1317}