1use crate::dtype::DType;
10use crate::shape::Shape;
11
12#[cfg(not(feature = "std"))]
13use alloc::{vec, vec::Vec};
14
15#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum PruningStrategy {
18 Magnitude { threshold_percentile: u8 },
20
21 BlockWise { block_size: (usize, usize) },
23
24 ChannelWise { channels_to_prune: usize },
26
27 AttentionHead { heads_to_prune: usize },
29
30 Movement { sensitivity: f32 },
32
33 GradualMagnitude {
35 initial_sparsity: f32,
36 final_sparsity: f32,
37 },
38}
39
40impl PruningStrategy {
41 pub fn expected_sparsity(&self) -> f32 {
43 match self {
44 Self::Magnitude {
45 threshold_percentile,
46 } => *threshold_percentile as f32 / 100.0,
47 Self::GradualMagnitude { final_sparsity, .. } => *final_sparsity,
48 _ => 0.5, }
50 }
51
52 pub fn is_structured(&self) -> bool {
54 matches!(
55 self,
56 Self::BlockWise { .. } | Self::ChannelWise { .. } | Self::AttentionHead { .. }
57 )
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct PruningMetadata {
64 strategy: PruningStrategy,
66
67 pruned_indices: Option<Vec<usize>>,
69
70 pruned_blocks: Option<Vec<(usize, usize)>>,
72
73 pruned_channels: Option<Vec<usize>>,
75
76 achieved_sparsity: f32,
78
79 original_shape: Shape,
81
82 threshold_value: Option<f32>,
84
85 compression_ratio: f32,
87}
88
89impl PruningMetadata {
90 pub fn new(strategy: PruningStrategy, original_shape: Shape, achieved_sparsity: f32) -> Self {
92 let compression_ratio = 1.0 / (1.0 - achieved_sparsity);
93
94 Self {
95 strategy,
96 pruned_indices: None,
97 pruned_blocks: None,
98 pruned_channels: None,
99 achieved_sparsity,
100 original_shape,
101 threshold_value: None,
102 compression_ratio,
103 }
104 }
105
106 pub fn with_indices(mut self, indices: Vec<usize>) -> Self {
108 self.pruned_indices = Some(indices);
109 self
110 }
111
112 pub fn with_blocks(mut self, blocks: Vec<(usize, usize)>) -> Self {
114 self.pruned_blocks = Some(blocks);
115 self
116 }
117
118 pub fn with_channels(mut self, channels: Vec<usize>) -> Self {
120 self.pruned_channels = Some(channels);
121 self
122 }
123
124 pub fn with_threshold(mut self, threshold: f32) -> Self {
126 self.threshold_value = Some(threshold);
127 self
128 }
129
130 pub fn strategy(&self) -> PruningStrategy {
132 self.strategy
133 }
134
135 pub fn sparsity(&self) -> f32 {
137 self.achieved_sparsity
138 }
139
140 pub fn compression_ratio(&self) -> f32 {
142 self.compression_ratio
143 }
144
145 pub fn num_pruned_elements(&self) -> usize {
147 if let Some(ref indices) = self.pruned_indices {
148 indices.len()
149 } else if let Some(ref blocks) = self.pruned_blocks {
150 blocks.len()
151 } else if let Some(ref channels) = self.pruned_channels {
152 channels.len()
153 } else {
154 0
155 }
156 }
157
158 pub fn is_element_pruned(&self, index: usize) -> bool {
160 if let Some(ref indices) = self.pruned_indices {
161 indices.binary_search(&index).is_ok()
162 } else {
163 false
164 }
165 }
166
167 pub fn memory_savings(&self, dtype: DType) -> usize {
169 let total_elements = self.original_shape.numel();
170 let pruned_elements = (total_elements as f32 * self.achieved_sparsity) as usize;
171 pruned_elements * dtype.size()
172 }
173}
174
175#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
177pub enum CompressionEncoding {
178 Raw,
180
181 RunLength,
183
184 Delta,
186
187 Huffman,
189
190 Bitmap,
192
193 Hybrid,
195}
196
197impl CompressionEncoding {
198 pub fn expected_compression_ratio(&self) -> f32 {
200 match self {
201 Self::Raw => 1.0,
202 Self::RunLength => 2.0,
203 Self::Delta => 1.5,
204 Self::Huffman => 2.5,
205 Self::Bitmap => 3.0,
206 Self::Hybrid => 3.5,
207 }
208 }
209
210 pub fn requires_sorted_indices(&self) -> bool {
212 matches!(self, Self::RunLength | Self::Delta | Self::Hybrid)
213 }
214}
215
216#[derive(Debug, Clone)]
218pub struct RunLengthEncoded {
219 start_indices: Vec<usize>,
221
222 run_lengths: Vec<usize>,
224
225 total_elements: usize,
227}
228
229impl RunLengthEncoded {
230 pub fn encode(indices: &[usize]) -> Self {
232 if indices.is_empty() {
233 return Self {
234 start_indices: vec![],
235 run_lengths: vec![],
236 total_elements: 0,
237 };
238 }
239
240 let mut start_indices = Vec::new();
241 let mut run_lengths = Vec::new();
242
243 let mut current_start = indices[0];
244 let mut current_length = 1;
245
246 for i in 1..indices.len() {
247 if indices[i] == indices[i - 1] + 1 {
248 current_length += 1;
250 } else {
251 start_indices.push(current_start);
253 run_lengths.push(current_length);
254 current_start = indices[i];
255 current_length = 1;
256 }
257 }
258
259 start_indices.push(current_start);
261 run_lengths.push(current_length);
262
263 Self {
264 start_indices,
265 run_lengths,
266 total_elements: indices.len(),
267 }
268 }
269
270 pub fn decode(&self) -> Vec<usize> {
272 let mut indices = Vec::with_capacity(self.total_elements);
273
274 for (start, length) in self.start_indices.iter().zip(self.run_lengths.iter()) {
275 for offset in 0..*length {
276 indices.push(start + offset);
277 }
278 }
279
280 indices
281 }
282
283 pub fn compression_ratio(&self) -> f32 {
285 if self.start_indices.is_empty() {
286 return 1.0;
287 }
288
289 let original_size = self.total_elements * std::mem::size_of::<usize>();
290 let compressed_size =
291 (self.start_indices.len() + self.run_lengths.len()) * std::mem::size_of::<usize>();
292
293 original_size as f32 / compressed_size as f32
294 }
295
296 pub fn num_runs(&self) -> usize {
298 self.start_indices.len()
299 }
300}
301
302#[derive(Debug, Clone)]
304pub struct DeltaEncoded {
305 base_index: usize,
307
308 deltas: Vec<i32>,
310
311 total_elements: usize,
313}
314
315impl DeltaEncoded {
316 pub fn encode(indices: &[usize]) -> Self {
318 if indices.is_empty() {
319 return Self {
320 base_index: 0,
321 deltas: vec![],
322 total_elements: 0,
323 };
324 }
325
326 let base_index = indices[0];
327 let mut deltas = Vec::with_capacity(indices.len() - 1);
328
329 for i in 1..indices.len() {
330 let delta = (indices[i] as i64 - indices[i - 1] as i64) as i32;
331 deltas.push(delta);
332 }
333
334 Self {
335 base_index,
336 deltas,
337 total_elements: indices.len(),
338 }
339 }
340
341 pub fn decode(&self) -> Vec<usize> {
343 if self.total_elements == 0 {
344 return vec![];
345 }
346
347 let mut indices = Vec::with_capacity(self.total_elements);
348 indices.push(self.base_index);
349
350 let mut current = self.base_index as i64;
351 for &delta in &self.deltas {
352 current += delta as i64;
353 indices.push(current as usize);
354 }
355
356 indices
357 }
358
359 pub fn compression_ratio(&self) -> f32 {
361 if self.total_elements == 0 {
362 return 1.0;
363 }
364
365 let original_size = self.total_elements * std::mem::size_of::<usize>();
366 let compressed_size =
367 std::mem::size_of::<usize>() + self.deltas.len() * std::mem::size_of::<i32>();
368
369 original_size as f32 / compressed_size as f32
370 }
371}
372
373#[derive(Debug, Clone)]
375pub struct BitmapEncoded {
376 start_index: usize,
378
379 bitmap: Vec<u64>,
381
382 num_elements: usize,
384
385 num_set_bits: usize,
387}
388
389impl BitmapEncoded {
390 pub fn encode(indices: &[usize], start: usize, end: usize) -> Self {
392 let num_elements = end - start;
393 let num_words = (num_elements + 63) / 64;
394 let mut bitmap = vec![0u64; num_words];
395 let mut num_set_bits = 0;
396
397 for &idx in indices {
398 if idx >= start && idx < end {
399 let bit_pos = idx - start;
400 let word_idx = bit_pos / 64;
401 let bit_idx = bit_pos % 64;
402 bitmap[word_idx] |= 1u64 << bit_idx;
403 num_set_bits += 1;
404 }
405 }
406
407 Self {
408 start_index: start,
409 bitmap,
410 num_elements,
411 num_set_bits,
412 }
413 }
414
415 pub fn decode(&self) -> Vec<usize> {
417 let mut indices = Vec::with_capacity(self.num_set_bits);
418
419 for (word_idx, &word) in self.bitmap.iter().enumerate() {
420 if word == 0 {
421 continue;
422 }
423
424 for bit_idx in 0..64 {
425 if (word & (1u64 << bit_idx)) != 0 {
426 let idx = self.start_index + word_idx * 64 + bit_idx;
427 if idx < self.start_index + self.num_elements {
428 indices.push(idx);
429 }
430 }
431 }
432 }
433
434 indices
435 }
436
437 pub fn compression_ratio(&self) -> f32 {
439 if self.num_set_bits == 0 {
440 return 1.0;
441 }
442
443 let original_size = self.num_set_bits * std::mem::size_of::<usize>();
444 let compressed_size =
445 std::mem::size_of::<usize>() + self.bitmap.len() * std::mem::size_of::<u64>();
446
447 original_size as f32 / compressed_size as f32
448 }
449
450 pub fn density(&self) -> f32 {
452 self.num_set_bits as f32 / self.num_elements as f32
453 }
454}
455
456#[derive(Debug, Clone)]
458pub struct CompressionAnalysis {
459 pub original_size: usize,
461
462 pub compressed_size: usize,
464
465 pub compression_ratio: f32,
467
468 pub space_savings: usize,
470
471 pub encoding: CompressionEncoding,
473
474 pub sparsity: f32,
476
477 pub efficiency_score: u8,
479}
480
481impl CompressionAnalysis {
482 pub fn new(
484 original_size: usize,
485 compressed_size: usize,
486 encoding: CompressionEncoding,
487 sparsity: f32,
488 ) -> Self {
489 let compression_ratio = if compressed_size > 0 {
490 original_size as f32 / compressed_size as f32
491 } else {
492 1.0
493 };
494
495 let space_savings = original_size.saturating_sub(compressed_size);
496
497 let theoretical_max = encoding.expected_compression_ratio();
499 let efficiency_score = ((compression_ratio / theoretical_max) * 100.0).min(100.0) as u8;
500
501 Self {
502 original_size,
503 compressed_size,
504 compression_ratio,
505 space_savings,
506 encoding,
507 sparsity,
508 efficiency_score,
509 }
510 }
511
512 pub fn is_beneficial(&self) -> bool {
514 self.compression_ratio > 1.1 }
516
517 pub fn savings_percentage(&self) -> f32 {
519 (self.space_savings as f32 / self.original_size as f32) * 100.0
520 }
521}
522
523#[derive(Debug, Clone)]
525pub struct CompressionSelector {
526 sparsity_threshold: f32,
528
529 preferred_encodings: Vec<CompressionEncoding>,
531}
532
533impl CompressionSelector {
534 pub fn new() -> Self {
536 Self {
537 sparsity_threshold: 0.3, preferred_encodings: vec![
539 CompressionEncoding::Hybrid,
540 CompressionEncoding::Huffman,
541 CompressionEncoding::Bitmap,
542 CompressionEncoding::RunLength,
543 CompressionEncoding::Delta,
544 ],
545 }
546 }
547
548 pub fn with_sparsity_threshold(mut self, threshold: f32) -> Self {
550 self.sparsity_threshold = threshold;
551 self
552 }
553
554 pub fn preferred_encodings(&self) -> &[CompressionEncoding] {
556 &self.preferred_encodings
557 }
558
559 pub fn select_encoding(&self, indices: &[usize], total_size: usize) -> CompressionEncoding {
561 if indices.is_empty() {
562 return CompressionEncoding::Raw;
563 }
564
565 let sparsity = 1.0 - (indices.len() as f32 / total_size as f32);
566
567 if sparsity < self.sparsity_threshold {
569 return CompressionEncoding::Raw;
570 }
571
572 let consecutive_ratio = self.calculate_consecutive_ratio(indices);
574 if consecutive_ratio > 0.7 {
575 return CompressionEncoding::RunLength;
576 }
577
578 let avg_delta = self.calculate_average_delta(indices);
580 if avg_delta < 10.0 {
581 return CompressionEncoding::Delta;
582 }
583
584 if self.has_dense_regions(indices) {
586 return CompressionEncoding::Bitmap;
587 }
588
589 CompressionEncoding::Hybrid
591 }
592
593 fn calculate_consecutive_ratio(&self, indices: &[usize]) -> f32 {
594 if indices.len() < 2 {
595 return 0.0;
596 }
597
598 let mut consecutive_count = 0;
599 for i in 1..indices.len() {
600 if indices[i] == indices[i - 1] + 1 {
601 consecutive_count += 1;
602 }
603 }
604
605 consecutive_count as f32 / (indices.len() - 1) as f32
606 }
607
608 fn calculate_average_delta(&self, indices: &[usize]) -> f32 {
609 if indices.len() < 2 {
610 return 0.0;
611 }
612
613 let mut total_delta = 0i64;
614 for i in 1..indices.len() {
615 total_delta += (indices[i] as i64 - indices[i - 1] as i64).abs();
616 }
617
618 total_delta as f32 / (indices.len() - 1) as f32
619 }
620
621 fn has_dense_regions(&self, indices: &[usize]) -> bool {
622 if indices.len() < 10 {
623 return false;
624 }
625
626 let min_idx = *indices.iter().min().expect("reduction should succeed");
628 let max_idx = *indices.iter().max().expect("reduction should succeed");
629 let range = max_idx - min_idx + 1;
630
631 if range == 0 {
632 return false;
633 }
634
635 let density = indices.len() as f32 / range as f32;
636 density > 0.8
637 }
638}
639
640impl Default for CompressionSelector {
641 fn default() -> Self {
642 Self::new()
643 }
644}
645
646#[derive(Debug, Clone)]
648pub struct MagnitudeThresholdCalculator;
649
650impl MagnitudeThresholdCalculator {
651 pub fn from_percentile(values: &[f32], percentile: u8) -> f32 {
653 if values.is_empty() {
654 return 0.0;
655 }
656
657 let mut sorted_values: Vec<f32> = values.iter().map(|v| v.abs()).collect();
658 sorted_values.sort_by(|a, b| {
659 a.partial_cmp(b)
660 .expect("absolute values should be comparable (no NaN)")
661 });
662
663 let index = ((percentile as f32 / 100.0) * sorted_values.len() as f32) as usize;
664 let index = index.min(sorted_values.len() - 1);
665
666 sorted_values[index]
667 }
668
669 pub fn from_top_k(values: &[f32], k: usize) -> f32 {
671 if values.is_empty() || k == 0 {
672 return 0.0;
673 }
674
675 let mut sorted_values: Vec<f32> = values.iter().map(|v| v.abs()).collect();
676 sorted_values.sort_by(|a, b| {
677 b.partial_cmp(a)
678 .expect("absolute values should be comparable (no NaN)")
679 });
680
681 let k = k.min(sorted_values.len());
682 sorted_values[k - 1]
683 }
684
685 pub fn from_std_dev(values: &[f32], num_std_dev: f32) -> f32 {
687 if values.is_empty() {
688 return 0.0;
689 }
690
691 let mean = values.iter().sum::<f32>() / values.len() as f32;
692 let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
693 let std_dev = variance.sqrt();
694
695 mean.abs() - num_std_dev * std_dev
696 }
697}
698
699#[cfg(test)]
700mod tests {
701 use super::*;
702
703 #[test]
704 fn test_pruning_strategy() {
705 let strategy = PruningStrategy::Magnitude {
706 threshold_percentile: 50,
707 };
708 assert_eq!(strategy.expected_sparsity(), 0.5);
709 assert!(!strategy.is_structured());
710
711 let structured = PruningStrategy::BlockWise { block_size: (4, 4) };
712 assert!(structured.is_structured());
713 }
714
715 #[test]
716 fn test_run_length_encoding() {
717 let indices = vec![0, 1, 2, 3, 10, 11, 12, 20];
718 let encoded = RunLengthEncoded::encode(&indices);
719
720 assert_eq!(encoded.num_runs(), 3);
721 assert_eq!(encoded.decode(), indices);
722 assert!(encoded.compression_ratio() > 1.0);
723 }
724
725 #[test]
726 fn test_delta_encoding() {
727 let indices = vec![5, 10, 15, 20, 25];
728 let encoded = DeltaEncoded::encode(&indices);
729
730 assert_eq!(encoded.decode(), indices);
731 assert!(encoded.compression_ratio() > 1.0);
732 }
733
734 #[test]
735 fn test_bitmap_encoding() {
736 let indices = vec![0, 1, 3, 5, 7];
737 let encoded = BitmapEncoded::encode(&indices, 0, 10);
738
739 assert_eq!(encoded.num_set_bits, 5);
740 assert_eq!(encoded.decode(), indices);
741 assert_eq!(encoded.density(), 0.5);
742 }
743
744 #[test]
745 fn test_compression_selector() {
746 let selector = CompressionSelector::new();
747
748 let consecutive = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
750 let encoding = selector.select_encoding(&consecutive, 100);
751 assert_eq!(encoding, CompressionEncoding::RunLength);
752
753 let small_deltas = vec![0, 1, 3, 4, 6, 7, 9, 10];
755 let encoding = selector.select_encoding(&small_deltas, 100);
756 assert!(matches!(
757 encoding,
758 CompressionEncoding::Delta | CompressionEncoding::RunLength
759 ));
760 }
761
762 #[test]
763 fn test_magnitude_threshold() {
764 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
765
766 let threshold = MagnitudeThresholdCalculator::from_percentile(&values, 50);
768 assert!((threshold - 6.0).abs() < 0.1);
769
770 let threshold = MagnitudeThresholdCalculator::from_top_k(&values, 3);
772 assert!((threshold - 8.0).abs() < 0.1);
773 }
774
775 #[test]
776 fn test_pruning_metadata() {
777 let shape = Shape::new(vec![10, 10]);
778 let metadata = PruningMetadata::new(
779 PruningStrategy::Magnitude {
780 threshold_percentile: 50,
781 },
782 shape,
783 0.5,
784 )
785 .with_indices(vec![0, 1, 2, 3, 4])
786 .with_threshold(0.1);
787
788 assert_eq!(metadata.sparsity(), 0.5);
789 assert_eq!(metadata.compression_ratio(), 2.0);
790 assert_eq!(metadata.num_pruned_elements(), 5);
791 assert!(metadata.is_element_pruned(2));
792 assert!(!metadata.is_element_pruned(10));
793 }
794
795 #[test]
796 fn test_compression_analysis() {
797 let analysis = CompressionAnalysis::new(1000, 250, CompressionEncoding::Huffman, 0.75);
798
799 assert_eq!(analysis.compression_ratio, 4.0);
800 assert_eq!(analysis.space_savings, 750);
801 assert!(analysis.is_beneficial());
802 assert_eq!(analysis.savings_percentage(), 75.0);
803 }
804}