Skip to main content

tensorlogic_infer/
pruning.rs

1//! Structured sparsity patterns and pruning for TensorLogic.
2//!
3//! Provides magnitude-based and structured pruning to reduce model size
4//! while maintaining hardware efficiency through regular sparsity patterns.
5
6use ndarray::{Array1, Array2, ArrayD};
7use thiserror::Error;
8
9/// Errors from pruning operations.
10#[derive(Debug, Error)]
11pub enum PruningError {
12    #[error("Invalid sparsity ratio: {0}. Must be in [0, 1).")]
13    InvalidSparsityRatio(f64),
14    #[error("Shape mismatch: {0}")]
15    ShapeMismatch(String),
16    #[error("Block size {0} does not divide dimension {1}")]
17    InvalidBlockSize(usize, usize),
18    #[error("Empty tensor")]
19    EmptyTensor,
20}
21
22/// Structured sparsity pattern types.
23#[derive(Debug, Clone, PartialEq)]
24pub enum SparsityPattern {
25    /// Element-wise (unstructured) sparsity — zero out individual elements.
26    Unstructured,
27    /// Block sparsity — zero out rectangular blocks of size (block_h × block_w).
28    Block { block_h: usize, block_w: usize },
29    /// Channel/row sparsity — zero out entire rows (for weight matrices).
30    Row,
31    /// Column sparsity — zero out entire columns.
32    Column,
33    /// N:M sparsity — keep N non-zero values per group of M (common on Ampere+ GPUs).
34    NM { n: usize, m: usize },
35}
36
37impl SparsityPattern {
38    /// Human-readable name for this pattern.
39    pub fn name(&self) -> &'static str {
40        match self {
41            SparsityPattern::Unstructured => "unstructured",
42            SparsityPattern::Block { .. } => "block",
43            SparsityPattern::Row => "row",
44            SparsityPattern::Column => "column",
45            SparsityPattern::NM { .. } => "n:m",
46        }
47    }
48
49    /// Whether this is a structured pattern (more hardware efficient).
50    pub fn is_structured(&self) -> bool {
51        !matches!(self, SparsityPattern::Unstructured)
52    }
53}
54
55/// Statistics about the sparsity of a pruned tensor.
56#[derive(Debug, Clone)]
57pub struct SparsityStats {
58    /// Fraction of zero elements (0.0 = dense, 1.0 = all-zero).
59    pub actual_sparsity: f64,
60    /// Number of zero elements.
61    pub zero_count: usize,
62    /// Total number of elements.
63    pub total_count: usize,
64    /// Theoretical compute speedup from sparsity (rough estimate).
65    pub theoretical_speedup: f64,
66    /// Pattern used for pruning.
67    pub pattern: SparsityPattern,
68}
69
70impl SparsityStats {
71    /// Compute sparsity statistics for a tensor with the given pattern.
72    pub fn compute(tensor: &ArrayD<f64>, pattern: SparsityPattern) -> Self {
73        let total_count = tensor.len();
74        let zero_count = tensor.iter().filter(|&&v| v == 0.0).count();
75        let actual_sparsity = if total_count == 0 {
76            0.0
77        } else {
78            zero_count as f64 / total_count as f64
79        };
80        // Rough speedup: structured sparsity > unstructured
81        let theoretical_speedup = if pattern.is_structured() {
82            1.0 / (1.0 - actual_sparsity + 1e-9)
83        } else {
84            1.0 + actual_sparsity * 0.5 // unstructured has limited hw benefit
85        };
86        SparsityStats {
87            actual_sparsity,
88            zero_count,
89            total_count,
90            theoretical_speedup,
91            pattern,
92        }
93    }
94}
95
96/// Configuration for the pruning process.
97#[derive(Debug, Clone)]
98pub struct PruningConfig {
99    /// Target sparsity ratio [0, 1).
100    pub target_sparsity: f64,
101    /// Sparsity pattern to apply.
102    pub pattern: SparsityPattern,
103    /// Whether to rescale remaining weights after pruning.
104    pub rescale: bool,
105}
106
107impl PruningConfig {
108    /// Create a new pruning config with the given sparsity and pattern.
109    ///
110    /// Returns an error if `target_sparsity` is not in `[0, 1)`.
111    pub fn new(target_sparsity: f64, pattern: SparsityPattern) -> Result<Self, PruningError> {
112        if !(0.0..1.0).contains(&target_sparsity) {
113            return Err(PruningError::InvalidSparsityRatio(target_sparsity));
114        }
115        Ok(PruningConfig {
116            target_sparsity,
117            pattern,
118            rescale: false,
119        })
120    }
121
122    /// Set whether to rescale non-zero weights after pruning.
123    pub fn with_rescale(mut self, rescale: bool) -> Self {
124        self.rescale = rescale;
125        self
126    }
127}
128
129/// Magnitude-based pruner: zero out elements with smallest absolute values.
130pub struct MagnitudePruner {
131    config: PruningConfig,
132}
133
134impl MagnitudePruner {
135    /// Create a new magnitude pruner with the given config.
136    pub fn new(config: PruningConfig) -> Self {
137        MagnitudePruner { config }
138    }
139
140    /// Prune a 2D matrix in-place according to the pattern.
141    pub fn prune_2d(&self, matrix: &mut Array2<f64>) -> Result<SparsityStats, PruningError> {
142        if matrix.is_empty() {
143            return Err(PruningError::EmptyTensor);
144        }
145        match &self.config.pattern {
146            SparsityPattern::Unstructured => {
147                self.prune_unstructured_2d(matrix)?;
148            }
149            SparsityPattern::Block { block_h, block_w } => {
150                self.prune_block_2d(matrix, *block_h, *block_w)?;
151            }
152            SparsityPattern::Row => {
153                self.prune_rows_2d(matrix)?;
154            }
155            SparsityPattern::Column => {
156                self.prune_columns_2d(matrix)?;
157            }
158            SparsityPattern::NM { n, m } => {
159                self.prune_nm_2d(matrix, *n, *m)?;
160            }
161        }
162        if self.config.rescale {
163            self.rescale_nonzero(matrix);
164        }
165        Ok(SparsityStats::compute(
166            &matrix.clone().into_dyn(),
167            self.config.pattern.clone(),
168        ))
169    }
170
171    /// Prune a general N-D tensor (applies unstructured or falls back to 2D for structured patterns).
172    pub fn prune(&self, tensor: &mut ArrayD<f64>) -> Result<SparsityStats, PruningError> {
173        if tensor.is_empty() {
174            return Err(PruningError::EmptyTensor);
175        }
176        match &self.config.pattern {
177            SparsityPattern::Unstructured => {
178                self.prune_unstructured_nd(tensor)?;
179            }
180            _ => {
181                // For structured patterns, require 2D
182                if tensor.ndim() != 2 {
183                    return Err(PruningError::ShapeMismatch(format!(
184                        "Structured pruning requires 2D tensor, got {}D",
185                        tensor.ndim()
186                    )));
187                }
188                let mut mat = tensor
189                    .clone()
190                    .into_dimensionality::<ndarray::Ix2>()
191                    .map_err(|e| PruningError::ShapeMismatch(e.to_string()))?;
192                self.prune_2d(&mut mat)?;
193                *tensor = mat.into_dyn();
194            }
195        }
196        Ok(SparsityStats::compute(tensor, self.config.pattern.clone()))
197    }
198
199    fn prune_unstructured_nd(&self, tensor: &mut ArrayD<f64>) -> Result<(), PruningError> {
200        let k = ((1.0 - self.config.target_sparsity) * tensor.len() as f64).ceil() as usize;
201        let mut mags: Vec<f64> = tensor.iter().map(|v| v.abs()).collect();
202        mags.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
203        let threshold = if k < mags.len() {
204            mags[mags.len() - k]
205        } else {
206            0.0
207        };
208        tensor.mapv_inplace(|v| if v.abs() >= threshold { v } else { 0.0 });
209        Ok(())
210    }
211
212    fn prune_unstructured_2d(&self, matrix: &mut Array2<f64>) -> Result<(), PruningError> {
213        let k = ((1.0 - self.config.target_sparsity) * matrix.len() as f64).ceil() as usize;
214        let mut mags: Vec<f64> = matrix.iter().map(|v| v.abs()).collect();
215        mags.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
216        let threshold = if k < mags.len() {
217            mags[mags.len() - k]
218        } else {
219            0.0
220        };
221        matrix.mapv_inplace(|v| if v.abs() >= threshold { v } else { 0.0 });
222        Ok(())
223    }
224
225    fn prune_rows_2d(&self, matrix: &mut Array2<f64>) -> Result<(), PruningError> {
226        let nrows = matrix.nrows();
227        let n_prune = (self.config.target_sparsity * nrows as f64).round() as usize;
228        // Compute row L2 norms
229        let mut norms: Vec<(usize, f64)> = (0..nrows)
230            .map(|i| {
231                let norm: f64 = matrix.row(i).iter().map(|v| v * v).sum::<f64>().sqrt();
232                (i, norm)
233            })
234            .collect();
235        norms.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
236        for &(row_idx, _) in &norms[..n_prune] {
237            matrix.row_mut(row_idx).fill(0.0);
238        }
239        Ok(())
240    }
241
242    fn prune_columns_2d(&self, matrix: &mut Array2<f64>) -> Result<(), PruningError> {
243        let ncols = matrix.ncols();
244        let n_prune = (self.config.target_sparsity * ncols as f64).round() as usize;
245        let mut norms: Vec<(usize, f64)> = (0..ncols)
246            .map(|j| {
247                let norm: f64 = matrix.column(j).iter().map(|v| v * v).sum::<f64>().sqrt();
248                (j, norm)
249            })
250            .collect();
251        norms.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
252        for &(col_idx, _) in &norms[..n_prune] {
253            matrix.column_mut(col_idx).fill(0.0);
254        }
255        Ok(())
256    }
257
258    fn prune_block_2d(
259        &self,
260        matrix: &mut Array2<f64>,
261        bh: usize,
262        bw: usize,
263    ) -> Result<(), PruningError> {
264        let (rows, cols) = (matrix.nrows(), matrix.ncols());
265        if rows % bh != 0 {
266            return Err(PruningError::InvalidBlockSize(bh, rows));
267        }
268        if cols % bw != 0 {
269            return Err(PruningError::InvalidBlockSize(bw, cols));
270        }
271        let n_blocks_r = rows / bh;
272        let n_blocks_c = cols / bw;
273        let total_blocks = n_blocks_r * n_blocks_c;
274        let n_prune = (self.config.target_sparsity * total_blocks as f64).round() as usize;
275        // Compute block norms
276        let mut block_norms: Vec<(usize, usize, f64)> = Vec::with_capacity(total_blocks);
277        for br in 0..n_blocks_r {
278            for bc in 0..n_blocks_c {
279                let norm: f64 = matrix
280                    .slice(ndarray::s![br * bh..(br + 1) * bh, bc * bw..(bc + 1) * bw])
281                    .iter()
282                    .map(|v| v * v)
283                    .sum::<f64>()
284                    .sqrt();
285                block_norms.push((br, bc, norm));
286            }
287        }
288        block_norms.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
289        for &(br, bc, _) in &block_norms[..n_prune] {
290            matrix
291                .slice_mut(ndarray::s![br * bh..(br + 1) * bh, bc * bw..(bc + 1) * bw])
292                .fill(0.0);
293        }
294        Ok(())
295    }
296
297    fn prune_nm_2d(
298        &self,
299        matrix: &mut Array2<f64>,
300        n: usize,
301        m: usize,
302    ) -> Result<(), PruningError> {
303        if n >= m {
304            return Err(PruningError::InvalidBlockSize(n, m));
305        }
306        // For each row, for each group of m consecutive elements, keep top-n by magnitude
307        let ncols = matrix.ncols();
308        for i in 0..matrix.nrows() {
309            let mut col = 0;
310            while col + m <= ncols {
311                let group: Vec<f64> = (col..col + m).map(|j| matrix[[i, j]]).collect();
312                let mut mags: Vec<(usize, f64)> = group
313                    .iter()
314                    .enumerate()
315                    .map(|(j, &v)| (j, v.abs()))
316                    .collect();
317                mags.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
318                let keep: std::collections::HashSet<usize> =
319                    mags[..n].iter().map(|&(j, _)| j).collect();
320                for j in 0..m {
321                    if !keep.contains(&j) {
322                        matrix[[i, col + j]] = 0.0;
323                    }
324                }
325                col += m;
326            }
327        }
328        Ok(())
329    }
330
331    fn rescale_nonzero(&self, matrix: &mut Array2<f64>) {
332        let total = matrix.len() as f64;
333        let nonzero = matrix.iter().filter(|&&v| v != 0.0).count() as f64;
334        if nonzero > 0.0 {
335            let scale = total / nonzero;
336            matrix.mapv_inplace(|v| if v != 0.0 { v * scale } else { 0.0 });
337        }
338    }
339}
340
341/// Compute sparsity statistics for an N-D tensor.
342///
343/// Returns the fraction of zero elements (0.0 = fully dense, 1.0 = all zeros).
344pub fn compute_sparsity(tensor: &ArrayD<f64>) -> f64 {
345    if tensor.is_empty() {
346        return 0.0;
347    }
348    let zeros = tensor.iter().filter(|&&v| v == 0.0).count();
349    zeros as f64 / tensor.len() as f64
350}
351
352/// Compute per-row L2 norms for a 2D matrix.
353pub fn row_norms(matrix: &Array2<f64>) -> Array1<f64> {
354    Array1::from_iter(
355        matrix
356            .rows()
357            .into_iter()
358            .map(|row| row.iter().map(|v| v * v).sum::<f64>().sqrt()),
359    )
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use ndarray::array;
366
367    // Helper: build a 2D array and convert to dynamic.
368    fn dyn2d(data: Array2<f64>) -> ArrayD<f64> {
369        data.into_dyn()
370    }
371
372    #[test]
373    fn test_sparsity_pattern_names() {
374        assert_eq!(SparsityPattern::Unstructured.name(), "unstructured");
375        assert_eq!(
376            SparsityPattern::Block {
377                block_h: 2,
378                block_w: 2
379            }
380            .name(),
381            "block"
382        );
383        assert_eq!(SparsityPattern::Row.name(), "row");
384        assert_eq!(SparsityPattern::Column.name(), "column");
385        assert_eq!(SparsityPattern::NM { n: 2, m: 4 }.name(), "n:m");
386    }
387
388    #[test]
389    fn test_sparsity_pattern_is_structured() {
390        assert!(!SparsityPattern::Unstructured.is_structured());
391        assert!(SparsityPattern::Block {
392            block_h: 2,
393            block_w: 2
394        }
395        .is_structured());
396        assert!(SparsityPattern::Row.is_structured());
397        assert!(SparsityPattern::Column.is_structured());
398        assert!(SparsityPattern::NM { n: 1, m: 4 }.is_structured());
399    }
400
401    #[test]
402    fn test_pruning_config_invalid_ratio() {
403        let result = PruningConfig::new(1.0, SparsityPattern::Unstructured);
404        assert!(result.is_err());
405        let result_neg = PruningConfig::new(-0.1, SparsityPattern::Unstructured);
406        assert!(result_neg.is_err());
407    }
408
409    #[test]
410    fn test_pruning_config_valid() {
411        let result = PruningConfig::new(0.5, SparsityPattern::Unstructured);
412        assert!(result.is_ok());
413        let cfg = result.expect("valid config");
414        assert!((cfg.target_sparsity - 0.5).abs() < 1e-10);
415    }
416
417    #[test]
418    fn test_unstructured_pruning_zeros_out() {
419        // 4×4 matrix, 50% sparsity → ~8 zeros out of 16
420        let mut mat = array![
421            [1.0, 2.0, 3.0, 4.0],
422            [5.0, 6.0, 7.0, 8.0],
423            [9.0, 10.0, 11.0, 12.0],
424            [13.0, 14.0, 15.0, 16.0],
425        ];
426        let cfg = PruningConfig::new(0.5, SparsityPattern::Unstructured).expect("valid config");
427        let pruner = MagnitudePruner::new(cfg);
428        let stats = pruner.prune_2d(&mut mat).expect("prune ok");
429        // Should have ~50% zeros
430        assert!(stats.actual_sparsity >= 0.4 && stats.actual_sparsity <= 0.6);
431    }
432
433    #[test]
434    fn test_unstructured_preserves_largest() {
435        // Elements 10, 20, 30, 40 — with 50% sparsity (keep 50%), 20 and 40 must survive
436        let mut mat = array![[10.0, 20.0], [30.0, 40.0]];
437        let cfg = PruningConfig::new(0.5, SparsityPattern::Unstructured).expect("valid config");
438        let pruner = MagnitudePruner::new(cfg);
439        pruner.prune_2d(&mut mat).expect("prune ok");
440        // Largest two (30, 40) must be non-zero
441        assert!(mat[[1, 0]] != 0.0 || mat[[1, 1]] != 0.0);
442        assert!(mat[[1, 1]] != 0.0); // 40 is the largest, must survive
443    }
444
445    #[test]
446    fn test_row_pruning_zeros_weakest_rows() {
447        // Row 0 has tiny values, row 1 has large values
448        let mut mat = array![[0.001, 0.001], [100.0, 100.0]];
449        let cfg = PruningConfig::new(0.5, SparsityPattern::Row).expect("valid config");
450        let pruner = MagnitudePruner::new(cfg);
451        pruner.prune_2d(&mut mat).expect("prune ok");
452        // Row 0 (weakest) should be zeroed
453        assert_eq!(mat[[0, 0]], 0.0);
454        assert_eq!(mat[[0, 1]], 0.0);
455        // Row 1 should survive
456        assert!(mat[[1, 0]] != 0.0);
457    }
458
459    #[test]
460    fn test_column_pruning_zeros_weakest_cols() {
461        // Col 0 has tiny values, col 1 has large values
462        let mut mat = array![[0.001, 100.0], [0.001, 100.0]];
463        let cfg = PruningConfig::new(0.5, SparsityPattern::Column).expect("valid config");
464        let pruner = MagnitudePruner::new(cfg);
465        pruner.prune_2d(&mut mat).expect("prune ok");
466        // Column 0 should be zeroed
467        assert_eq!(mat[[0, 0]], 0.0);
468        assert_eq!(mat[[1, 0]], 0.0);
469        // Column 1 should survive
470        assert!(mat[[0, 1]] != 0.0);
471    }
472
473    #[test]
474    fn test_block_pruning_basic() {
475        // 4×4 with 2×2 blocks → 4 blocks total, 50% → 2 blocks zeroed
476        let mut mat = array![
477            [1.0, 2.0, 100.0, 200.0],
478            [3.0, 4.0, 300.0, 400.0],
479            [0.1, 0.2, 50.0, 60.0],
480            [0.3, 0.4, 70.0, 80.0],
481        ];
482        let cfg = PruningConfig::new(
483            0.5,
484            SparsityPattern::Block {
485                block_h: 2,
486                block_w: 2,
487            },
488        )
489        .expect("valid config");
490        let pruner = MagnitudePruner::new(cfg);
491        let stats = pruner.prune_2d(&mut mat).expect("prune ok");
492        // 2 out of 4 blocks zeroed → 50% element sparsity
493        assert!((stats.actual_sparsity - 0.5).abs() < 0.01);
494    }
495
496    #[test]
497    fn test_block_pruning_invalid_size() {
498        // 4 rows but block_h=3 → does not divide
499        let mut mat = array![
500            [1.0, 2.0, 3.0],
501            [4.0, 5.0, 6.0],
502            [7.0, 8.0, 9.0],
503            [10.0, 11.0, 12.0]
504        ];
505        let cfg = PruningConfig::new(
506            0.5,
507            SparsityPattern::Block {
508                block_h: 3,
509                block_w: 3,
510            },
511        )
512        .expect("valid config");
513        let pruner = MagnitudePruner::new(cfg);
514        let result = pruner.prune_2d(&mut mat);
515        assert!(matches!(result, Err(PruningError::InvalidBlockSize(_, _))));
516    }
517
518    #[test]
519    fn test_nm_pruning_keeps_n_per_m() {
520        // 2:4 sparsity on a single row [1, 2, 3, 4]
521        // Keep 2 largest (3 and 4), zero out 1 and 2
522        let mut mat = array![[1.0, 2.0, 3.0, 4.0]];
523        let cfg =
524            PruningConfig::new(0.5, SparsityPattern::NM { n: 2, m: 4 }).expect("valid config");
525        let pruner = MagnitudePruner::new(cfg);
526        pruner.prune_2d(&mut mat).expect("prune ok");
527        let nonzero_count = mat.iter().filter(|&&v| v != 0.0).count();
528        assert_eq!(nonzero_count, 2);
529        // The two largest must survive
530        assert!(mat[[0, 2]] != 0.0); // 3.0
531        assert!(mat[[0, 3]] != 0.0); // 4.0
532    }
533
534    #[test]
535    fn test_nm_invalid_n_ge_m() {
536        let mut mat = array![[1.0, 2.0, 3.0, 4.0]];
537        // n=4 >= m=4 is invalid
538        let cfg =
539            PruningConfig::new(0.1, SparsityPattern::NM { n: 4, m: 4 }).expect("valid config");
540        let pruner = MagnitudePruner::new(cfg);
541        let result = pruner.prune_2d(&mut mat);
542        assert!(matches!(result, Err(PruningError::InvalidBlockSize(_, _))));
543    }
544
545    #[test]
546    fn test_rescale_preserves_sum() {
547        // After rescaling, non-zero elements are scaled by total/nonzero.
548        // With 50% sparsity: the surviving elements are scaled by 2×,
549        // so their sum equals the sum of the top-50% elements × 2.
550        // We verify the rescaled sum is strictly larger than the unrescaled pruned sum.
551        let original = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
552
553        // Prune without rescale
554        let mut mat_no_rescale = original.clone();
555        let cfg_no = PruningConfig::new(0.5, SparsityPattern::Unstructured).expect("valid config");
556        let pruner_no = MagnitudePruner::new(cfg_no);
557        pruner_no.prune_2d(&mut mat_no_rescale).expect("prune ok");
558        let sum_no_rescale: f64 = mat_no_rescale.iter().copied().sum();
559
560        // Prune with rescale
561        let mut mat = original.clone();
562        let cfg = PruningConfig::new(0.5, SparsityPattern::Unstructured)
563            .expect("valid config")
564            .with_rescale(true);
565        let pruner = MagnitudePruner::new(cfg);
566        pruner.prune_2d(&mut mat).expect("prune ok");
567        let sum_rescaled: f64 = mat.iter().copied().sum();
568
569        // Rescaled sum should be larger than the non-rescaled pruned sum
570        // (weights are scaled up to compensate for pruned weights)
571        assert!(
572            sum_rescaled > sum_no_rescale,
573            "rescaled sum ({sum_rescaled}) should exceed unrescaled pruned sum ({sum_no_rescale})"
574        );
575        // And the number of non-zero elements should be the same
576        let nz_no = mat_no_rescale.iter().filter(|&&v| v != 0.0).count();
577        let nz_rescaled = mat.iter().filter(|&&v| v != 0.0).count();
578        assert_eq!(
579            nz_no, nz_rescaled,
580            "rescale should not change which elements are zero"
581        );
582    }
583
584    #[test]
585    fn test_sparsity_stats_compute() {
586        let mat = array![[0.0, 1.0], [0.0, 2.0]];
587        let stats = SparsityStats::compute(&mat.into_dyn(), SparsityPattern::Unstructured);
588        assert_eq!(stats.zero_count, 2);
589        assert_eq!(stats.total_count, 4);
590        assert!((stats.actual_sparsity - 0.5).abs() < 1e-10);
591    }
592
593    #[test]
594    fn test_sparsity_stats_speedup_structured() {
595        // A 75% sparse tensor: structured speedup should exceed unstructured speedup
596        let mat = array![[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 2.0]];
597        let structured_stats =
598            SparsityStats::compute(&mat.clone().into_dyn(), SparsityPattern::Row);
599        let unstructured_stats =
600            SparsityStats::compute(&mat.into_dyn(), SparsityPattern::Unstructured);
601        assert!(
602            structured_stats.theoretical_speedup > unstructured_stats.theoretical_speedup,
603            "structured speedup ({}) should exceed unstructured ({})",
604            structured_stats.theoretical_speedup,
605            unstructured_stats.theoretical_speedup
606        );
607    }
608
609    #[test]
610    fn test_compute_sparsity_dense() {
611        let mat = array![[1.0, 2.0], [3.0, 4.0]];
612        let sparsity = compute_sparsity(&mat.into_dyn());
613        assert!((sparsity - 0.0).abs() < 1e-10);
614    }
615
616    #[test]
617    fn test_compute_sparsity_half() {
618        let mat = array![[0.0, 1.0], [0.0, 2.0]];
619        let sparsity = compute_sparsity(&mat.into_dyn());
620        assert!((sparsity - 0.5).abs() < 1e-10);
621    }
622
623    #[test]
624    fn test_row_norms_correctness() {
625        let mat = array![[3.0, 4.0], [0.0, 0.0]];
626        let norms = row_norms(&mat);
627        assert!(
628            (norms[0] - 5.0).abs() < 1e-10,
629            "norm[0] should be 5.0, got {}",
630            norms[0]
631        );
632        assert!(
633            (norms[1] - 0.0).abs() < 1e-10,
634            "norm[1] should be 0.0, got {}",
635            norms[1]
636        );
637    }
638
639    #[test]
640    fn test_prune_nd_tensor() {
641        // 3D tensor (2×3×4), apply unstructured pruning
642        use ndarray::Array3;
643        let data: Array3<f64> =
644            Array3::from_shape_fn((2, 3, 4), |(i, j, k)| (i * 12 + j * 4 + k + 1) as f64);
645        let mut tensor = data.into_dyn();
646        let cfg = PruningConfig::new(0.5, SparsityPattern::Unstructured).expect("valid config");
647        let pruner = MagnitudePruner::new(cfg);
648        let stats = pruner.prune(&mut tensor).expect("prune ok");
649        // Roughly 50% should be zeros
650        assert!(
651            stats.actual_sparsity >= 0.4 && stats.actual_sparsity <= 0.6,
652            "sparsity={} not near 0.5",
653            stats.actual_sparsity
654        );
655    }
656
657    #[test]
658    fn test_prune_empty_tensor() {
659        use ndarray::Array2;
660        let mut empty: ArrayD<f64> = dyn2d(Array2::zeros((0, 4)));
661        let cfg = PruningConfig::new(0.5, SparsityPattern::Unstructured).expect("valid config");
662        let pruner = MagnitudePruner::new(cfg);
663        let result = pruner.prune(&mut empty);
664        assert!(matches!(result, Err(PruningError::EmptyTensor)));
665    }
666}