Skip to main content

tensorlogic_train/
pruning.rs

1//! Model pruning for compression and acceleration.
2//!
3//! This module provides various pruning strategies to reduce model size and computational cost:
4//! - **Unstructured pruning**: Remove individual weights based on magnitude, gradient, or other criteria
5//! - **Structured pruning**: Remove entire neurons, channels, or filters
6//! - **Iterative pruning**: Gradually increase pruning ratio during training
7//! - **Dynamic pruning**: Adaptively prune based on runtime statistics
8//!
9//! # Pruning Strategies
10//!
11//! ## Magnitude-based Pruning
12//! Prune weights with smallest absolute values (most common and effective):
13//! ```rust
14//! use tensorlogic_train::{MagnitudePruner, Pruner, PruningConfig};
15//! use scirs2_core::ndarray::Array2;
16//!
17//! let weights = Array2::from_shape_vec((3, 3), vec![
18//!     0.1, 0.5, 0.9,
19//!     0.2, 0.6, 0.01,
20//!     0.3, 0.7, 0.8,
21//! ]).unwrap();
22//!
23//! let config = PruningConfig {
24//!     pruning_ratio: 0.3, // Remove 30% of weights
25//!     structured: false,
26//!     iterative: false,
27//!     ..Default::default()
28//! };
29//!
30//! let pruner = MagnitudePruner::new(config);
31//! let (pruned_weights, mask) = pruner.prune(&weights).unwrap();
32//! ```
33//!
34//! ## Gradient-based Pruning
35//! Prune weights with smallest gradient magnitudes (less sensitive to training):
36//! ```rust
37//! use tensorlogic_train::{GradientPruner, PruningConfig};
38//! use scirs2_core::ndarray::Array2;
39//!
40//! let gradients = Array2::<f64>::zeros((3, 3));
41//! let config = PruningConfig::default();
42//! let pruner = GradientPruner::new(config);
43//! ```
44//!
45//! ## Structured Pruning
46//! Remove entire neurons, channels, or filters:
47//! ```rust
48//! use tensorlogic_train::{StructuredPruner, PruningConfig, StructuredPruningAxis};
49//!
50//! let config = PruningConfig {
51//!     pruning_ratio: 0.5,
52//!     structured: true,
53//!     ..Default::default()
54//! };
55//!
56//! let pruner = StructuredPruner::new(config, StructuredPruningAxis::Rows);
57//! ```
58
59use scirs2_core::ndarray::{Array2, ArrayD, Axis, Ix2};
60use serde::{Deserialize, Serialize};
61use std::collections::HashMap;
62
63use crate::error::{TrainError, TrainResult};
64
65/// Configuration for pruning strategies.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PruningConfig {
68    /// Fraction of weights to prune (0.0 to 1.0)
69    pub pruning_ratio: f64,
70    /// Whether to use structured pruning (entire neurons/channels)
71    pub structured: bool,
72    /// Use iterative pruning (gradually increase pruning ratio)
73    pub iterative: bool,
74    /// Number of iterations for iterative pruning
75    pub num_iterations: usize,
76    /// Initial pruning ratio for iterative pruning
77    pub initial_ratio: f64,
78    /// Final pruning ratio for iterative pruning
79    pub final_ratio: f64,
80    /// Pruning schedule: "linear", "exponential", "cosine"
81    pub schedule: String,
82    /// Minimum weight magnitude threshold (weights below this are always pruned)
83    pub min_threshold: f64,
84    /// Whether to use global pruning (across all layers) or local (per-layer)
85    pub global_pruning: bool,
86}
87
88impl Default for PruningConfig {
89    fn default() -> Self {
90        Self {
91            pruning_ratio: 0.5,
92            structured: false,
93            iterative: false,
94            num_iterations: 10,
95            initial_ratio: 0.0,
96            final_ratio: 0.9,
97            schedule: "linear".to_string(),
98            min_threshold: 1e-8,
99            global_pruning: false,
100        }
101    }
102}
103
104/// Axis for structured pruning.
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum StructuredPruningAxis {
107    /// Prune rows (output neurons)
108    Rows,
109    /// Prune columns (input neurons)
110    Columns,
111    /// Prune both (for convolutional filters)
112    Both,
113}
114
115/// Pruning mask indicating which weights are kept (1.0) or removed (0.0).
116pub type PruningMask = ArrayD<f64>;
117
118/// Statistics about pruned model.
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct PruningStats {
121    /// Total number of parameters before pruning
122    pub total_params: usize,
123    /// Number of parameters after pruning
124    pub active_params: usize,
125    /// Pruning ratio achieved
126    pub pruning_ratio: f64,
127    /// Number of pruning iterations performed
128    pub iterations: usize,
129    /// Per-layer pruning statistics
130    pub per_layer_stats: HashMap<String, LayerPruningStats>,
131}
132
133/// Pruning statistics for a single layer.
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct LayerPruningStats {
136    /// Layer name
137    pub name: String,
138    /// Original parameter count
139    pub original_params: usize,
140    /// Active parameter count after pruning
141    pub active_params: usize,
142    /// Pruning ratio for this layer
143    pub ratio: f64,
144}
145
146impl PruningStats {
147    /// Calculate compression ratio (original size / pruned size).
148    pub fn compression_ratio(&self) -> f64 {
149        if self.active_params == 0 {
150            0.0
151        } else {
152            self.total_params as f64 / self.active_params as f64
153        }
154    }
155
156    /// Calculate FLOPs reduction (approximate).
157    pub fn flops_reduction(&self) -> f64 {
158        self.pruning_ratio
159    }
160
161    /// Pretty print pruning statistics.
162    pub fn summary(&self) -> String {
163        format!(
164            "Pruning Stats:\n\
165             - Total params: {}\n\
166             - Active params: {}\n\
167             - Pruned: {} ({:.2}%)\n\
168             - Compression: {:.2}x\n\
169             - Est. FLOPs reduction: {:.2}%",
170            self.total_params,
171            self.active_params,
172            self.total_params - self.active_params,
173            self.pruning_ratio * 100.0,
174            self.compression_ratio(),
175            self.flops_reduction() * 100.0
176        )
177    }
178}
179
180/// Trait for pruning strategies.
181pub trait Pruner {
182    /// Prune weights and return pruned weights and mask.
183    fn prune(&self, weights: &Array2<f64>) -> TrainResult<(Array2<f64>, PruningMask)>;
184
185    /// Generate pruning mask without modifying weights.
186    fn generate_mask(&self, weights: &Array2<f64>) -> TrainResult<PruningMask>;
187
188    /// Apply existing mask to weights.
189    fn apply_mask(&self, weights: &Array2<f64>, mask: &PruningMask) -> TrainResult<Array2<f64>>;
190
191    /// Get pruning configuration.
192    fn config(&self) -> &PruningConfig;
193
194    /// Update pruning ratio for iterative pruning.
195    fn update_ratio(&mut self, iteration: usize);
196}
197
198/// Magnitude-based pruning (prune smallest weights).
199pub struct MagnitudePruner {
200    config: PruningConfig,
201    current_ratio: f64,
202}
203
204impl MagnitudePruner {
205    /// Create a new magnitude-based pruner.
206    pub fn new(config: PruningConfig) -> Self {
207        let current_ratio = if config.iterative {
208            config.initial_ratio
209        } else {
210            config.pruning_ratio
211        };
212        Self {
213            config,
214            current_ratio,
215        }
216    }
217
218    /// Calculate pruning threshold based on weight distribution.
219    fn calculate_threshold(&self, weights: &Array2<f64>) -> f64 {
220        let mut abs_weights: Vec<f64> = weights.iter().map(|w| w.abs()).collect();
221        abs_weights.sort_by(|a, b| a.partial_cmp(b).unwrap());
222
223        let prune_count = (abs_weights.len() as f64 * self.current_ratio) as usize;
224        if prune_count >= abs_weights.len() {
225            abs_weights.last().copied().unwrap_or(0.0)
226        } else {
227            abs_weights[prune_count]
228        }
229    }
230}
231
232impl Pruner for MagnitudePruner {
233    fn prune(&self, weights: &Array2<f64>) -> TrainResult<(Array2<f64>, PruningMask)> {
234        let mask = self.generate_mask(weights)?;
235        let pruned = self.apply_mask(weights, &mask)?;
236        Ok((pruned, mask))
237    }
238
239    fn generate_mask(&self, weights: &Array2<f64>) -> TrainResult<PruningMask> {
240        let threshold = self
241            .calculate_threshold(weights)
242            .max(self.config.min_threshold);
243
244        let mask = weights.mapv(|w| if w.abs() >= threshold { 1.0 } else { 0.0 });
245        Ok(mask.into_dyn())
246    }
247
248    fn apply_mask(&self, weights: &Array2<f64>, mask: &PruningMask) -> TrainResult<Array2<f64>> {
249        let mask_2d = mask
250            .clone()
251            .into_dimensionality::<Ix2>()
252            .map_err(|e| TrainError::ConfigError(format!("Mask shape mismatch: {}", e)))?;
253
254        if weights.shape() != mask_2d.shape() {
255            return Err(TrainError::ConfigError(format!(
256                "Weight and mask shapes do not match: {:?} vs {:?}",
257                weights.shape(),
258                mask_2d.shape()
259            )));
260        }
261
262        Ok(weights * &mask_2d)
263    }
264
265    fn config(&self) -> &PruningConfig {
266        &self.config
267    }
268
269    fn update_ratio(&mut self, iteration: usize) {
270        if !self.config.iterative || iteration >= self.config.num_iterations {
271            return;
272        }
273
274        let progress = iteration as f64 / (self.config.num_iterations - 1) as f64;
275        self.current_ratio = match self.config.schedule.as_str() {
276            "linear" => {
277                self.config.initial_ratio
278                    + (self.config.final_ratio - self.config.initial_ratio) * progress
279            }
280            "exponential" => {
281                let log_initial = self.config.initial_ratio.max(1e-8).ln();
282                let log_final = self.config.final_ratio.ln();
283                (log_initial + (log_final - log_initial) * progress).exp()
284            }
285            "cosine" => {
286                let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
287                self.config.final_ratio
288                    + (self.config.initial_ratio - self.config.final_ratio) * cosine_decay
289            }
290            _ => self.config.pruning_ratio,
291        };
292    }
293}
294
295/// Gradient-based pruning (prune weights with smallest gradients).
296pub struct GradientPruner {
297    config: PruningConfig,
298    current_ratio: f64,
299    gradient_history: HashMap<String, Vec<Array2<f64>>>,
300}
301
302impl GradientPruner {
303    /// Create a new gradient-based pruner.
304    pub fn new(config: PruningConfig) -> Self {
305        let current_ratio = if config.iterative {
306            config.initial_ratio
307        } else {
308            config.pruning_ratio
309        };
310        Self {
311            config,
312            current_ratio,
313            gradient_history: HashMap::new(),
314        }
315    }
316
317    /// Update gradient history for a layer.
318    pub fn update_gradients(&mut self, layer_name: &str, gradients: Array2<f64>) {
319        self.gradient_history
320            .entry(layer_name.to_string())
321            .or_default()
322            .push(gradients);
323    }
324
325    /// Calculate average gradient magnitude for a layer.
326    fn average_gradient_magnitude(&self, layer_name: &str) -> Option<Array2<f64>> {
327        let gradients = self.gradient_history.get(layer_name)?;
328        if gradients.is_empty() {
329            return None;
330        }
331
332        let mut sum = gradients[0].mapv(|g| g.abs());
333        for grad in &gradients[1..] {
334            sum = sum + grad.mapv(|g| g.abs());
335        }
336        Some(sum / gradients.len() as f64)
337    }
338
339    /// Calculate pruning threshold based on gradient distribution.
340    fn calculate_threshold(&self, gradients: &Array2<f64>) -> f64 {
341        let mut abs_grads: Vec<f64> = gradients.iter().map(|g| g.abs()).collect();
342        abs_grads.sort_by(|a, b| a.partial_cmp(b).unwrap());
343
344        let prune_count = (abs_grads.len() as f64 * self.current_ratio) as usize;
345        if prune_count >= abs_grads.len() {
346            abs_grads.last().copied().unwrap_or(0.0)
347        } else {
348            abs_grads[prune_count]
349        }
350    }
351
352    /// Prune based on gradient history.
353    pub fn prune_with_history(
354        &self,
355        weights: &Array2<f64>,
356        layer_name: &str,
357    ) -> TrainResult<(Array2<f64>, PruningMask)> {
358        if let Some(avg_grads) = self.average_gradient_magnitude(layer_name) {
359            let threshold = self
360                .calculate_threshold(&avg_grads)
361                .max(self.config.min_threshold);
362            let mask = avg_grads.mapv(|g| if g >= threshold { 1.0 } else { 0.0 });
363            let pruned = weights * &mask;
364            Ok((pruned, mask.into_dyn()))
365        } else {
366            // Fall back to magnitude pruning if no gradient history
367            let magnitude_pruner = MagnitudePruner::new(self.config.clone());
368            magnitude_pruner.prune(weights)
369        }
370    }
371}
372
373impl Pruner for GradientPruner {
374    fn prune(&self, weights: &Array2<f64>) -> TrainResult<(Array2<f64>, PruningMask)> {
375        // Without gradient information, fall back to magnitude pruning
376        let magnitude_pruner = MagnitudePruner::new(self.config.clone());
377        magnitude_pruner.prune(weights)
378    }
379
380    fn generate_mask(&self, weights: &Array2<f64>) -> TrainResult<PruningMask> {
381        let magnitude_pruner = MagnitudePruner::new(self.config.clone());
382        magnitude_pruner.generate_mask(weights)
383    }
384
385    fn apply_mask(&self, weights: &Array2<f64>, mask: &PruningMask) -> TrainResult<Array2<f64>> {
386        let magnitude_pruner = MagnitudePruner::new(self.config.clone());
387        magnitude_pruner.apply_mask(weights, mask)
388    }
389
390    fn config(&self) -> &PruningConfig {
391        &self.config
392    }
393
394    fn update_ratio(&mut self, iteration: usize) {
395        if !self.config.iterative || iteration >= self.config.num_iterations {
396            return;
397        }
398
399        let progress = iteration as f64 / (self.config.num_iterations - 1) as f64;
400        self.current_ratio = match self.config.schedule.as_str() {
401            "linear" => {
402                self.config.initial_ratio
403                    + (self.config.final_ratio - self.config.initial_ratio) * progress
404            }
405            "exponential" => {
406                let log_initial = self.config.initial_ratio.max(1e-8).ln();
407                let log_final = self.config.final_ratio.ln();
408                (log_initial + (log_final - log_initial) * progress).exp()
409            }
410            "cosine" => {
411                let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
412                self.config.final_ratio
413                    + (self.config.initial_ratio - self.config.final_ratio) * cosine_decay
414            }
415            _ => self.config.pruning_ratio,
416        };
417    }
418}
419
420/// Structured pruning (remove entire neurons/channels/filters).
421pub struct StructuredPruner {
422    config: PruningConfig,
423    axis: StructuredPruningAxis,
424    current_ratio: f64,
425}
426
427impl StructuredPruner {
428    /// Create a new structured pruner.
429    pub fn new(config: PruningConfig, axis: StructuredPruningAxis) -> Self {
430        let current_ratio = if config.iterative {
431            config.initial_ratio
432        } else {
433            config.pruning_ratio
434        };
435        Self {
436            config,
437            axis,
438            current_ratio,
439        }
440    }
441
442    /// Calculate importance scores for rows or columns.
443    fn calculate_importance(&self, weights: &Array2<f64>, axis: Axis) -> Vec<f64> {
444        let axis_len = weights.len_of(axis);
445        (0..axis_len)
446            .map(|i| {
447                let slice = weights.index_axis(axis, i);
448                // L2 norm as importance metric
449                slice.iter().map(|&w| w * w).sum::<f64>().sqrt()
450            })
451            .collect()
452    }
453
454    /// Determine which units to prune based on importance scores.
455    fn select_units_to_prune(&self, importance: &[f64]) -> Vec<usize> {
456        let mut indexed: Vec<(usize, f64)> = importance.iter().copied().enumerate().collect();
457        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
458
459        let prune_count = (importance.len() as f64 * self.current_ratio) as usize;
460        indexed
461            .iter()
462            .take(prune_count)
463            .map(|(idx, _)| *idx)
464            .collect()
465    }
466
467    /// Generate mask for structured pruning.
468    fn generate_structured_mask(&self, weights: &Array2<f64>) -> TrainResult<PruningMask> {
469        let (nrows, ncols) = weights.dim();
470        let mut mask = Array2::ones((nrows, ncols));
471
472        match self.axis {
473            StructuredPruningAxis::Rows => {
474                let importance = self.calculate_importance(weights, Axis(0));
475                let to_prune = self.select_units_to_prune(&importance);
476                for &row_idx in &to_prune {
477                    mask.row_mut(row_idx).fill(0.0);
478                }
479            }
480            StructuredPruningAxis::Columns => {
481                let importance = self.calculate_importance(weights, Axis(1));
482                let to_prune = self.select_units_to_prune(&importance);
483                for &col_idx in &to_prune {
484                    mask.column_mut(col_idx).fill(0.0);
485                }
486            }
487            StructuredPruningAxis::Both => {
488                // Prune both rows and columns
489                let row_importance = self.calculate_importance(weights, Axis(0));
490                let col_importance = self.calculate_importance(weights, Axis(1));
491
492                let rows_to_prune = self.select_units_to_prune(&row_importance);
493                let cols_to_prune = self.select_units_to_prune(&col_importance);
494
495                for &row_idx in &rows_to_prune {
496                    mask.row_mut(row_idx).fill(0.0);
497                }
498                for &col_idx in &cols_to_prune {
499                    mask.column_mut(col_idx).fill(0.0);
500                }
501            }
502        }
503
504        Ok(mask.into_dyn())
505    }
506}
507
508impl Pruner for StructuredPruner {
509    fn prune(&self, weights: &Array2<f64>) -> TrainResult<(Array2<f64>, PruningMask)> {
510        let mask = self.generate_structured_mask(weights)?;
511        let pruned = self.apply_mask(weights, &mask)?;
512        Ok((pruned, mask))
513    }
514
515    fn generate_mask(&self, weights: &Array2<f64>) -> TrainResult<PruningMask> {
516        self.generate_structured_mask(weights)
517    }
518
519    fn apply_mask(&self, weights: &Array2<f64>, mask: &PruningMask) -> TrainResult<Array2<f64>> {
520        let mask_2d = mask
521            .clone()
522            .into_dimensionality::<Ix2>()
523            .map_err(|e| TrainError::ConfigError(format!("Mask shape mismatch: {}", e)))?;
524
525        if weights.shape() != mask_2d.shape() {
526            return Err(TrainError::ConfigError(format!(
527                "Weight and mask shapes do not match: {:?} vs {:?}",
528                weights.shape(),
529                mask_2d.shape()
530            )));
531        }
532
533        Ok(weights * &mask_2d)
534    }
535
536    fn config(&self) -> &PruningConfig {
537        &self.config
538    }
539
540    fn update_ratio(&mut self, iteration: usize) {
541        if !self.config.iterative || iteration >= self.config.num_iterations {
542            return;
543        }
544
545        let progress = iteration as f64 / (self.config.num_iterations - 1) as f64;
546        self.current_ratio = match self.config.schedule.as_str() {
547            "linear" => {
548                self.config.initial_ratio
549                    + (self.config.final_ratio - self.config.initial_ratio) * progress
550            }
551            "exponential" => {
552                let log_initial = self.config.initial_ratio.max(1e-8).ln();
553                let log_final = self.config.final_ratio.ln();
554                (log_initial + (log_final - log_initial) * progress).exp()
555            }
556            "cosine" => {
557                let cosine_decay = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
558                self.config.final_ratio
559                    + (self.config.initial_ratio - self.config.final_ratio) * cosine_decay
560            }
561            _ => self.config.pruning_ratio,
562        };
563    }
564}
565
566/// Global pruning across multiple layers.
567pub struct GlobalPruner {
568    config: PruningConfig,
569    layer_weights: HashMap<String, Array2<f64>>,
570}
571
572impl GlobalPruner {
573    /// Create a new global pruner.
574    pub fn new(config: PruningConfig) -> Self {
575        Self {
576            config,
577            layer_weights: HashMap::new(),
578        }
579    }
580
581    /// Add a layer to the global pruning pool.
582    pub fn add_layer(&mut self, name: &str, weights: Array2<f64>) {
583        self.layer_weights.insert(name.to_string(), weights);
584    }
585
586    /// Calculate global threshold across all layers.
587    fn calculate_global_threshold(&self) -> f64 {
588        let mut all_weights: Vec<f64> = self
589            .layer_weights
590            .values()
591            .flat_map(|w| w.iter().map(|x| x.abs()))
592            .collect();
593
594        all_weights.sort_by(|a, b| a.partial_cmp(b).unwrap());
595
596        let total_params = all_weights.len();
597        let prune_count = (total_params as f64 * self.config.pruning_ratio) as usize;
598
599        if prune_count >= total_params {
600            all_weights.last().copied().unwrap_or(0.0)
601        } else {
602            all_weights[prune_count]
603        }
604    }
605
606    /// Prune all layers using global threshold.
607    pub fn prune_all(&self) -> TrainResult<HashMap<String, (Array2<f64>, PruningMask)>> {
608        let threshold = self
609            .calculate_global_threshold()
610            .max(self.config.min_threshold);
611
612        let mut results = HashMap::new();
613        for (name, weights) in &self.layer_weights {
614            let mask = weights.mapv(|w| if w.abs() >= threshold { 1.0 } else { 0.0 });
615            let pruned = weights * &mask;
616            results.insert(name.clone(), (pruned, mask.into_dyn()));
617        }
618
619        Ok(results)
620    }
621
622    /// Generate pruning statistics.
623    pub fn statistics(&self, pruned: &HashMap<String, (Array2<f64>, PruningMask)>) -> PruningStats {
624        let mut total_params = 0;
625        let mut active_params = 0;
626        let mut per_layer_stats = HashMap::new();
627
628        for (name, weights) in &self.layer_weights {
629            let layer_total = weights.len();
630            total_params += layer_total;
631
632            if let Some((_, mask)) = pruned.get(name) {
633                let layer_active = mask.iter().filter(|&&m| m > 0.5).count();
634                active_params += layer_active;
635
636                per_layer_stats.insert(
637                    name.clone(),
638                    LayerPruningStats {
639                        name: name.clone(),
640                        original_params: layer_total,
641                        active_params: layer_active,
642                        ratio: 1.0 - (layer_active as f64 / layer_total as f64),
643                    },
644                );
645            }
646        }
647
648        PruningStats {
649            total_params,
650            active_params,
651            pruning_ratio: 1.0 - (active_params as f64 / total_params as f64),
652            iterations: 1,
653            per_layer_stats,
654        }
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use approx::assert_abs_diff_eq;
662
663    #[test]
664    fn test_magnitude_pruner() {
665        let weights =
666            Array2::from_shape_vec((3, 3), vec![0.1, 0.5, 0.9, 0.2, 0.6, 0.01, 0.3, 0.7, 0.8])
667                .unwrap();
668
669        let config = PruningConfig {
670            pruning_ratio: 0.3,
671            structured: false,
672            iterative: false,
673            ..Default::default()
674        };
675
676        let pruner = MagnitudePruner::new(config);
677        let (pruned, mask) = pruner.prune(&weights).unwrap();
678
679        // Check that smallest weights are pruned
680        let active_count = mask.iter().filter(|&&m| m > 0.5).count();
681        // With 30% pruning ratio, prune_count = (9 * 0.3) = 2.7 -> 2
682        // So we keep 9 - 2 = 7 weights (approximately 78% kept)
683        let prune_count = (9.0 * 0.3) as usize;
684        let expected_active = 9 - prune_count;
685        assert_eq!(active_count, expected_active);
686
687        // Check that pruned weights are zeroed
688        for ((&p, &m), &w) in pruned.iter().zip(mask.iter()).zip(weights.iter()) {
689            if m < 0.5 {
690                assert_abs_diff_eq!(p, 0.0, epsilon = 1e-10);
691            } else {
692                assert_abs_diff_eq!(p, w, epsilon = 1e-10);
693            }
694        }
695    }
696
697    #[test]
698    fn test_iterative_pruning() {
699        let mut pruner = MagnitudePruner::new(PruningConfig {
700            pruning_ratio: 0.0,
701            structured: false,
702            iterative: true,
703            num_iterations: 5,
704            initial_ratio: 0.0,
705            final_ratio: 0.5,
706            schedule: "linear".to_string(),
707            ..Default::default()
708        });
709
710        assert_abs_diff_eq!(pruner.current_ratio, 0.0, epsilon = 1e-10);
711
712        pruner.update_ratio(0);
713        assert_abs_diff_eq!(pruner.current_ratio, 0.0, epsilon = 1e-3);
714
715        pruner.update_ratio(2);
716        assert_abs_diff_eq!(pruner.current_ratio, 0.25, epsilon = 1e-3);
717
718        pruner.update_ratio(4);
719        assert_abs_diff_eq!(pruner.current_ratio, 0.5, epsilon = 1e-3);
720    }
721
722    #[test]
723    fn test_structured_pruner_rows() {
724        let weights = Array2::from_shape_vec(
725            (4, 3),
726            vec![
727                0.1, 0.1, 0.1, // Row 0: low magnitude
728                0.9, 0.9, 0.9, // Row 1: high magnitude
729                0.2, 0.2, 0.2, // Row 2: low magnitude
730                0.8, 0.8, 0.8, // Row 3: high magnitude
731            ],
732        )
733        .unwrap();
734
735        let config = PruningConfig {
736            pruning_ratio: 0.5, // Prune 50% of rows (2 out of 4)
737            structured: true,
738            ..Default::default()
739        };
740
741        let pruner = StructuredPruner::new(config, StructuredPruningAxis::Rows);
742        let (pruned, _mask) = pruner.prune(&weights).unwrap();
743
744        // Check that 2 rows are completely zeroed
745        let zero_rows = (0..4)
746            .filter(|&i| pruned.row(i).iter().all(|&x| x.abs() < 1e-10))
747            .count();
748        assert_eq!(zero_rows, 2);
749
750        // Check that the low magnitude rows (0 and 2) are pruned
751        assert!(pruned.row(0).iter().all(|&x| x.abs() < 1e-10));
752        assert!(pruned.row(2).iter().all(|&x| x.abs() < 1e-10));
753    }
754
755    #[test]
756    fn test_structured_pruner_columns() {
757        let weights = Array2::from_shape_vec(
758            (3, 4),
759            vec![
760                0.1, 0.9, 0.2, 0.8, // Each column has varying magnitudes
761                0.1, 0.9, 0.2, 0.8, 0.1, 0.9, 0.2, 0.8,
762            ],
763        )
764        .unwrap();
765
766        let config = PruningConfig {
767            pruning_ratio: 0.5,
768            structured: true,
769            ..Default::default()
770        };
771
772        let pruner = StructuredPruner::new(config, StructuredPruningAxis::Columns);
773        let (pruned, _mask) = pruner.prune(&weights).unwrap();
774
775        // Check that 2 columns are completely zeroed
776        let zero_cols = (0..4)
777            .filter(|&i| pruned.column(i).iter().all(|&x| x.abs() < 1e-10))
778            .count();
779        assert_eq!(zero_cols, 2);
780    }
781
782    #[test]
783    fn test_global_pruner() {
784        let mut global_pruner = GlobalPruner::new(PruningConfig {
785            pruning_ratio: 0.5,
786            global_pruning: true,
787            ..Default::default()
788        });
789
790        let layer1 = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
791        let layer2 = Array2::from_shape_vec((2, 2), vec![0.5, 0.6, 0.7, 0.8]).unwrap();
792
793        global_pruner.add_layer("layer1", layer1);
794        global_pruner.add_layer("layer2", layer2);
795
796        let pruned = global_pruner.prune_all().unwrap();
797        let stats = global_pruner.statistics(&pruned);
798
799        assert_eq!(stats.total_params, 8);
800        assert_eq!(stats.active_params, 4);
801        assert_abs_diff_eq!(stats.pruning_ratio, 0.5, epsilon = 1e-3);
802    }
803
804    #[test]
805    fn test_pruning_stats() {
806        let stats = PruningStats {
807            total_params: 1000,
808            active_params: 200,
809            pruning_ratio: 0.8,
810            iterations: 5,
811            per_layer_stats: HashMap::new(),
812        };
813
814        assert_abs_diff_eq!(stats.compression_ratio(), 5.0, epsilon = 1e-10);
815        assert_abs_diff_eq!(stats.flops_reduction(), 0.8, epsilon = 1e-10);
816
817        let summary = stats.summary();
818        assert!(summary.contains("1000"));
819        assert!(summary.contains("200"));
820        assert!(summary.contains("5.00x"));
821    }
822
823    #[test]
824    fn test_gradient_pruner_fallback() {
825        let weights =
826            Array2::from_shape_vec((3, 3), vec![0.1, 0.5, 0.9, 0.2, 0.6, 0.01, 0.3, 0.7, 0.8])
827                .unwrap();
828
829        let config = PruningConfig {
830            pruning_ratio: 0.3,
831            ..Default::default()
832        };
833
834        let pruner = GradientPruner::new(config);
835        let (_pruned, mask) = pruner.prune(&weights).unwrap();
836
837        // Without gradient history, should fall back to magnitude pruning
838        let active_count = mask.iter().filter(|&&m| m > 0.5).count();
839        // With 30% pruning ratio, prune_count = (9 * 0.3) = 2.7 -> 2
840        // So we keep 9 - 2 = 7 weights
841        let prune_count = (9.0 * 0.3) as usize;
842        let expected_active = 9 - prune_count;
843        assert_eq!(active_count, expected_active);
844    }
845
846    #[test]
847    fn test_gradient_pruner_with_history() {
848        let weights = Array2::from_shape_vec((2, 2), vec![0.5, 0.6, 0.7, 0.8]).unwrap();
849
850        let grads1 = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
851
852        let grads2 = Array2::from_shape_vec((2, 2), vec![0.15, 0.25, 0.35, 0.45]).unwrap();
853
854        let config = PruningConfig {
855            pruning_ratio: 0.5,
856            ..Default::default()
857        };
858
859        let mut pruner = GradientPruner::new(config);
860        pruner.update_gradients("layer1", grads1);
861        pruner.update_gradients("layer1", grads2);
862
863        let (pruned, _mask) = pruner.prune_with_history(&weights, "layer1").unwrap();
864
865        // Weights with smallest average gradients should be pruned
866        // Average gradients: [0.125, 0.225, 0.325, 0.425]
867        // Should prune the two smallest (0.125, 0.225)
868        assert_abs_diff_eq!(pruned[[0, 0]], 0.0, epsilon = 1e-10);
869        assert_abs_diff_eq!(pruned[[0, 1]], 0.0, epsilon = 1e-10);
870    }
871
872    #[test]
873    fn test_exponential_schedule() {
874        let mut pruner = MagnitudePruner::new(PruningConfig {
875            pruning_ratio: 0.0,
876            iterative: true,
877            num_iterations: 5,
878            initial_ratio: 0.1,
879            final_ratio: 0.9,
880            schedule: "exponential".to_string(),
881            ..Default::default()
882        });
883
884        pruner.update_ratio(0);
885        let ratio_0 = pruner.current_ratio;
886        pruner.update_ratio(2);
887        let ratio_2 = pruner.current_ratio;
888        pruner.update_ratio(4);
889        let ratio_4 = pruner.current_ratio;
890
891        // Exponential schedule should have larger jumps later
892        assert!(ratio_0 < ratio_2);
893        assert!(ratio_2 < ratio_4);
894        assert_abs_diff_eq!(ratio_0, 0.1, epsilon = 1e-2);
895        assert_abs_diff_eq!(ratio_4, 0.9, epsilon = 1e-2);
896    }
897
898    #[test]
899    fn test_cosine_schedule() {
900        let mut pruner = MagnitudePruner::new(PruningConfig {
901            pruning_ratio: 0.0,
902            iterative: true,
903            num_iterations: 5,
904            initial_ratio: 0.1,
905            final_ratio: 0.9,
906            schedule: "cosine".to_string(),
907            ..Default::default()
908        });
909
910        pruner.update_ratio(0);
911        let ratio_0 = pruner.current_ratio;
912        pruner.update_ratio(4);
913        let ratio_4 = pruner.current_ratio;
914
915        assert_abs_diff_eq!(ratio_0, 0.1, epsilon = 1e-2);
916        assert_abs_diff_eq!(ratio_4, 0.9, epsilon = 1e-2);
917    }
918
919    #[test]
920    fn test_min_threshold() {
921        let weights = Array2::from_shape_vec((2, 2), vec![1e-10, 1e-9, 1e-8, 0.5]).unwrap();
922
923        let config = PruningConfig {
924            pruning_ratio: 0.0,  // Don't prune by ratio
925            min_threshold: 1e-7, // But prune by threshold
926            ..Default::default()
927        };
928
929        let pruner = MagnitudePruner::new(config);
930        let (pruned, _mask) = pruner.prune(&weights).unwrap();
931
932        // All weights below threshold should be pruned
933        assert_abs_diff_eq!(pruned[[0, 0]], 0.0, epsilon = 1e-10);
934        assert_abs_diff_eq!(pruned[[0, 1]], 0.0, epsilon = 1e-10);
935        assert_abs_diff_eq!(pruned[[1, 0]], 0.0, epsilon = 1e-10);
936        assert_abs_diff_eq!(pruned[[1, 1]], 0.5, epsilon = 1e-10);
937    }
938
939    #[test]
940    fn test_apply_mask() {
941        let weights = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
942
943        let mask = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 1.0, 0.0])
944            .unwrap()
945            .into_dyn();
946
947        let config = PruningConfig::default();
948        let pruner = MagnitudePruner::new(config);
949        let pruned = pruner.apply_mask(&weights, &mask).unwrap();
950
951        assert_abs_diff_eq!(pruned[[0, 0]], 1.0, epsilon = 1e-10);
952        assert_abs_diff_eq!(pruned[[0, 1]], 0.0, epsilon = 1e-10);
953        assert_abs_diff_eq!(pruned[[1, 0]], 3.0, epsilon = 1e-10);
954        assert_abs_diff_eq!(pruned[[1, 1]], 0.0, epsilon = 1e-10);
955    }
956
957    #[test]
958    fn test_structured_both_axes() {
959        let weights = Array2::from_shape_vec(
960            (4, 4),
961            vec![
962                0.1, 0.1, 0.8, 0.1, 0.1, 0.1, 0.8, 0.1, 0.1, 0.1, 0.8, 0.1, 0.9, 0.9, 0.1, 0.9,
963            ],
964        )
965        .unwrap();
966
967        let config = PruningConfig {
968            pruning_ratio: 0.25,
969            structured: true,
970            ..Default::default()
971        };
972
973        let pruner = StructuredPruner::new(config, StructuredPruningAxis::Both);
974        let (_pruned, _mask) = pruner.prune(&weights).unwrap();
975
976        // Should prune both rows and columns based on L2 norms
977        // This is a complex test; we just verify it doesn't panic
978    }
979}