Skip to main content

torsh_core/
compression.rs

1//! Tensor compression schemes and pruning metadata for ToRSh Core
2//!
3//! This module provides comprehensive tensor compression techniques including:
4//! - Magnitude-based pruning with threshold detection
5//! - Structured pruning patterns (block-wise, channel-wise, attention head-wise)
6//! - Compression encoding schemes (run-length, Huffman, delta encoding)
7//! - Quantization-aware compression metadata
8
9use crate::dtype::DType;
10use crate::shape::Shape;
11
12#[cfg(not(feature = "std"))]
13use alloc::{vec, vec::Vec};
14
15/// Pruning strategies for tensor compression
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum PruningStrategy {
18    /// Magnitude-based pruning: remove weights below threshold
19    Magnitude { threshold_percentile: u8 },
20
21    /// Structured pruning: remove entire blocks
22    BlockWise { block_size: (usize, usize) },
23
24    /// Channel-wise pruning for convolutional layers
25    ChannelWise { channels_to_prune: usize },
26
27    /// Attention head pruning for transformer models
28    AttentionHead { heads_to_prune: usize },
29
30    /// Movement pruning: prune based on weight movement during training
31    Movement { sensitivity: f32 },
32
33    /// Gradual magnitude pruning with schedule
34    GradualMagnitude {
35        initial_sparsity: f32,
36        final_sparsity: f32,
37    },
38}
39
40impl PruningStrategy {
41    /// Get the expected sparsity for this strategy
42    pub fn expected_sparsity(&self) -> f32 {
43        match self {
44            Self::Magnitude {
45                threshold_percentile,
46            } => *threshold_percentile as f32 / 100.0,
47            Self::GradualMagnitude { final_sparsity, .. } => *final_sparsity,
48            _ => 0.5, // Default 50% for other strategies
49        }
50    }
51
52    /// Check if this is a structured pruning strategy
53    pub fn is_structured(&self) -> bool {
54        matches!(
55            self,
56            Self::BlockWise { .. } | Self::ChannelWise { .. } | Self::AttentionHead { .. }
57        )
58    }
59}
60
61/// Pruning metadata tracking which elements/blocks are pruned
62#[derive(Debug, Clone)]
63pub struct PruningMetadata {
64    /// Pruning strategy used
65    strategy: PruningStrategy,
66
67    /// Pruned element indices (for unstructured pruning)
68    pruned_indices: Option<Vec<usize>>,
69
70    /// Pruned block indices (for structured pruning)
71    pruned_blocks: Option<Vec<(usize, usize)>>,
72
73    /// Pruned channel indices (for channel-wise pruning)
74    pruned_channels: Option<Vec<usize>>,
75
76    /// Actual sparsity achieved
77    achieved_sparsity: f32,
78
79    /// Original tensor shape before pruning
80    original_shape: Shape,
81
82    /// Threshold value used (for magnitude pruning)
83    threshold_value: Option<f32>,
84
85    /// Compression ratio (original_size / pruned_size)
86    compression_ratio: f32,
87}
88
89impl PruningMetadata {
90    /// Create new pruning metadata
91    pub fn new(strategy: PruningStrategy, original_shape: Shape, achieved_sparsity: f32) -> Self {
92        let compression_ratio = 1.0 / (1.0 - achieved_sparsity);
93
94        Self {
95            strategy,
96            pruned_indices: None,
97            pruned_blocks: None,
98            pruned_channels: None,
99            achieved_sparsity,
100            original_shape,
101            threshold_value: None,
102            compression_ratio,
103        }
104    }
105
106    /// Set pruned indices for unstructured pruning
107    pub fn with_indices(mut self, indices: Vec<usize>) -> Self {
108        self.pruned_indices = Some(indices);
109        self
110    }
111
112    /// Set pruned blocks for structured pruning
113    pub fn with_blocks(mut self, blocks: Vec<(usize, usize)>) -> Self {
114        self.pruned_blocks = Some(blocks);
115        self
116    }
117
118    /// Set pruned channels
119    pub fn with_channels(mut self, channels: Vec<usize>) -> Self {
120        self.pruned_channels = Some(channels);
121        self
122    }
123
124    /// Set threshold value
125    pub fn with_threshold(mut self, threshold: f32) -> Self {
126        self.threshold_value = Some(threshold);
127        self
128    }
129
130    /// Get pruning strategy
131    pub fn strategy(&self) -> PruningStrategy {
132        self.strategy
133    }
134
135    /// Get achieved sparsity
136    pub fn sparsity(&self) -> f32 {
137        self.achieved_sparsity
138    }
139
140    /// Get compression ratio
141    pub fn compression_ratio(&self) -> f32 {
142        self.compression_ratio
143    }
144
145    /// Get number of pruned elements
146    pub fn num_pruned_elements(&self) -> usize {
147        if let Some(ref indices) = self.pruned_indices {
148            indices.len()
149        } else if let Some(ref blocks) = self.pruned_blocks {
150            blocks.len()
151        } else if let Some(ref channels) = self.pruned_channels {
152            channels.len()
153        } else {
154            0
155        }
156    }
157
158    /// Check if an element is pruned
159    pub fn is_element_pruned(&self, index: usize) -> bool {
160        if let Some(ref indices) = self.pruned_indices {
161            indices.binary_search(&index).is_ok()
162        } else {
163            false
164        }
165    }
166
167    /// Get memory savings in bytes
168    pub fn memory_savings(&self, dtype: DType) -> usize {
169        let total_elements = self.original_shape.numel();
170        let pruned_elements = (total_elements as f32 * self.achieved_sparsity) as usize;
171        pruned_elements * dtype.size()
172    }
173}
174
175/// Compression encoding schemes for sparse tensor indices
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
177pub enum CompressionEncoding {
178    /// No encoding (raw values)
179    Raw,
180
181    /// Run-length encoding for consecutive indices
182    RunLength,
183
184    /// Delta encoding for sequential indices
185    Delta,
186
187    /// Huffman encoding for variable-length codes
188    Huffman,
189
190    /// Bitmap encoding for dense regions
191    Bitmap,
192
193    /// Hybrid encoding combining multiple schemes
194    Hybrid,
195}
196
197impl CompressionEncoding {
198    /// Get the expected compression ratio for typical sparse tensors
199    pub fn expected_compression_ratio(&self) -> f32 {
200        match self {
201            Self::Raw => 1.0,
202            Self::RunLength => 2.0,
203            Self::Delta => 1.5,
204            Self::Huffman => 2.5,
205            Self::Bitmap => 3.0,
206            Self::Hybrid => 3.5,
207        }
208    }
209
210    /// Check if encoding requires sorted indices
211    pub fn requires_sorted_indices(&self) -> bool {
212        matches!(self, Self::RunLength | Self::Delta | Self::Hybrid)
213    }
214}
215
216/// Run-length encoded index sequence
217#[derive(Debug, Clone)]
218pub struct RunLengthEncoded {
219    /// Starting indices of runs
220    start_indices: Vec<usize>,
221
222    /// Lengths of runs
223    run_lengths: Vec<usize>,
224
225    /// Total number of elements encoded
226    total_elements: usize,
227}
228
229impl RunLengthEncoded {
230    /// Create new run-length encoding from sorted indices
231    pub fn encode(indices: &[usize]) -> Self {
232        if indices.is_empty() {
233            return Self {
234                start_indices: vec![],
235                run_lengths: vec![],
236                total_elements: 0,
237            };
238        }
239
240        let mut start_indices = Vec::new();
241        let mut run_lengths = Vec::new();
242
243        let mut current_start = indices[0];
244        let mut current_length = 1;
245
246        for i in 1..indices.len() {
247            if indices[i] == indices[i - 1] + 1 {
248                // Continue current run
249                current_length += 1;
250            } else {
251                // Start new run
252                start_indices.push(current_start);
253                run_lengths.push(current_length);
254                current_start = indices[i];
255                current_length = 1;
256            }
257        }
258
259        // Push last run
260        start_indices.push(current_start);
261        run_lengths.push(current_length);
262
263        Self {
264            start_indices,
265            run_lengths,
266            total_elements: indices.len(),
267        }
268    }
269
270    /// Decode run-length encoding back to indices
271    pub fn decode(&self) -> Vec<usize> {
272        let mut indices = Vec::with_capacity(self.total_elements);
273
274        for (start, length) in self.start_indices.iter().zip(self.run_lengths.iter()) {
275            for offset in 0..*length {
276                indices.push(start + offset);
277            }
278        }
279
280        indices
281    }
282
283    /// Get compression ratio (original_size / compressed_size)
284    pub fn compression_ratio(&self) -> f32 {
285        if self.start_indices.is_empty() {
286            return 1.0;
287        }
288
289        let original_size = self.total_elements * std::mem::size_of::<usize>();
290        let compressed_size =
291            (self.start_indices.len() + self.run_lengths.len()) * std::mem::size_of::<usize>();
292
293        original_size as f32 / compressed_size as f32
294    }
295
296    /// Get number of runs
297    pub fn num_runs(&self) -> usize {
298        self.start_indices.len()
299    }
300}
301
302/// Delta-encoded index sequence
303#[derive(Debug, Clone)]
304pub struct DeltaEncoded {
305    /// Base index (first element)
306    base_index: usize,
307
308    /// Delta values (differences between consecutive indices)
309    deltas: Vec<i32>,
310
311    /// Total number of elements
312    total_elements: usize,
313}
314
315impl DeltaEncoded {
316    /// Create new delta encoding from sorted indices
317    pub fn encode(indices: &[usize]) -> Self {
318        if indices.is_empty() {
319            return Self {
320                base_index: 0,
321                deltas: vec![],
322                total_elements: 0,
323            };
324        }
325
326        let base_index = indices[0];
327        let mut deltas = Vec::with_capacity(indices.len() - 1);
328
329        for i in 1..indices.len() {
330            let delta = (indices[i] as i64 - indices[i - 1] as i64) as i32;
331            deltas.push(delta);
332        }
333
334        Self {
335            base_index,
336            deltas,
337            total_elements: indices.len(),
338        }
339    }
340
341    /// Decode delta encoding back to indices
342    pub fn decode(&self) -> Vec<usize> {
343        if self.total_elements == 0 {
344            return vec![];
345        }
346
347        let mut indices = Vec::with_capacity(self.total_elements);
348        indices.push(self.base_index);
349
350        let mut current = self.base_index as i64;
351        for &delta in &self.deltas {
352            current += delta as i64;
353            indices.push(current as usize);
354        }
355
356        indices
357    }
358
359    /// Get compression ratio
360    pub fn compression_ratio(&self) -> f32 {
361        if self.total_elements == 0 {
362            return 1.0;
363        }
364
365        let original_size = self.total_elements * std::mem::size_of::<usize>();
366        let compressed_size =
367            std::mem::size_of::<usize>() + self.deltas.len() * std::mem::size_of::<i32>();
368
369        original_size as f32 / compressed_size as f32
370    }
371}
372
373/// Bitmap encoding for dense index regions
374#[derive(Debug, Clone)]
375pub struct BitmapEncoded {
376    /// Starting index of bitmap
377    start_index: usize,
378
379    /// Bitmap (1 bit per element)
380    bitmap: Vec<u64>,
381
382    /// Number of elements in bitmap
383    num_elements: usize,
384
385    /// Number of set bits
386    num_set_bits: usize,
387}
388
389impl BitmapEncoded {
390    /// Create bitmap encoding from indices within a range
391    pub fn encode(indices: &[usize], start: usize, end: usize) -> Self {
392        let num_elements = end - start;
393        let num_words = (num_elements + 63) / 64;
394        let mut bitmap = vec![0u64; num_words];
395        let mut num_set_bits = 0;
396
397        for &idx in indices {
398            if idx >= start && idx < end {
399                let bit_pos = idx - start;
400                let word_idx = bit_pos / 64;
401                let bit_idx = bit_pos % 64;
402                bitmap[word_idx] |= 1u64 << bit_idx;
403                num_set_bits += 1;
404            }
405        }
406
407        Self {
408            start_index: start,
409            bitmap,
410            num_elements,
411            num_set_bits,
412        }
413    }
414
415    /// Decode bitmap back to indices
416    pub fn decode(&self) -> Vec<usize> {
417        let mut indices = Vec::with_capacity(self.num_set_bits);
418
419        for (word_idx, &word) in self.bitmap.iter().enumerate() {
420            if word == 0 {
421                continue;
422            }
423
424            for bit_idx in 0..64 {
425                if (word & (1u64 << bit_idx)) != 0 {
426                    let idx = self.start_index + word_idx * 64 + bit_idx;
427                    if idx < self.start_index + self.num_elements {
428                        indices.push(idx);
429                    }
430                }
431            }
432        }
433
434        indices
435    }
436
437    /// Get compression ratio
438    pub fn compression_ratio(&self) -> f32 {
439        if self.num_set_bits == 0 {
440            return 1.0;
441        }
442
443        let original_size = self.num_set_bits * std::mem::size_of::<usize>();
444        let compressed_size =
445            std::mem::size_of::<usize>() + self.bitmap.len() * std::mem::size_of::<u64>();
446
447        original_size as f32 / compressed_size as f32
448    }
449
450    /// Get density (ratio of set bits to total bits)
451    pub fn density(&self) -> f32 {
452        self.num_set_bits as f32 / self.num_elements as f32
453    }
454}
455
456/// Compression statistics and analysis
457#[derive(Debug, Clone)]
458pub struct CompressionAnalysis {
459    /// Original size in bytes
460    pub original_size: usize,
461
462    /// Compressed size in bytes
463    pub compressed_size: usize,
464
465    /// Compression ratio
466    pub compression_ratio: f32,
467
468    /// Space savings in bytes
469    pub space_savings: usize,
470
471    /// Encoding used
472    pub encoding: CompressionEncoding,
473
474    /// Sparsity of data
475    pub sparsity: f32,
476
477    /// Compression efficiency score (0-100)
478    pub efficiency_score: u8,
479}
480
481impl CompressionAnalysis {
482    /// Create compression analysis
483    pub fn new(
484        original_size: usize,
485        compressed_size: usize,
486        encoding: CompressionEncoding,
487        sparsity: f32,
488    ) -> Self {
489        let compression_ratio = if compressed_size > 0 {
490            original_size as f32 / compressed_size as f32
491        } else {
492            1.0
493        };
494
495        let space_savings = original_size.saturating_sub(compressed_size);
496
497        // Calculate efficiency score (how close to theoretical maximum)
498        let theoretical_max = encoding.expected_compression_ratio();
499        let efficiency_score = ((compression_ratio / theoretical_max) * 100.0).min(100.0) as u8;
500
501        Self {
502            original_size,
503            compressed_size,
504            compression_ratio,
505            space_savings,
506            encoding,
507            sparsity,
508            efficiency_score,
509        }
510    }
511
512    /// Check if compression is beneficial
513    pub fn is_beneficial(&self) -> bool {
514        self.compression_ratio > 1.1 // At least 10% savings
515    }
516
517    /// Get space savings percentage
518    pub fn savings_percentage(&self) -> f32 {
519        (self.space_savings as f32 / self.original_size as f32) * 100.0
520    }
521}
522
523/// Compression strategy selector based on data characteristics
524#[derive(Debug, Clone)]
525pub struct CompressionSelector {
526    /// Sparsity threshold for using compression
527    sparsity_threshold: f32,
528
529    /// Preferred encoding methods in priority order
530    preferred_encodings: Vec<CompressionEncoding>,
531}
532
533impl CompressionSelector {
534    /// Create new compression selector
535    pub fn new() -> Self {
536        Self {
537            sparsity_threshold: 0.3, // 30% sparsity minimum
538            preferred_encodings: vec![
539                CompressionEncoding::Hybrid,
540                CompressionEncoding::Huffman,
541                CompressionEncoding::Bitmap,
542                CompressionEncoding::RunLength,
543                CompressionEncoding::Delta,
544            ],
545        }
546    }
547
548    /// Set sparsity threshold
549    pub fn with_sparsity_threshold(mut self, threshold: f32) -> Self {
550        self.sparsity_threshold = threshold;
551        self
552    }
553
554    /// Get preferred encodings
555    pub fn preferred_encodings(&self) -> &[CompressionEncoding] {
556        &self.preferred_encodings
557    }
558
559    /// Select best compression encoding for indices
560    pub fn select_encoding(&self, indices: &[usize], total_size: usize) -> CompressionEncoding {
561        if indices.is_empty() {
562            return CompressionEncoding::Raw;
563        }
564
565        let sparsity = 1.0 - (indices.len() as f32 / total_size as f32);
566
567        // Don't compress if below threshold
568        if sparsity < self.sparsity_threshold {
569            return CompressionEncoding::Raw;
570        }
571
572        // Check for consecutive patterns (good for RLE)
573        let consecutive_ratio = self.calculate_consecutive_ratio(indices);
574        if consecutive_ratio > 0.7 {
575            return CompressionEncoding::RunLength;
576        }
577
578        // Check for small deltas (good for delta encoding)
579        let avg_delta = self.calculate_average_delta(indices);
580        if avg_delta < 10.0 {
581            return CompressionEncoding::Delta;
582        }
583
584        // Check for dense regions (good for bitmap)
585        if self.has_dense_regions(indices) {
586            return CompressionEncoding::Bitmap;
587        }
588
589        // Default to hybrid for complex patterns
590        CompressionEncoding::Hybrid
591    }
592
593    fn calculate_consecutive_ratio(&self, indices: &[usize]) -> f32 {
594        if indices.len() < 2 {
595            return 0.0;
596        }
597
598        let mut consecutive_count = 0;
599        for i in 1..indices.len() {
600            if indices[i] == indices[i - 1] + 1 {
601                consecutive_count += 1;
602            }
603        }
604
605        consecutive_count as f32 / (indices.len() - 1) as f32
606    }
607
608    fn calculate_average_delta(&self, indices: &[usize]) -> f32 {
609        if indices.len() < 2 {
610            return 0.0;
611        }
612
613        let mut total_delta = 0i64;
614        for i in 1..indices.len() {
615            total_delta += (indices[i] as i64 - indices[i - 1] as i64).abs();
616        }
617
618        total_delta as f32 / (indices.len() - 1) as f32
619    }
620
621    fn has_dense_regions(&self, indices: &[usize]) -> bool {
622        if indices.len() < 10 {
623            return false;
624        }
625
626        // Check if there's a region with >80% density
627        let min_idx = *indices.iter().min().expect("reduction should succeed");
628        let max_idx = *indices.iter().max().expect("reduction should succeed");
629        let range = max_idx - min_idx + 1;
630
631        if range == 0 {
632            return false;
633        }
634
635        let density = indices.len() as f32 / range as f32;
636        density > 0.8
637    }
638}
639
640impl Default for CompressionSelector {
641    fn default() -> Self {
642        Self::new()
643    }
644}
645
646/// Magnitude threshold calculator for pruning
647#[derive(Debug, Clone)]
648pub struct MagnitudeThresholdCalculator;
649
650impl MagnitudeThresholdCalculator {
651    /// Calculate threshold from percentile
652    pub fn from_percentile(values: &[f32], percentile: u8) -> f32 {
653        if values.is_empty() {
654            return 0.0;
655        }
656
657        let mut sorted_values: Vec<f32> = values.iter().map(|v| v.abs()).collect();
658        sorted_values.sort_by(|a, b| {
659            a.partial_cmp(b)
660                .expect("absolute values should be comparable (no NaN)")
661        });
662
663        let index = ((percentile as f32 / 100.0) * sorted_values.len() as f32) as usize;
664        let index = index.min(sorted_values.len() - 1);
665
666        sorted_values[index]
667    }
668
669    /// Calculate threshold from top-k selection
670    pub fn from_top_k(values: &[f32], k: usize) -> f32 {
671        if values.is_empty() || k == 0 {
672            return 0.0;
673        }
674
675        let mut sorted_values: Vec<f32> = values.iter().map(|v| v.abs()).collect();
676        sorted_values.sort_by(|a, b| {
677            b.partial_cmp(a)
678                .expect("absolute values should be comparable (no NaN)")
679        });
680
681        let k = k.min(sorted_values.len());
682        sorted_values[k - 1]
683    }
684
685    /// Calculate threshold from standard deviation
686    pub fn from_std_dev(values: &[f32], num_std_dev: f32) -> f32 {
687        if values.is_empty() {
688            return 0.0;
689        }
690
691        let mean = values.iter().sum::<f32>() / values.len() as f32;
692        let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
693        let std_dev = variance.sqrt();
694
695        mean.abs() - num_std_dev * std_dev
696    }
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702
703    #[test]
704    fn test_pruning_strategy() {
705        let strategy = PruningStrategy::Magnitude {
706            threshold_percentile: 50,
707        };
708        assert_eq!(strategy.expected_sparsity(), 0.5);
709        assert!(!strategy.is_structured());
710
711        let structured = PruningStrategy::BlockWise { block_size: (4, 4) };
712        assert!(structured.is_structured());
713    }
714
715    #[test]
716    fn test_run_length_encoding() {
717        let indices = vec![0, 1, 2, 3, 10, 11, 12, 20];
718        let encoded = RunLengthEncoded::encode(&indices);
719
720        assert_eq!(encoded.num_runs(), 3);
721        assert_eq!(encoded.decode(), indices);
722        assert!(encoded.compression_ratio() > 1.0);
723    }
724
725    #[test]
726    fn test_delta_encoding() {
727        let indices = vec![5, 10, 15, 20, 25];
728        let encoded = DeltaEncoded::encode(&indices);
729
730        assert_eq!(encoded.decode(), indices);
731        assert!(encoded.compression_ratio() > 1.0);
732    }
733
734    #[test]
735    fn test_bitmap_encoding() {
736        let indices = vec![0, 1, 3, 5, 7];
737        let encoded = BitmapEncoded::encode(&indices, 0, 10);
738
739        assert_eq!(encoded.num_set_bits, 5);
740        assert_eq!(encoded.decode(), indices);
741        assert_eq!(encoded.density(), 0.5);
742    }
743
744    #[test]
745    fn test_compression_selector() {
746        let selector = CompressionSelector::new();
747
748        // Consecutive indices should use RLE
749        let consecutive = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
750        let encoding = selector.select_encoding(&consecutive, 100);
751        assert_eq!(encoding, CompressionEncoding::RunLength);
752
753        // Small deltas should use delta encoding
754        let small_deltas = vec![0, 1, 3, 4, 6, 7, 9, 10];
755        let encoding = selector.select_encoding(&small_deltas, 100);
756        assert!(matches!(
757            encoding,
758            CompressionEncoding::Delta | CompressionEncoding::RunLength
759        ));
760    }
761
762    #[test]
763    fn test_magnitude_threshold() {
764        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
765
766        // 50th percentile with 10 values at index 5 gives 6.0
767        let threshold = MagnitudeThresholdCalculator::from_percentile(&values, 50);
768        assert!((threshold - 6.0).abs() < 0.1);
769
770        // Top-3 selection gets the 3rd largest value which is 8.0
771        let threshold = MagnitudeThresholdCalculator::from_top_k(&values, 3);
772        assert!((threshold - 8.0).abs() < 0.1);
773    }
774
775    #[test]
776    fn test_pruning_metadata() {
777        let shape = Shape::new(vec![10, 10]);
778        let metadata = PruningMetadata::new(
779            PruningStrategy::Magnitude {
780                threshold_percentile: 50,
781            },
782            shape,
783            0.5,
784        )
785        .with_indices(vec![0, 1, 2, 3, 4])
786        .with_threshold(0.1);
787
788        assert_eq!(metadata.sparsity(), 0.5);
789        assert_eq!(metadata.compression_ratio(), 2.0);
790        assert_eq!(metadata.num_pruned_elements(), 5);
791        assert!(metadata.is_element_pruned(2));
792        assert!(!metadata.is_element_pruned(10));
793    }
794
795    #[test]
796    fn test_compression_analysis() {
797        let analysis = CompressionAnalysis::new(1000, 250, CompressionEncoding::Huffman, 0.75);
798
799        assert_eq!(analysis.compression_ratio, 4.0);
800        assert_eq!(analysis.space_savings, 750);
801        assert!(analysis.is_beneficial());
802        assert_eq!(analysis.savings_percentage(), 75.0);
803    }
804}