1use scirs2_core::ndarray::{Array2, ArrayD, Axis, Ix2};
60use serde::{Deserialize, Serialize};
61use std::collections::HashMap;
62
63use crate::error::{TrainError, TrainResult};
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PruningConfig {
68 pub pruning_ratio: f64,
70 pub structured: bool,
72 pub iterative: bool,
74 pub num_iterations: usize,
76 pub initial_ratio: f64,
78 pub final_ratio: f64,
80 pub schedule: String,
82 pub min_threshold: f64,
84 pub global_pruning: bool,
86}
87
88impl Default for PruningConfig {
89 fn default() -> Self {
90 Self {
91 pruning_ratio: 0.5,
92 structured: false,
93 iterative: false,
94 num_iterations: 10,
95 initial_ratio: 0.0,
96 final_ratio: 0.9,
97 schedule: "linear".to_string(),
98 min_threshold: 1e-8,
99 global_pruning: false,
100 }
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum StructuredPruningAxis {
107 Rows,
109 Columns,
111 Both,
113}
114
115pub type PruningMask = ArrayD<f64>;
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct PruningStats {
121 pub total_params: usize,
123 pub active_params: usize,
125 pub pruning_ratio: f64,
127 pub iterations: usize,
129 pub per_layer_stats: HashMap<String, LayerPruningStats>,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct LayerPruningStats {
136 pub name: String,
138 pub original_params: usize,
140 pub active_params: usize,
142 pub ratio: f64,
144}
145
146impl PruningStats {
147 pub fn compression_ratio(&self) -> f64 {
149 if self.active_params == 0 {
150 0.0
151 } else {
152 self.total_params as f64 / self.active_params as f64
153 }
154 }
155
156 pub fn flops_reduction(&self) -> f64 {
158 self.pruning_ratio
159 }
160
161 pub fn summary(&self) -> String {
163 format!(
164 "Pruning Stats:\n\
165 - Total params: {}\n\
166 - Active params: {}\n\
167 - Pruned: {} ({:.2}%)\n\
168 - Compression: {:.2}x\n\
169 - Est. FLOPs reduction: {:.2}%",
170 self.total_params,
171 self.active_params,
172 self.total_params - self.active_params,
173 self.pruning_ratio * 100.0,
174 self.compression_ratio(),
175 self.flops_reduction() * 100.0
176 )
177 }
178}
179
180pub trait Pruner {
182 fn prune(&self, weights: &Array2<f64>) -> TrainResult<(Array2<f64>, PruningMask)>;
184
185 fn generate_mask(&self, weights: &Array2<f64>) -> TrainResult<PruningMask>;
187
188 fn apply_mask(&self, weights: &Array2<f64>, mask: &PruningMask) -> TrainResult<Array2<f64>>;
190
191 fn config(&self) -> &PruningConfig;
193
194 fn update_ratio(&mut self, iteration: usize);
196}
197
198pub struct MagnitudePruner {
200 config: PruningConfig,
201 current_ratio: f64,
202}
203
204impl MagnitudePruner {
205 pub fn new(config: PruningConfig) -> Self {
207 let current_ratio = if config.iterative {
208 config.initial_ratio
209 } else {
210 config.pruning_ratio
211 };
212 Self {
213 config,
214 current_ratio,
215 }
216 }
217
218 fn calculate_threshold(&self, weights: &Array2<f64>) -> f64 {
220 let mut abs_weights: Vec<f64> = weights.iter().map(|w| w.abs()).collect();
221 abs_weights.sort_by(|a, b| a.partial_cmp(b).unwrap());
222
223 let prune_count = (abs_weights.len() as f64 * self.current_ratio) as usize;
224 if prune_count >= abs_weights.len() {
225 abs_weights.last().copied().unwrap_or(0.0)
226 } else {
227 abs_weights[prune_count]
228 }
229 }
230}
231
232impl Pruner for MagnitudePruner {
233 fn prune(&self, weights: &Array2<f64>) -> TrainResult<(Array2<f64>, PruningMask)> {
234 let mask = self.generate_mask(weights)?;
235 let pruned = self.apply_mask(weights, &mask)?;
236 Ok((pruned, mask))
237 }
238
239 fn generate_mask(&self, weights: &Array2<f64>) -> TrainResult<PruningMask> {
240 let threshold = self
241 .calculate_threshold(weights)
242 .max(self.config.min_threshold);
243
244 let mask = weights.mapv(|w| if w.abs() >= threshold { 1.0 } else { 0.0 });
245 Ok(mask.into_dyn())
246 }
247
248 fn apply_mask(&self, weights: &Array2<f64>, mask: &PruningMask) -> TrainResult<Array2<f64>> {
249 let mask_2d = mask
250 .clone()
251 .into_dimensionality::<Ix2>()
252 .map_err(|e| TrainError::ConfigError(format!("Mask shape mismatch: {}", e)))?;
253
254 if weights.shape() != mask_2d.shape() {
255 return Err(TrainError::ConfigError(format!(
256 "Weight and mask shapes do not match: {:?} vs {:?}",
257 weights.shape(),
258 mask_2d.shape()
259 )));
260 }
261
262 Ok(weights * &mask_2d)
263 }
264
265 fn config(&self) -> &PruningConfig {
266 &self.config
267 }
268
269 fn update_ratio(&mut self, iteration: usize) {
270 if !self.config.iterative || iteration >= self.config.num_iterations {
271 return;
272 }
273
274 let progress = iteration as f64 / (self.config.num_iterations - 1) as f64;
275 self.current_ratio = match self.config.schedule.as_str() {
276 "linear" => {
277 self.config.initial_ratio
278 + (self.config.final_ratio - self.config.initial_ratio) * progress
279 }
280 "exponential" => {
281 let log_initial = self.config.initial_ratio.max(1e-8).ln();
282 let log_final = self.config.final_ratio.ln();
283 (log_initial + (log_final - log_initial) * progress).exp()
284 }
285 "cosine" => {
286 let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
287 self.config.final_ratio
288 + (self.config.initial_ratio - self.config.final_ratio) * cosine_decay
289 }
290 _ => self.config.pruning_ratio,
291 };
292 }
293}
294
295pub struct GradientPruner {
297 config: PruningConfig,
298 current_ratio: f64,
299 gradient_history: HashMap<String, Vec<Array2<f64>>>,
300}
301
302impl GradientPruner {
303 pub fn new(config: PruningConfig) -> Self {
305 let current_ratio = if config.iterative {
306 config.initial_ratio
307 } else {
308 config.pruning_ratio
309 };
310 Self {
311 config,
312 current_ratio,
313 gradient_history: HashMap::new(),
314 }
315 }
316
317 pub fn update_gradients(&mut self, layer_name: &str, gradients: Array2<f64>) {
319 self.gradient_history
320 .entry(layer_name.to_string())
321 .or_default()
322 .push(gradients);
323 }
324
325 fn average_gradient_magnitude(&self, layer_name: &str) -> Option<Array2<f64>> {
327 let gradients = self.gradient_history.get(layer_name)?;
328 if gradients.is_empty() {
329 return None;
330 }
331
332 let mut sum = gradients[0].mapv(|g| g.abs());
333 for grad in &gradients[1..] {
334 sum = sum + grad.mapv(|g| g.abs());
335 }
336 Some(sum / gradients.len() as f64)
337 }
338
339 fn calculate_threshold(&self, gradients: &Array2<f64>) -> f64 {
341 let mut abs_grads: Vec<f64> = gradients.iter().map(|g| g.abs()).collect();
342 abs_grads.sort_by(|a, b| a.partial_cmp(b).unwrap());
343
344 let prune_count = (abs_grads.len() as f64 * self.current_ratio) as usize;
345 if prune_count >= abs_grads.len() {
346 abs_grads.last().copied().unwrap_or(0.0)
347 } else {
348 abs_grads[prune_count]
349 }
350 }
351
352 pub fn prune_with_history(
354 &self,
355 weights: &Array2<f64>,
356 layer_name: &str,
357 ) -> TrainResult<(Array2<f64>, PruningMask)> {
358 if let Some(avg_grads) = self.average_gradient_magnitude(layer_name) {
359 let threshold = self
360 .calculate_threshold(&avg_grads)
361 .max(self.config.min_threshold);
362 let mask = avg_grads.mapv(|g| if g >= threshold { 1.0 } else { 0.0 });
363 let pruned = weights * &mask;
364 Ok((pruned, mask.into_dyn()))
365 } else {
366 let magnitude_pruner = MagnitudePruner::new(self.config.clone());
368 magnitude_pruner.prune(weights)
369 }
370 }
371}
372
373impl Pruner for GradientPruner {
374 fn prune(&self, weights: &Array2<f64>) -> TrainResult<(Array2<f64>, PruningMask)> {
375 let magnitude_pruner = MagnitudePruner::new(self.config.clone());
377 magnitude_pruner.prune(weights)
378 }
379
380 fn generate_mask(&self, weights: &Array2<f64>) -> TrainResult<PruningMask> {
381 let magnitude_pruner = MagnitudePruner::new(self.config.clone());
382 magnitude_pruner.generate_mask(weights)
383 }
384
385 fn apply_mask(&self, weights: &Array2<f64>, mask: &PruningMask) -> TrainResult<Array2<f64>> {
386 let magnitude_pruner = MagnitudePruner::new(self.config.clone());
387 magnitude_pruner.apply_mask(weights, mask)
388 }
389
390 fn config(&self) -> &PruningConfig {
391 &self.config
392 }
393
394 fn update_ratio(&mut self, iteration: usize) {
395 if !self.config.iterative || iteration >= self.config.num_iterations {
396 return;
397 }
398
399 let progress = iteration as f64 / (self.config.num_iterations - 1) as f64;
400 self.current_ratio = match self.config.schedule.as_str() {
401 "linear" => {
402 self.config.initial_ratio
403 + (self.config.final_ratio - self.config.initial_ratio) * progress
404 }
405 "exponential" => {
406 let log_initial = self.config.initial_ratio.max(1e-8).ln();
407 let log_final = self.config.final_ratio.ln();
408 (log_initial + (log_final - log_initial) * progress).exp()
409 }
410 "cosine" => {
411 let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
412 self.config.final_ratio
413 + (self.config.initial_ratio - self.config.final_ratio) * cosine_decay
414 }
415 _ => self.config.pruning_ratio,
416 };
417 }
418}
419
420pub struct StructuredPruner {
422 config: PruningConfig,
423 axis: StructuredPruningAxis,
424 current_ratio: f64,
425}
426
427impl StructuredPruner {
428 pub fn new(config: PruningConfig, axis: StructuredPruningAxis) -> Self {
430 let current_ratio = if config.iterative {
431 config.initial_ratio
432 } else {
433 config.pruning_ratio
434 };
435 Self {
436 config,
437 axis,
438 current_ratio,
439 }
440 }
441
442 fn calculate_importance(&self, weights: &Array2<f64>, axis: Axis) -> Vec<f64> {
444 let axis_len = weights.len_of(axis);
445 (0..axis_len)
446 .map(|i| {
447 let slice = weights.index_axis(axis, i);
448 slice.iter().map(|&w| w * w).sum::<f64>().sqrt()
450 })
451 .collect()
452 }
453
454 fn select_units_to_prune(&self, importance: &[f64]) -> Vec<usize> {
456 let mut indexed: Vec<(usize, f64)> = importance.iter().copied().enumerate().collect();
457 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
458
459 let prune_count = (importance.len() as f64 * self.current_ratio) as usize;
460 indexed
461 .iter()
462 .take(prune_count)
463 .map(|(idx, _)| *idx)
464 .collect()
465 }
466
467 fn generate_structured_mask(&self, weights: &Array2<f64>) -> TrainResult<PruningMask> {
469 let (nrows, ncols) = weights.dim();
470 let mut mask = Array2::ones((nrows, ncols));
471
472 match self.axis {
473 StructuredPruningAxis::Rows => {
474 let importance = self.calculate_importance(weights, Axis(0));
475 let to_prune = self.select_units_to_prune(&importance);
476 for &row_idx in &to_prune {
477 mask.row_mut(row_idx).fill(0.0);
478 }
479 }
480 StructuredPruningAxis::Columns => {
481 let importance = self.calculate_importance(weights, Axis(1));
482 let to_prune = self.select_units_to_prune(&importance);
483 for &col_idx in &to_prune {
484 mask.column_mut(col_idx).fill(0.0);
485 }
486 }
487 StructuredPruningAxis::Both => {
488 let row_importance = self.calculate_importance(weights, Axis(0));
490 let col_importance = self.calculate_importance(weights, Axis(1));
491
492 let rows_to_prune = self.select_units_to_prune(&row_importance);
493 let cols_to_prune = self.select_units_to_prune(&col_importance);
494
495 for &row_idx in &rows_to_prune {
496 mask.row_mut(row_idx).fill(0.0);
497 }
498 for &col_idx in &cols_to_prune {
499 mask.column_mut(col_idx).fill(0.0);
500 }
501 }
502 }
503
504 Ok(mask.into_dyn())
505 }
506}
507
508impl Pruner for StructuredPruner {
509 fn prune(&self, weights: &Array2<f64>) -> TrainResult<(Array2<f64>, PruningMask)> {
510 let mask = self.generate_structured_mask(weights)?;
511 let pruned = self.apply_mask(weights, &mask)?;
512 Ok((pruned, mask))
513 }
514
515 fn generate_mask(&self, weights: &Array2<f64>) -> TrainResult<PruningMask> {
516 self.generate_structured_mask(weights)
517 }
518
519 fn apply_mask(&self, weights: &Array2<f64>, mask: &PruningMask) -> TrainResult<Array2<f64>> {
520 let mask_2d = mask
521 .clone()
522 .into_dimensionality::<Ix2>()
523 .map_err(|e| TrainError::ConfigError(format!("Mask shape mismatch: {}", e)))?;
524
525 if weights.shape() != mask_2d.shape() {
526 return Err(TrainError::ConfigError(format!(
527 "Weight and mask shapes do not match: {:?} vs {:?}",
528 weights.shape(),
529 mask_2d.shape()
530 )));
531 }
532
533 Ok(weights * &mask_2d)
534 }
535
536 fn config(&self) -> &PruningConfig {
537 &self.config
538 }
539
540 fn update_ratio(&mut self, iteration: usize) {
541 if !self.config.iterative || iteration >= self.config.num_iterations {
542 return;
543 }
544
545 let progress = iteration as f64 / (self.config.num_iterations - 1) as f64;
546 self.current_ratio = match self.config.schedule.as_str() {
547 "linear" => {
548 self.config.initial_ratio
549 + (self.config.final_ratio - self.config.initial_ratio) * progress
550 }
551 "exponential" => {
552 let log_initial = self.config.initial_ratio.max(1e-8).ln();
553 let log_final = self.config.final_ratio.ln();
554 (log_initial + (log_final - log_initial) * progress).exp()
555 }
556 "cosine" => {
557 let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
558 self.config.final_ratio
559 + (self.config.initial_ratio - self.config.final_ratio) * cosine_decay
560 }
561 _ => self.config.pruning_ratio,
562 };
563 }
564}
565
566pub struct GlobalPruner {
568 config: PruningConfig,
569 layer_weights: HashMap<String, Array2<f64>>,
570}
571
572impl GlobalPruner {
573 pub fn new(config: PruningConfig) -> Self {
575 Self {
576 config,
577 layer_weights: HashMap::new(),
578 }
579 }
580
581 pub fn add_layer(&mut self, name: &str, weights: Array2<f64>) {
583 self.layer_weights.insert(name.to_string(), weights);
584 }
585
586 fn calculate_global_threshold(&self) -> f64 {
588 let mut all_weights: Vec<f64> = self
589 .layer_weights
590 .values()
591 .flat_map(|w| w.iter().map(|x| x.abs()))
592 .collect();
593
594 all_weights.sort_by(|a, b| a.partial_cmp(b).unwrap());
595
596 let total_params = all_weights.len();
597 let prune_count = (total_params as f64 * self.config.pruning_ratio) as usize;
598
599 if prune_count >= total_params {
600 all_weights.last().copied().unwrap_or(0.0)
601 } else {
602 all_weights[prune_count]
603 }
604 }
605
606 pub fn prune_all(&self) -> TrainResult<HashMap<String, (Array2<f64>, PruningMask)>> {
608 let threshold = self
609 .calculate_global_threshold()
610 .max(self.config.min_threshold);
611
612 let mut results = HashMap::new();
613 for (name, weights) in &self.layer_weights {
614 let mask = weights.mapv(|w| if w.abs() >= threshold { 1.0 } else { 0.0 });
615 let pruned = weights * &mask;
616 results.insert(name.clone(), (pruned, mask.into_dyn()));
617 }
618
619 Ok(results)
620 }
621
622 pub fn statistics(&self, pruned: &HashMap<String, (Array2<f64>, PruningMask)>) -> PruningStats {
624 let mut total_params = 0;
625 let mut active_params = 0;
626 let mut per_layer_stats = HashMap::new();
627
628 for (name, weights) in &self.layer_weights {
629 let layer_total = weights.len();
630 total_params += layer_total;
631
632 if let Some((_, mask)) = pruned.get(name) {
633 let layer_active = mask.iter().filter(|&&m| m > 0.5).count();
634 active_params += layer_active;
635
636 per_layer_stats.insert(
637 name.clone(),
638 LayerPruningStats {
639 name: name.clone(),
640 original_params: layer_total,
641 active_params: layer_active,
642 ratio: 1.0 - (layer_active as f64 / layer_total as f64),
643 },
644 );
645 }
646 }
647
648 PruningStats {
649 total_params,
650 active_params,
651 pruning_ratio: 1.0 - (active_params as f64 / total_params as f64),
652 iterations: 1,
653 per_layer_stats,
654 }
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661 use approx::assert_abs_diff_eq;
662
663 #[test]
664 fn test_magnitude_pruner() {
665 let weights =
666 Array2::from_shape_vec((3, 3), vec![0.1, 0.5, 0.9, 0.2, 0.6, 0.01, 0.3, 0.7, 0.8])
667 .unwrap();
668
669 let config = PruningConfig {
670 pruning_ratio: 0.3,
671 structured: false,
672 iterative: false,
673 ..Default::default()
674 };
675
676 let pruner = MagnitudePruner::new(config);
677 let (pruned, mask) = pruner.prune(&weights).unwrap();
678
679 let active_count = mask.iter().filter(|&&m| m > 0.5).count();
681 let prune_count = (9.0 * 0.3) as usize;
684 let expected_active = 9 - prune_count;
685 assert_eq!(active_count, expected_active);
686
687 for ((&p, &m), &w) in pruned.iter().zip(mask.iter()).zip(weights.iter()) {
689 if m < 0.5 {
690 assert_abs_diff_eq!(p, 0.0, epsilon = 1e-10);
691 } else {
692 assert_abs_diff_eq!(p, w, epsilon = 1e-10);
693 }
694 }
695 }
696
697 #[test]
698 fn test_iterative_pruning() {
699 let mut pruner = MagnitudePruner::new(PruningConfig {
700 pruning_ratio: 0.0,
701 structured: false,
702 iterative: true,
703 num_iterations: 5,
704 initial_ratio: 0.0,
705 final_ratio: 0.5,
706 schedule: "linear".to_string(),
707 ..Default::default()
708 });
709
710 assert_abs_diff_eq!(pruner.current_ratio, 0.0, epsilon = 1e-10);
711
712 pruner.update_ratio(0);
713 assert_abs_diff_eq!(pruner.current_ratio, 0.0, epsilon = 1e-3);
714
715 pruner.update_ratio(2);
716 assert_abs_diff_eq!(pruner.current_ratio, 0.25, epsilon = 1e-3);
717
718 pruner.update_ratio(4);
719 assert_abs_diff_eq!(pruner.current_ratio, 0.5, epsilon = 1e-3);
720 }
721
722 #[test]
723 fn test_structured_pruner_rows() {
724 let weights = Array2::from_shape_vec(
725 (4, 3),
726 vec![
727 0.1, 0.1, 0.1, 0.9, 0.9, 0.9, 0.2, 0.2, 0.2, 0.8, 0.8, 0.8, ],
732 )
733 .unwrap();
734
735 let config = PruningConfig {
736 pruning_ratio: 0.5, structured: true,
738 ..Default::default()
739 };
740
741 let pruner = StructuredPruner::new(config, StructuredPruningAxis::Rows);
742 let (pruned, _mask) = pruner.prune(&weights).unwrap();
743
744 let zero_rows = (0..4)
746 .filter(|&i| pruned.row(i).iter().all(|&x| x.abs() < 1e-10))
747 .count();
748 assert_eq!(zero_rows, 2);
749
750 assert!(pruned.row(0).iter().all(|&x| x.abs() < 1e-10));
752 assert!(pruned.row(2).iter().all(|&x| x.abs() < 1e-10));
753 }
754
755 #[test]
756 fn test_structured_pruner_columns() {
757 let weights = Array2::from_shape_vec(
758 (3, 4),
759 vec![
760 0.1, 0.9, 0.2, 0.8, 0.1, 0.9, 0.2, 0.8, 0.1, 0.9, 0.2, 0.8,
762 ],
763 )
764 .unwrap();
765
766 let config = PruningConfig {
767 pruning_ratio: 0.5,
768 structured: true,
769 ..Default::default()
770 };
771
772 let pruner = StructuredPruner::new(config, StructuredPruningAxis::Columns);
773 let (pruned, _mask) = pruner.prune(&weights).unwrap();
774
775 let zero_cols = (0..4)
777 .filter(|&i| pruned.column(i).iter().all(|&x| x.abs() < 1e-10))
778 .count();
779 assert_eq!(zero_cols, 2);
780 }
781
782 #[test]
783 fn test_global_pruner() {
784 let mut global_pruner = GlobalPruner::new(PruningConfig {
785 pruning_ratio: 0.5,
786 global_pruning: true,
787 ..Default::default()
788 });
789
790 let layer1 = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
791 let layer2 = Array2::from_shape_vec((2, 2), vec![0.5, 0.6, 0.7, 0.8]).unwrap();
792
793 global_pruner.add_layer("layer1", layer1);
794 global_pruner.add_layer("layer2", layer2);
795
796 let pruned = global_pruner.prune_all().unwrap();
797 let stats = global_pruner.statistics(&pruned);
798
799 assert_eq!(stats.total_params, 8);
800 assert_eq!(stats.active_params, 4);
801 assert_abs_diff_eq!(stats.pruning_ratio, 0.5, epsilon = 1e-3);
802 }
803
804 #[test]
805 fn test_pruning_stats() {
806 let stats = PruningStats {
807 total_params: 1000,
808 active_params: 200,
809 pruning_ratio: 0.8,
810 iterations: 5,
811 per_layer_stats: HashMap::new(),
812 };
813
814 assert_abs_diff_eq!(stats.compression_ratio(), 5.0, epsilon = 1e-10);
815 assert_abs_diff_eq!(stats.flops_reduction(), 0.8, epsilon = 1e-10);
816
817 let summary = stats.summary();
818 assert!(summary.contains("1000"));
819 assert!(summary.contains("200"));
820 assert!(summary.contains("5.00x"));
821 }
822
823 #[test]
824 fn test_gradient_pruner_fallback() {
825 let weights =
826 Array2::from_shape_vec((3, 3), vec![0.1, 0.5, 0.9, 0.2, 0.6, 0.01, 0.3, 0.7, 0.8])
827 .unwrap();
828
829 let config = PruningConfig {
830 pruning_ratio: 0.3,
831 ..Default::default()
832 };
833
834 let pruner = GradientPruner::new(config);
835 let (_pruned, mask) = pruner.prune(&weights).unwrap();
836
837 let active_count = mask.iter().filter(|&&m| m > 0.5).count();
839 let prune_count = (9.0 * 0.3) as usize;
842 let expected_active = 9 - prune_count;
843 assert_eq!(active_count, expected_active);
844 }
845
846 #[test]
847 fn test_gradient_pruner_with_history() {
848 let weights = Array2::from_shape_vec((2, 2), vec![0.5, 0.6, 0.7, 0.8]).unwrap();
849
850 let grads1 = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
851
852 let grads2 = Array2::from_shape_vec((2, 2), vec![0.15, 0.25, 0.35, 0.45]).unwrap();
853
854 let config = PruningConfig {
855 pruning_ratio: 0.5,
856 ..Default::default()
857 };
858
859 let mut pruner = GradientPruner::new(config);
860 pruner.update_gradients("layer1", grads1);
861 pruner.update_gradients("layer1", grads2);
862
863 let (pruned, _mask) = pruner.prune_with_history(&weights, "layer1").unwrap();
864
865 assert_abs_diff_eq!(pruned[[0, 0]], 0.0, epsilon = 1e-10);
869 assert_abs_diff_eq!(pruned[[0, 1]], 0.0, epsilon = 1e-10);
870 }
871
872 #[test]
873 fn test_exponential_schedule() {
874 let mut pruner = MagnitudePruner::new(PruningConfig {
875 pruning_ratio: 0.0,
876 iterative: true,
877 num_iterations: 5,
878 initial_ratio: 0.1,
879 final_ratio: 0.9,
880 schedule: "exponential".to_string(),
881 ..Default::default()
882 });
883
884 pruner.update_ratio(0);
885 let ratio_0 = pruner.current_ratio;
886 pruner.update_ratio(2);
887 let ratio_2 = pruner.current_ratio;
888 pruner.update_ratio(4);
889 let ratio_4 = pruner.current_ratio;
890
891 assert!(ratio_0 < ratio_2);
893 assert!(ratio_2 < ratio_4);
894 assert_abs_diff_eq!(ratio_0, 0.1, epsilon = 1e-2);
895 assert_abs_diff_eq!(ratio_4, 0.9, epsilon = 1e-2);
896 }
897
898 #[test]
899 fn test_cosine_schedule() {
900 let mut pruner = MagnitudePruner::new(PruningConfig {
901 pruning_ratio: 0.0,
902 iterative: true,
903 num_iterations: 5,
904 initial_ratio: 0.1,
905 final_ratio: 0.9,
906 schedule: "cosine".to_string(),
907 ..Default::default()
908 });
909
910 pruner.update_ratio(0);
911 let ratio_0 = pruner.current_ratio;
912 pruner.update_ratio(4);
913 let ratio_4 = pruner.current_ratio;
914
915 assert_abs_diff_eq!(ratio_0, 0.1, epsilon = 1e-2);
916 assert_abs_diff_eq!(ratio_4, 0.9, epsilon = 1e-2);
917 }
918
919 #[test]
920 fn test_min_threshold() {
921 let weights = Array2::from_shape_vec((2, 2), vec![1e-10, 1e-9, 1e-8, 0.5]).unwrap();
922
923 let config = PruningConfig {
924 pruning_ratio: 0.0, min_threshold: 1e-7, ..Default::default()
927 };
928
929 let pruner = MagnitudePruner::new(config);
930 let (pruned, _mask) = pruner.prune(&weights).unwrap();
931
932 assert_abs_diff_eq!(pruned[[0, 0]], 0.0, epsilon = 1e-10);
934 assert_abs_diff_eq!(pruned[[0, 1]], 0.0, epsilon = 1e-10);
935 assert_abs_diff_eq!(pruned[[1, 0]], 0.0, epsilon = 1e-10);
936 assert_abs_diff_eq!(pruned[[1, 1]], 0.5, epsilon = 1e-10);
937 }
938
939 #[test]
940 fn test_apply_mask() {
941 let weights = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
942
943 let mask = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 1.0, 0.0])
944 .unwrap()
945 .into_dyn();
946
947 let config = PruningConfig::default();
948 let pruner = MagnitudePruner::new(config);
949 let pruned = pruner.apply_mask(&weights, &mask).unwrap();
950
951 assert_abs_diff_eq!(pruned[[0, 0]], 1.0, epsilon = 1e-10);
952 assert_abs_diff_eq!(pruned[[0, 1]], 0.0, epsilon = 1e-10);
953 assert_abs_diff_eq!(pruned[[1, 0]], 3.0, epsilon = 1e-10);
954 assert_abs_diff_eq!(pruned[[1, 1]], 0.0, epsilon = 1e-10);
955 }
956
957 #[test]
958 fn test_structured_both_axes() {
959 let weights = Array2::from_shape_vec(
960 (4, 4),
961 vec![
962 0.1, 0.1, 0.8, 0.1, 0.1, 0.1, 0.8, 0.1, 0.1, 0.1, 0.8, 0.1, 0.9, 0.9, 0.1, 0.9,
963 ],
964 )
965 .unwrap();
966
967 let config = PruningConfig {
968 pruning_ratio: 0.25,
969 structured: true,
970 ..Default::default()
971 };
972
973 let pruner = StructuredPruner::new(config, StructuredPruningAxis::Both);
974 let (_pruned, _mask) = pruner.prune(&weights).unwrap();
975
976 }
979}