1use ndarray::{Array1, Array2, ArrayD};
7use thiserror::Error;
8
9#[derive(Debug, Error)]
11pub enum PruningError {
12 #[error("Invalid sparsity ratio: {0}. Must be in [0, 1).")]
13 InvalidSparsityRatio(f64),
14 #[error("Shape mismatch: {0}")]
15 ShapeMismatch(String),
16 #[error("Block size {0} does not divide dimension {1}")]
17 InvalidBlockSize(usize, usize),
18 #[error("Empty tensor")]
19 EmptyTensor,
20}
21
22#[derive(Debug, Clone, PartialEq)]
24pub enum SparsityPattern {
25 Unstructured,
27 Block { block_h: usize, block_w: usize },
29 Row,
31 Column,
33 NM { n: usize, m: usize },
35}
36
37impl SparsityPattern {
38 pub fn name(&self) -> &'static str {
40 match self {
41 SparsityPattern::Unstructured => "unstructured",
42 SparsityPattern::Block { .. } => "block",
43 SparsityPattern::Row => "row",
44 SparsityPattern::Column => "column",
45 SparsityPattern::NM { .. } => "n:m",
46 }
47 }
48
49 pub fn is_structured(&self) -> bool {
51 !matches!(self, SparsityPattern::Unstructured)
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct SparsityStats {
58 pub actual_sparsity: f64,
60 pub zero_count: usize,
62 pub total_count: usize,
64 pub theoretical_speedup: f64,
66 pub pattern: SparsityPattern,
68}
69
70impl SparsityStats {
71 pub fn compute(tensor: &ArrayD<f64>, pattern: SparsityPattern) -> Self {
73 let total_count = tensor.len();
74 let zero_count = tensor.iter().filter(|&&v| v == 0.0).count();
75 let actual_sparsity = if total_count == 0 {
76 0.0
77 } else {
78 zero_count as f64 / total_count as f64
79 };
80 let theoretical_speedup = if pattern.is_structured() {
82 1.0 / (1.0 - actual_sparsity + 1e-9)
83 } else {
84 1.0 + actual_sparsity * 0.5 };
86 SparsityStats {
87 actual_sparsity,
88 zero_count,
89 total_count,
90 theoretical_speedup,
91 pattern,
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct PruningConfig {
99 pub target_sparsity: f64,
101 pub pattern: SparsityPattern,
103 pub rescale: bool,
105}
106
107impl PruningConfig {
108 pub fn new(target_sparsity: f64, pattern: SparsityPattern) -> Result<Self, PruningError> {
112 if !(0.0..1.0).contains(&target_sparsity) {
113 return Err(PruningError::InvalidSparsityRatio(target_sparsity));
114 }
115 Ok(PruningConfig {
116 target_sparsity,
117 pattern,
118 rescale: false,
119 })
120 }
121
122 pub fn with_rescale(mut self, rescale: bool) -> Self {
124 self.rescale = rescale;
125 self
126 }
127}
128
129pub struct MagnitudePruner {
131 config: PruningConfig,
132}
133
134impl MagnitudePruner {
135 pub fn new(config: PruningConfig) -> Self {
137 MagnitudePruner { config }
138 }
139
140 pub fn prune_2d(&self, matrix: &mut Array2<f64>) -> Result<SparsityStats, PruningError> {
142 if matrix.is_empty() {
143 return Err(PruningError::EmptyTensor);
144 }
145 match &self.config.pattern {
146 SparsityPattern::Unstructured => {
147 self.prune_unstructured_2d(matrix)?;
148 }
149 SparsityPattern::Block { block_h, block_w } => {
150 self.prune_block_2d(matrix, *block_h, *block_w)?;
151 }
152 SparsityPattern::Row => {
153 self.prune_rows_2d(matrix)?;
154 }
155 SparsityPattern::Column => {
156 self.prune_columns_2d(matrix)?;
157 }
158 SparsityPattern::NM { n, m } => {
159 self.prune_nm_2d(matrix, *n, *m)?;
160 }
161 }
162 if self.config.rescale {
163 self.rescale_nonzero(matrix);
164 }
165 Ok(SparsityStats::compute(
166 &matrix.clone().into_dyn(),
167 self.config.pattern.clone(),
168 ))
169 }
170
171 pub fn prune(&self, tensor: &mut ArrayD<f64>) -> Result<SparsityStats, PruningError> {
173 if tensor.is_empty() {
174 return Err(PruningError::EmptyTensor);
175 }
176 match &self.config.pattern {
177 SparsityPattern::Unstructured => {
178 self.prune_unstructured_nd(tensor)?;
179 }
180 _ => {
181 if tensor.ndim() != 2 {
183 return Err(PruningError::ShapeMismatch(format!(
184 "Structured pruning requires 2D tensor, got {}D",
185 tensor.ndim()
186 )));
187 }
188 let mut mat = tensor
189 .clone()
190 .into_dimensionality::<ndarray::Ix2>()
191 .map_err(|e| PruningError::ShapeMismatch(e.to_string()))?;
192 self.prune_2d(&mut mat)?;
193 *tensor = mat.into_dyn();
194 }
195 }
196 Ok(SparsityStats::compute(tensor, self.config.pattern.clone()))
197 }
198
199 fn prune_unstructured_nd(&self, tensor: &mut ArrayD<f64>) -> Result<(), PruningError> {
200 let k = ((1.0 - self.config.target_sparsity) * tensor.len() as f64).ceil() as usize;
201 let mut mags: Vec<f64> = tensor.iter().map(|v| v.abs()).collect();
202 mags.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
203 let threshold = if k < mags.len() {
204 mags[mags.len() - k]
205 } else {
206 0.0
207 };
208 tensor.mapv_inplace(|v| if v.abs() >= threshold { v } else { 0.0 });
209 Ok(())
210 }
211
212 fn prune_unstructured_2d(&self, matrix: &mut Array2<f64>) -> Result<(), PruningError> {
213 let k = ((1.0 - self.config.target_sparsity) * matrix.len() as f64).ceil() as usize;
214 let mut mags: Vec<f64> = matrix.iter().map(|v| v.abs()).collect();
215 mags.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
216 let threshold = if k < mags.len() {
217 mags[mags.len() - k]
218 } else {
219 0.0
220 };
221 matrix.mapv_inplace(|v| if v.abs() >= threshold { v } else { 0.0 });
222 Ok(())
223 }
224
225 fn prune_rows_2d(&self, matrix: &mut Array2<f64>) -> Result<(), PruningError> {
226 let nrows = matrix.nrows();
227 let n_prune = (self.config.target_sparsity * nrows as f64).round() as usize;
228 let mut norms: Vec<(usize, f64)> = (0..nrows)
230 .map(|i| {
231 let norm: f64 = matrix.row(i).iter().map(|v| v * v).sum::<f64>().sqrt();
232 (i, norm)
233 })
234 .collect();
235 norms.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
236 for &(row_idx, _) in &norms[..n_prune] {
237 matrix.row_mut(row_idx).fill(0.0);
238 }
239 Ok(())
240 }
241
242 fn prune_columns_2d(&self, matrix: &mut Array2<f64>) -> Result<(), PruningError> {
243 let ncols = matrix.ncols();
244 let n_prune = (self.config.target_sparsity * ncols as f64).round() as usize;
245 let mut norms: Vec<(usize, f64)> = (0..ncols)
246 .map(|j| {
247 let norm: f64 = matrix.column(j).iter().map(|v| v * v).sum::<f64>().sqrt();
248 (j, norm)
249 })
250 .collect();
251 norms.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
252 for &(col_idx, _) in &norms[..n_prune] {
253 matrix.column_mut(col_idx).fill(0.0);
254 }
255 Ok(())
256 }
257
258 fn prune_block_2d(
259 &self,
260 matrix: &mut Array2<f64>,
261 bh: usize,
262 bw: usize,
263 ) -> Result<(), PruningError> {
264 let (rows, cols) = (matrix.nrows(), matrix.ncols());
265 if rows % bh != 0 {
266 return Err(PruningError::InvalidBlockSize(bh, rows));
267 }
268 if cols % bw != 0 {
269 return Err(PruningError::InvalidBlockSize(bw, cols));
270 }
271 let n_blocks_r = rows / bh;
272 let n_blocks_c = cols / bw;
273 let total_blocks = n_blocks_r * n_blocks_c;
274 let n_prune = (self.config.target_sparsity * total_blocks as f64).round() as usize;
275 let mut block_norms: Vec<(usize, usize, f64)> = Vec::with_capacity(total_blocks);
277 for br in 0..n_blocks_r {
278 for bc in 0..n_blocks_c {
279 let norm: f64 = matrix
280 .slice(ndarray::s![br * bh..(br + 1) * bh, bc * bw..(bc + 1) * bw])
281 .iter()
282 .map(|v| v * v)
283 .sum::<f64>()
284 .sqrt();
285 block_norms.push((br, bc, norm));
286 }
287 }
288 block_norms.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
289 for &(br, bc, _) in &block_norms[..n_prune] {
290 matrix
291 .slice_mut(ndarray::s![br * bh..(br + 1) * bh, bc * bw..(bc + 1) * bw])
292 .fill(0.0);
293 }
294 Ok(())
295 }
296
297 fn prune_nm_2d(
298 &self,
299 matrix: &mut Array2<f64>,
300 n: usize,
301 m: usize,
302 ) -> Result<(), PruningError> {
303 if n >= m {
304 return Err(PruningError::InvalidBlockSize(n, m));
305 }
306 let ncols = matrix.ncols();
308 for i in 0..matrix.nrows() {
309 let mut col = 0;
310 while col + m <= ncols {
311 let group: Vec<f64> = (col..col + m).map(|j| matrix[[i, j]]).collect();
312 let mut mags: Vec<(usize, f64)> = group
313 .iter()
314 .enumerate()
315 .map(|(j, &v)| (j, v.abs()))
316 .collect();
317 mags.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
318 let keep: std::collections::HashSet<usize> =
319 mags[..n].iter().map(|&(j, _)| j).collect();
320 for j in 0..m {
321 if !keep.contains(&j) {
322 matrix[[i, col + j]] = 0.0;
323 }
324 }
325 col += m;
326 }
327 }
328 Ok(())
329 }
330
331 fn rescale_nonzero(&self, matrix: &mut Array2<f64>) {
332 let total = matrix.len() as f64;
333 let nonzero = matrix.iter().filter(|&&v| v != 0.0).count() as f64;
334 if nonzero > 0.0 {
335 let scale = total / nonzero;
336 matrix.mapv_inplace(|v| if v != 0.0 { v * scale } else { 0.0 });
337 }
338 }
339}
340
341pub fn compute_sparsity(tensor: &ArrayD<f64>) -> f64 {
345 if tensor.is_empty() {
346 return 0.0;
347 }
348 let zeros = tensor.iter().filter(|&&v| v == 0.0).count();
349 zeros as f64 / tensor.len() as f64
350}
351
352pub fn row_norms(matrix: &Array2<f64>) -> Array1<f64> {
354 Array1::from_iter(
355 matrix
356 .rows()
357 .into_iter()
358 .map(|row| row.iter().map(|v| v * v).sum::<f64>().sqrt()),
359 )
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use ndarray::array;
366
367 fn dyn2d(data: Array2<f64>) -> ArrayD<f64> {
369 data.into_dyn()
370 }
371
372 #[test]
373 fn test_sparsity_pattern_names() {
374 assert_eq!(SparsityPattern::Unstructured.name(), "unstructured");
375 assert_eq!(
376 SparsityPattern::Block {
377 block_h: 2,
378 block_w: 2
379 }
380 .name(),
381 "block"
382 );
383 assert_eq!(SparsityPattern::Row.name(), "row");
384 assert_eq!(SparsityPattern::Column.name(), "column");
385 assert_eq!(SparsityPattern::NM { n: 2, m: 4 }.name(), "n:m");
386 }
387
388 #[test]
389 fn test_sparsity_pattern_is_structured() {
390 assert!(!SparsityPattern::Unstructured.is_structured());
391 assert!(SparsityPattern::Block {
392 block_h: 2,
393 block_w: 2
394 }
395 .is_structured());
396 assert!(SparsityPattern::Row.is_structured());
397 assert!(SparsityPattern::Column.is_structured());
398 assert!(SparsityPattern::NM { n: 1, m: 4 }.is_structured());
399 }
400
401 #[test]
402 fn test_pruning_config_invalid_ratio() {
403 let result = PruningConfig::new(1.0, SparsityPattern::Unstructured);
404 assert!(result.is_err());
405 let result_neg = PruningConfig::new(-0.1, SparsityPattern::Unstructured);
406 assert!(result_neg.is_err());
407 }
408
409 #[test]
410 fn test_pruning_config_valid() {
411 let result = PruningConfig::new(0.5, SparsityPattern::Unstructured);
412 assert!(result.is_ok());
413 let cfg = result.expect("valid config");
414 assert!((cfg.target_sparsity - 0.5).abs() < 1e-10);
415 }
416
417 #[test]
418 fn test_unstructured_pruning_zeros_out() {
419 let mut mat = array![
421 [1.0, 2.0, 3.0, 4.0],
422 [5.0, 6.0, 7.0, 8.0],
423 [9.0, 10.0, 11.0, 12.0],
424 [13.0, 14.0, 15.0, 16.0],
425 ];
426 let cfg = PruningConfig::new(0.5, SparsityPattern::Unstructured).expect("valid config");
427 let pruner = MagnitudePruner::new(cfg);
428 let stats = pruner.prune_2d(&mut mat).expect("prune ok");
429 assert!(stats.actual_sparsity >= 0.4 && stats.actual_sparsity <= 0.6);
431 }
432
433 #[test]
434 fn test_unstructured_preserves_largest() {
435 let mut mat = array![[10.0, 20.0], [30.0, 40.0]];
437 let cfg = PruningConfig::new(0.5, SparsityPattern::Unstructured).expect("valid config");
438 let pruner = MagnitudePruner::new(cfg);
439 pruner.prune_2d(&mut mat).expect("prune ok");
440 assert!(mat[[1, 0]] != 0.0 || mat[[1, 1]] != 0.0);
442 assert!(mat[[1, 1]] != 0.0); }
444
445 #[test]
446 fn test_row_pruning_zeros_weakest_rows() {
447 let mut mat = array![[0.001, 0.001], [100.0, 100.0]];
449 let cfg = PruningConfig::new(0.5, SparsityPattern::Row).expect("valid config");
450 let pruner = MagnitudePruner::new(cfg);
451 pruner.prune_2d(&mut mat).expect("prune ok");
452 assert_eq!(mat[[0, 0]], 0.0);
454 assert_eq!(mat[[0, 1]], 0.0);
455 assert!(mat[[1, 0]] != 0.0);
457 }
458
459 #[test]
460 fn test_column_pruning_zeros_weakest_cols() {
461 let mut mat = array![[0.001, 100.0], [0.001, 100.0]];
463 let cfg = PruningConfig::new(0.5, SparsityPattern::Column).expect("valid config");
464 let pruner = MagnitudePruner::new(cfg);
465 pruner.prune_2d(&mut mat).expect("prune ok");
466 assert_eq!(mat[[0, 0]], 0.0);
468 assert_eq!(mat[[1, 0]], 0.0);
469 assert!(mat[[0, 1]] != 0.0);
471 }
472
473 #[test]
474 fn test_block_pruning_basic() {
475 let mut mat = array![
477 [1.0, 2.0, 100.0, 200.0],
478 [3.0, 4.0, 300.0, 400.0],
479 [0.1, 0.2, 50.0, 60.0],
480 [0.3, 0.4, 70.0, 80.0],
481 ];
482 let cfg = PruningConfig::new(
483 0.5,
484 SparsityPattern::Block {
485 block_h: 2,
486 block_w: 2,
487 },
488 )
489 .expect("valid config");
490 let pruner = MagnitudePruner::new(cfg);
491 let stats = pruner.prune_2d(&mut mat).expect("prune ok");
492 assert!((stats.actual_sparsity - 0.5).abs() < 0.01);
494 }
495
496 #[test]
497 fn test_block_pruning_invalid_size() {
498 let mut mat = array![
500 [1.0, 2.0, 3.0],
501 [4.0, 5.0, 6.0],
502 [7.0, 8.0, 9.0],
503 [10.0, 11.0, 12.0]
504 ];
505 let cfg = PruningConfig::new(
506 0.5,
507 SparsityPattern::Block {
508 block_h: 3,
509 block_w: 3,
510 },
511 )
512 .expect("valid config");
513 let pruner = MagnitudePruner::new(cfg);
514 let result = pruner.prune_2d(&mut mat);
515 assert!(matches!(result, Err(PruningError::InvalidBlockSize(_, _))));
516 }
517
518 #[test]
519 fn test_nm_pruning_keeps_n_per_m() {
520 let mut mat = array![[1.0, 2.0, 3.0, 4.0]];
523 let cfg =
524 PruningConfig::new(0.5, SparsityPattern::NM { n: 2, m: 4 }).expect("valid config");
525 let pruner = MagnitudePruner::new(cfg);
526 pruner.prune_2d(&mut mat).expect("prune ok");
527 let nonzero_count = mat.iter().filter(|&&v| v != 0.0).count();
528 assert_eq!(nonzero_count, 2);
529 assert!(mat[[0, 2]] != 0.0); assert!(mat[[0, 3]] != 0.0); }
533
534 #[test]
535 fn test_nm_invalid_n_ge_m() {
536 let mut mat = array![[1.0, 2.0, 3.0, 4.0]];
537 let cfg =
539 PruningConfig::new(0.1, SparsityPattern::NM { n: 4, m: 4 }).expect("valid config");
540 let pruner = MagnitudePruner::new(cfg);
541 let result = pruner.prune_2d(&mut mat);
542 assert!(matches!(result, Err(PruningError::InvalidBlockSize(_, _))));
543 }
544
545 #[test]
546 fn test_rescale_preserves_sum() {
547 let original = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
552
553 let mut mat_no_rescale = original.clone();
555 let cfg_no = PruningConfig::new(0.5, SparsityPattern::Unstructured).expect("valid config");
556 let pruner_no = MagnitudePruner::new(cfg_no);
557 pruner_no.prune_2d(&mut mat_no_rescale).expect("prune ok");
558 let sum_no_rescale: f64 = mat_no_rescale.iter().copied().sum();
559
560 let mut mat = original.clone();
562 let cfg = PruningConfig::new(0.5, SparsityPattern::Unstructured)
563 .expect("valid config")
564 .with_rescale(true);
565 let pruner = MagnitudePruner::new(cfg);
566 pruner.prune_2d(&mut mat).expect("prune ok");
567 let sum_rescaled: f64 = mat.iter().copied().sum();
568
569 assert!(
572 sum_rescaled > sum_no_rescale,
573 "rescaled sum ({sum_rescaled}) should exceed unrescaled pruned sum ({sum_no_rescale})"
574 );
575 let nz_no = mat_no_rescale.iter().filter(|&&v| v != 0.0).count();
577 let nz_rescaled = mat.iter().filter(|&&v| v != 0.0).count();
578 assert_eq!(
579 nz_no, nz_rescaled,
580 "rescale should not change which elements are zero"
581 );
582 }
583
584 #[test]
585 fn test_sparsity_stats_compute() {
586 let mat = array![[0.0, 1.0], [0.0, 2.0]];
587 let stats = SparsityStats::compute(&mat.into_dyn(), SparsityPattern::Unstructured);
588 assert_eq!(stats.zero_count, 2);
589 assert_eq!(stats.total_count, 4);
590 assert!((stats.actual_sparsity - 0.5).abs() < 1e-10);
591 }
592
593 #[test]
594 fn test_sparsity_stats_speedup_structured() {
595 let mat = array![[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 2.0]];
597 let structured_stats =
598 SparsityStats::compute(&mat.clone().into_dyn(), SparsityPattern::Row);
599 let unstructured_stats =
600 SparsityStats::compute(&mat.into_dyn(), SparsityPattern::Unstructured);
601 assert!(
602 structured_stats.theoretical_speedup > unstructured_stats.theoretical_speedup,
603 "structured speedup ({}) should exceed unstructured ({})",
604 structured_stats.theoretical_speedup,
605 unstructured_stats.theoretical_speedup
606 );
607 }
608
609 #[test]
610 fn test_compute_sparsity_dense() {
611 let mat = array![[1.0, 2.0], [3.0, 4.0]];
612 let sparsity = compute_sparsity(&mat.into_dyn());
613 assert!((sparsity - 0.0).abs() < 1e-10);
614 }
615
616 #[test]
617 fn test_compute_sparsity_half() {
618 let mat = array![[0.0, 1.0], [0.0, 2.0]];
619 let sparsity = compute_sparsity(&mat.into_dyn());
620 assert!((sparsity - 0.5).abs() < 1e-10);
621 }
622
623 #[test]
624 fn test_row_norms_correctness() {
625 let mat = array![[3.0, 4.0], [0.0, 0.0]];
626 let norms = row_norms(&mat);
627 assert!(
628 (norms[0] - 5.0).abs() < 1e-10,
629 "norm[0] should be 5.0, got {}",
630 norms[0]
631 );
632 assert!(
633 (norms[1] - 0.0).abs() < 1e-10,
634 "norm[1] should be 0.0, got {}",
635 norms[1]
636 );
637 }
638
639 #[test]
640 fn test_prune_nd_tensor() {
641 use ndarray::Array3;
643 let data: Array3<f64> =
644 Array3::from_shape_fn((2, 3, 4), |(i, j, k)| (i * 12 + j * 4 + k + 1) as f64);
645 let mut tensor = data.into_dyn();
646 let cfg = PruningConfig::new(0.5, SparsityPattern::Unstructured).expect("valid config");
647 let pruner = MagnitudePruner::new(cfg);
648 let stats = pruner.prune(&mut tensor).expect("prune ok");
649 assert!(
651 stats.actual_sparsity >= 0.4 && stats.actual_sparsity <= 0.6,
652 "sparsity={} not near 0.5",
653 stats.actual_sparsity
654 );
655 }
656
657 #[test]
658 fn test_prune_empty_tensor() {
659 use ndarray::Array2;
660 let mut empty: ArrayD<f64> = dyn2d(Array2::zeros((0, 4)));
661 let cfg = PruningConfig::new(0.5, SparsityPattern::Unstructured).expect("valid config");
662 let pruner = MagnitudePruner::new(cfg);
663 let result = pruner.prune(&mut empty);
664 assert!(matches!(result, Err(PruningError::EmptyTensor)));
665 }
666}