1#![allow(clippy::excessive_nesting)] #![allow(unused_variables)] use crate::tensor::Tensor;
9use anyhow::{anyhow, Result};
10use scirs2_core::random::*; use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone)]
15pub struct PruningConfig {
16 pub target_sparsity: f32,
18 pub iterative: bool,
20 pub iterations: usize,
22 pub fine_tune: bool,
24 pub exclude_layers: HashSet<String>,
26 pub magnitude_threshold: Option<f32>,
28 pub seed: Option<u64>,
30}
31
32impl Default for PruningConfig {
33 fn default() -> Self {
34 Self {
35 target_sparsity: 0.5,
36 iterative: false,
37 iterations: 1,
38 fine_tune: true,
39 exclude_layers: HashSet::new(),
40 magnitude_threshold: None,
41 seed: None,
42 }
43 }
44}
45
46pub trait PruningStrategy: Send + Sync {
48 fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor>;
50
51 fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor>;
53
54 fn name(&self) -> &str;
56}
57
58#[derive(Debug, Clone)]
60pub struct PruningResult<M>
61where
62 M: crate::traits::Model,
63{
64 pub model: M,
65 pub sparsity: f32,
66 pub pruned_params: usize,
67 pub total_params: usize,
68 pub layer_sparsity: HashMap<String, f32>,
69}
70
71pub trait Pruner: Send + Sync {
73 fn prune<M>(&self, model: M, config: &PruningConfig) -> Result<PruningResult<M>>
75 where
76 M: crate::traits::Model + Clone;
77
78 fn estimate_pruning_potential<M>(
80 &self,
81 model: &M,
82 config: &PruningConfig,
83 ) -> Result<PruningStats>
84 where
85 M: crate::traits::Model;
86}
87
88#[derive(Debug, Clone)]
90pub struct PruningStats {
91 pub total_params: usize,
92 pub zero_params: usize,
93 pub sparsity: f32,
94 pub layer_stats: HashMap<String, LayerPruningStats>,
95}
96
97#[derive(Debug, Clone)]
98pub struct LayerPruningStats {
99 pub total_params: usize,
100 pub zero_params: usize,
101 pub sparsity: f32,
102}
103
104pub struct MagnitudePruner {
106 #[allow(dead_code)]
107 threshold: f32,
108}
109
110impl MagnitudePruner {
111 pub fn new(sparsity: f32) -> Self {
112 Self {
113 threshold: sparsity,
114 }
115 }
116}
117
118impl PruningStrategy for MagnitudePruner {
119 fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
120 let mask = self.get_mask(weights, config)?;
121
122 let pruned = weights
124 .data()?
125 .iter()
126 .zip(mask.data()?.iter())
127 .map(|(w, m)| if *m > 0.5 { *w } else { 0.0 })
128 .collect::<Vec<_>>();
129
130 Ok(Tensor::from_vec(pruned, &weights.shape())?)
131 }
132
133 fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
134 let data = weights.data()?;
135 let mut abs_weights: Vec<(f32, usize)> =
136 data.iter().enumerate().map(|(i, &w)| (w.abs(), i)).collect();
137
138 abs_weights.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
140
141 let num_prune = (data.len() as f32 * config.target_sparsity) as usize;
143 let mut mask = vec![1.0; data.len()];
144
145 for i in 0..num_prune.min(abs_weights.len()) {
147 mask[abs_weights[i].1] = 0.0;
148 }
149
150 Ok(Tensor::from_vec(mask, &weights.shape())?)
151 }
152
153 fn name(&self) -> &str {
154 "MagnitudePruner"
155 }
156}
157
158pub struct StructuredPruner {
160 pruning_dim: usize,
161}
162
163impl StructuredPruner {
164 pub fn new(pruning_dim: usize) -> Self {
165 Self { pruning_dim }
166 }
167}
168
169impl PruningStrategy for StructuredPruner {
170 fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
171 let shape = weights.shape();
173 if shape.len() < 2 {
174 return Err(anyhow!("Structured pruning requires at least 2D tensors"));
175 }
176
177 let importance_scores = self.calculate_importance(weights)?;
179
180 let num_structures = shape[self.pruning_dim];
182 let num_prune = (num_structures as f32 * config.target_sparsity) as usize;
183
184 let mut indices: Vec<usize> = (0..num_structures).collect();
185 indices.sort_by(|&a, &b| {
186 importance_scores[a]
187 .partial_cmp(&importance_scores[b])
188 .expect("Partial comparison failed")
189 });
190
191 let pruned_indices: HashSet<_> = indices.iter().take(num_prune).cloned().collect();
192
193 let data = weights.data()?;
195 let mut pruned_data = Vec::with_capacity(data.len());
196
197 for (i, &val) in data.iter().enumerate() {
199 let structure_idx = (i / shape.iter().skip(self.pruning_dim + 1).product::<usize>())
200 % shape[self.pruning_dim];
201
202 if pruned_indices.contains(&structure_idx) {
203 pruned_data.push(0.0);
204 } else {
205 pruned_data.push(val);
206 }
207 }
208
209 Ok(Tensor::from_vec(pruned_data, &shape)?)
210 }
211
212 fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
213 Ok(Tensor::ones(&weights.shape())?)
215 }
216
217 fn name(&self) -> &str {
218 "StructuredPruner"
219 }
220}
221
222impl StructuredPruner {
223 fn calculate_importance(&self, weights: &Tensor) -> Result<Vec<f32>> {
224 let shape = weights.shape();
225 let num_structures = shape[self.pruning_dim];
226 let mut importance = vec![0.0; num_structures];
227
228 let data = weights.data()?;
230 let structure_size = shape.iter().skip(self.pruning_dim + 1).product::<usize>();
231 let structures_per_batch = shape.iter().take(self.pruning_dim).product::<usize>();
232
233 for (i, importance_ref) in importance.iter_mut().enumerate() {
234 let mut sum_sq = 0.0;
235 for j in 0..structures_per_batch {
236 for k in 0..structure_size {
237 let idx = j * num_structures * structure_size + i * structure_size + k;
238 if idx < data.len() {
239 sum_sq += data[idx] * data[idx];
240 }
241 }
242 }
243 *importance_ref = sum_sq.sqrt();
244 }
245
246 Ok(importance)
247 }
248}
249
250pub struct UnstructuredPruner {
252 random: bool,
253}
254
255impl UnstructuredPruner {
256 pub fn new(random: bool) -> Self {
257 Self { random }
258 }
259}
260
261impl PruningStrategy for UnstructuredPruner {
262 fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
263 let data = weights.data()?;
264 let num_prune = (data.len() as f32 * config.target_sparsity) as usize;
265
266 let mut pruned = data.to_vec();
267
268 if self.random {
269 let mut rng = thread_rng();
271 let mut indices: Vec<usize> = (0..data.len()).collect();
272
273 for i in (1..indices.len()).rev() {
275 let j = rng.random_range(0..=i);
276 indices.swap(i, j);
277 }
278
279 for i in 0..num_prune.min(indices.len()) {
281 pruned[indices[i]] = 0.0;
282 }
283 } else {
284 let magnitude_pruner = MagnitudePruner::new(config.target_sparsity);
286 return magnitude_pruner.prune_weights(weights, config);
287 }
288
289 Ok(Tensor::from_vec(pruned, &weights.shape())?)
290 }
291
292 fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
293 let data = weights.data()?;
294 let num_prune = (data.len() as f32 * config.target_sparsity) as usize;
295 let mut mask = vec![1.0; data.len()];
296
297 if self.random {
298 let mut rng = thread_rng();
299 let mut indices: Vec<usize> = (0..data.len()).collect();
300
301 for i in (1..indices.len()).rev() {
302 let j = rng.random_range(0..=i);
303 indices.swap(i, j);
304 }
305
306 for i in 0..num_prune.min(indices.len()) {
307 mask[indices[i]] = 0.0;
308 }
309 }
310
311 Ok(Tensor::from_vec(mask, &weights.shape())?)
312 }
313
314 fn name(&self) -> &str {
315 "UnstructuredPruner"
316 }
317}
318
319pub struct GradualPruner {
321 initial_sparsity: f32,
322 final_sparsity: f32,
323 begin_step: usize,
324 end_step: usize,
325 #[allow(dead_code)]
326 frequency: usize,
327}
328
329impl GradualPruner {
330 pub fn new(
331 initial_sparsity: f32,
332 final_sparsity: f32,
333 begin_step: usize,
334 end_step: usize,
335 frequency: usize,
336 ) -> Self {
337 Self {
338 initial_sparsity,
339 final_sparsity,
340 begin_step,
341 end_step,
342 frequency,
343 }
344 }
345
346 pub fn get_sparsity_at_step(&self, step: usize) -> f32 {
347 if step < self.begin_step {
348 return 0.0;
349 }
350 if step >= self.end_step {
351 return self.final_sparsity;
352 }
353
354 let progress = (step - self.begin_step) as f32 / (self.end_step - self.begin_step) as f32;
355 self.initial_sparsity + (self.final_sparsity - self.initial_sparsity) * progress
356 }
357}
358
359#[derive(Debug, Clone)]
361pub enum PruningSchedule {
362 OneShot { step: usize },
364 Gradual {
366 begin_step: usize,
367 end_step: usize,
368 frequency: usize,
369 },
370 Iterative {
372 steps: Vec<usize>,
373 sparsities: Vec<f32>,
374 },
375}
376
377pub struct ChannelPruner {
379 importance_metric: ChannelImportanceMetric,
380}
381
382#[derive(Debug, Clone)]
383pub enum ChannelImportanceMetric {
384 L1Norm,
386 L2Norm,
388 MeanActivation,
390 GeometricMedian,
392}
393
394impl ChannelPruner {
395 pub fn new(metric: ChannelImportanceMetric) -> Self {
396 Self {
397 importance_metric: metric,
398 }
399 }
400}
401
402impl PruningStrategy for ChannelPruner {
403 fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
404 let shape = weights.shape();
405 if shape.len() != 4 {
406 return Err(anyhow!("Channel pruning requires 4D tensors (NCHW format)"));
407 }
408
409 let num_channels = shape[1]; let channel_importance = self.calculate_channel_importance(weights)?;
411
412 let num_prune = (num_channels as f32 * config.target_sparsity) as usize;
414 let mut sorted_channels: Vec<(f32, usize)> =
415 channel_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
416 sorted_channels.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
417
418 let pruned_channels: HashSet<usize> =
419 sorted_channels.iter().take(num_prune).map(|(_, idx)| *idx).collect();
420
421 let data = weights.data()?;
423 let mut pruned_data = data.to_vec();
424 let channel_size = shape[2] * shape[3]; let batch_channel_size = num_channels * channel_size;
426
427 for batch in 0..shape[0] {
428 for channel in &pruned_channels {
429 let start_idx = batch * batch_channel_size + channel * channel_size;
430 let end_idx = start_idx + channel_size;
431 for i in start_idx..end_idx.min(pruned_data.len()) {
432 pruned_data[i] = 0.0;
433 }
434 }
435 }
436
437 Ok(Tensor::from_vec(pruned_data, &shape)?)
438 }
439
440 fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
441 let shape = weights.shape();
442 let num_channels = shape[1];
443 let channel_importance = self.calculate_channel_importance(weights)?;
444
445 let num_prune = (num_channels as f32 * config.target_sparsity) as usize;
446 let mut sorted_channels: Vec<(f32, usize)> =
447 channel_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
448 sorted_channels.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
449
450 let pruned_channels: HashSet<usize> =
451 sorted_channels.iter().take(num_prune).map(|(_, idx)| *idx).collect();
452
453 let data = weights.data()?;
454 let mut mask = vec![1.0; data.len()];
455 let channel_size = shape[2] * shape[3];
456 let batch_channel_size = num_channels * channel_size;
457
458 for batch in 0..shape[0] {
459 for channel in &pruned_channels {
460 let start_idx = batch * batch_channel_size + channel * channel_size;
461 let end_idx = start_idx + channel_size;
462 for i in start_idx..end_idx.min(mask.len()) {
463 mask[i] = 0.0;
464 }
465 }
466 }
467
468 Ok(Tensor::from_vec(mask, &shape)?)
469 }
470
471 fn name(&self) -> &str {
472 "ChannelPruner"
473 }
474}
475
476impl ChannelPruner {
477 fn calculate_channel_importance(&self, weights: &Tensor) -> Result<Vec<f32>> {
478 let shape = weights.shape();
479 let num_channels = shape[1];
480 let channel_size = shape[2] * shape[3];
481 let data = weights.data()?;
482 let mut importance = vec![0.0; num_channels];
483
484 for (channel, importance_ref) in importance.iter_mut().enumerate() {
485 let mut channel_score = 0.0;
486 let mut count = 0;
487
488 for batch in 0..shape[0] {
489 let start_idx = batch * num_channels * channel_size + channel * channel_size;
490 let end_idx = start_idx + channel_size;
491
492 for data_ref in data.iter().take(end_idx.min(data.len())).skip(start_idx) {
493 match self.importance_metric {
494 ChannelImportanceMetric::L1Norm => channel_score += data_ref.abs(),
495 ChannelImportanceMetric::L2Norm => channel_score += data_ref * data_ref,
496 ChannelImportanceMetric::MeanActivation => channel_score += data_ref.abs(),
497 ChannelImportanceMetric::GeometricMedian => channel_score += data_ref.abs(),
498 }
499 count += 1;
500 }
501 }
502
503 *importance_ref = match self.importance_metric {
504 ChannelImportanceMetric::L2Norm => (channel_score / count as f32).sqrt(),
505 _ => channel_score / count as f32,
506 };
507 }
508
509 Ok(importance)
510 }
511}
512
513pub struct FilterPruner {
515 importance_metric: FilterImportanceMetric,
516}
517
518#[derive(Debug, Clone)]
519pub enum FilterImportanceMetric {
520 L1Norm,
522 L2Norm,
524 APoZ,
526}
527
528impl FilterPruner {
529 pub fn new(metric: FilterImportanceMetric) -> Self {
530 Self {
531 importance_metric: metric,
532 }
533 }
534}
535
536impl PruningStrategy for FilterPruner {
537 fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
538 let shape = weights.shape();
539 if shape.len() != 4 {
540 return Err(anyhow!("Filter pruning requires 4D tensors (NCHW format)"));
541 }
542
543 let num_filters = shape[0]; let filter_importance = self.calculate_filter_importance(weights)?;
545
546 let num_prune = (num_filters as f32 * config.target_sparsity) as usize;
548 let mut sorted_filters: Vec<(f32, usize)> =
549 filter_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
550 sorted_filters.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
551
552 let pruned_filters: HashSet<usize> =
553 sorted_filters.iter().take(num_prune).map(|(_, idx)| *idx).collect();
554
555 let data = weights.data()?;
557 let mut pruned_data = data.to_vec();
558 let filter_size = shape[1] * shape[2] * shape[3]; for filter_idx in &pruned_filters {
561 let start_idx = filter_idx * filter_size;
562 let end_idx = start_idx + filter_size;
563 for i in start_idx..end_idx.min(pruned_data.len()) {
564 pruned_data[i] = 0.0;
565 }
566 }
567
568 Ok(Tensor::from_vec(pruned_data, &shape)?)
569 }
570
571 fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
572 let shape = weights.shape();
573 let num_filters = shape[0];
574 let filter_importance = self.calculate_filter_importance(weights)?;
575
576 let num_prune = (num_filters as f32 * config.target_sparsity) as usize;
577 let mut sorted_filters: Vec<(f32, usize)> =
578 filter_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
579 sorted_filters.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
580
581 let pruned_filters: HashSet<usize> =
582 sorted_filters.iter().take(num_prune).map(|(_, idx)| *idx).collect();
583
584 let data = weights.data()?;
585 let mut mask = vec![1.0; data.len()];
586 let filter_size = shape[1] * shape[2] * shape[3];
587
588 for filter_idx in &pruned_filters {
589 let start_idx = filter_idx * filter_size;
590 let end_idx = start_idx + filter_size;
591 for i in start_idx..end_idx.min(mask.len()) {
592 mask[i] = 0.0;
593 }
594 }
595
596 Ok(Tensor::from_vec(mask, &shape)?)
597 }
598
599 fn name(&self) -> &str {
600 "FilterPruner"
601 }
602}
603
604impl FilterPruner {
605 fn calculate_filter_importance(&self, weights: &Tensor) -> Result<Vec<f32>> {
606 let shape = weights.shape();
607 let num_filters = shape[0];
608 let filter_size = shape[1] * shape[2] * shape[3];
609 let data = weights.data()?;
610 let mut importance = vec![0.0; num_filters];
611
612 for (filter, importance_ref) in importance.iter_mut().enumerate() {
613 let start_idx = filter * filter_size;
614 let end_idx = start_idx + filter_size;
615 let mut filter_score = 0.0;
616
617 for data_ref in data.iter().take(end_idx.min(data.len())).skip(start_idx) {
618 match self.importance_metric {
619 FilterImportanceMetric::L1Norm => filter_score += data_ref.abs(),
620 FilterImportanceMetric::L2Norm => filter_score += data_ref * data_ref,
621 FilterImportanceMetric::APoZ => {
622 filter_score += if *data_ref == 0.0 { 1.0 } else { 0.0 }
623 },
624 }
625 }
626
627 *importance_ref = match self.importance_metric {
628 FilterImportanceMetric::L2Norm => filter_score.sqrt(),
629 FilterImportanceMetric::APoZ => filter_score / filter_size as f32,
630 _ => filter_score,
631 };
632 }
633
634 Ok(importance)
635 }
636}
637
638pub struct HeadPruner {
640 num_heads: usize,
641 head_dim: usize,
642}
643
644impl HeadPruner {
645 pub fn new(num_heads: usize, head_dim: usize) -> Self {
646 Self {
647 num_heads,
648 head_dim,
649 }
650 }
651}
652
653impl PruningStrategy for HeadPruner {
654 fn prune_weights(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
655 let shape = weights.shape();
656 if shape.len() != 2 {
657 return Err(anyhow!(
658 "Head pruning requires 2D tensors (attention weight matrices)"
659 ));
660 }
661
662 let num_prune = (self.num_heads as f32 * config.target_sparsity) as usize;
664 let head_importance = self.calculate_head_importance(weights)?;
665
666 let mut sorted_heads: Vec<(f32, usize)> =
667 head_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
668 sorted_heads.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
669
670 let pruned_heads: HashSet<usize> =
671 sorted_heads.iter().take(num_prune).map(|(_, idx)| *idx).collect();
672
673 let data = weights.data()?;
675 let mut pruned_data = data.to_vec();
676
677 for head_idx in &pruned_heads {
679 let start_col = head_idx * self.head_dim;
680 let end_col = start_col + self.head_dim;
681
682 for row in 0..shape[0] {
683 for col in start_col..end_col.min(shape[1]) {
684 let idx = row * shape[1] + col;
685 if idx < pruned_data.len() {
686 pruned_data[idx] = 0.0;
687 }
688 }
689 }
690 }
691
692 Ok(Tensor::from_vec(pruned_data, &shape)?)
693 }
694
695 fn get_mask(&self, weights: &Tensor, config: &PruningConfig) -> Result<Tensor> {
696 let shape = weights.shape();
697 let num_prune = (self.num_heads as f32 * config.target_sparsity) as usize;
698 let head_importance = self.calculate_head_importance(weights)?;
699
700 let mut sorted_heads: Vec<(f32, usize)> =
701 head_importance.iter().enumerate().map(|(i, &score)| (score, i)).collect();
702 sorted_heads.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
703
704 let pruned_heads: HashSet<usize> =
705 sorted_heads.iter().take(num_prune).map(|(_, idx)| *idx).collect();
706
707 let data = weights.data()?;
708 let mut mask = vec![1.0; data.len()];
709
710 for head_idx in &pruned_heads {
711 let start_col = head_idx * self.head_dim;
712 let end_col = start_col + self.head_dim;
713
714 for row in 0..shape[0] {
715 for col in start_col..end_col.min(shape[1]) {
716 let idx = row * shape[1] + col;
717 if idx < mask.len() {
718 mask[idx] = 0.0;
719 }
720 }
721 }
722 }
723
724 Ok(Tensor::from_vec(mask, &shape)?)
725 }
726
727 fn name(&self) -> &str {
728 "HeadPruner"
729 }
730}
731
732impl HeadPruner {
733 fn calculate_head_importance(&self, weights: &Tensor) -> Result<Vec<f32>> {
734 let shape = weights.shape();
735 let data = weights.data()?;
736 let mut importance = vec![0.0; self.num_heads];
737
738 for (head, importance_ref) in importance.iter_mut().enumerate() {
739 let start_col = head * self.head_dim;
740 let end_col = start_col + self.head_dim;
741 let mut head_score = 0.0;
742 let mut count = 0;
743
744 for row in 0..shape[0] {
745 for col in start_col..end_col.min(shape[1]) {
746 let idx = row * shape[1] + col;
747 if idx < data.len() {
748 head_score += data[idx] * data[idx]; count += 1;
750 }
751 }
752 }
753
754 *importance_ref = if count > 0 { (head_score / count as f32).sqrt() } else { 0.0 };
755 }
756
757 Ok(importance)
758 }
759}
760
761pub struct LayerPruner {
763 layer_importance: HashMap<String, f32>,
764}
765
766impl Default for LayerPruner {
767 fn default() -> Self {
768 Self::new()
769 }
770}
771
772impl LayerPruner {
773 pub fn new() -> Self {
774 Self {
775 layer_importance: HashMap::new(),
776 }
777 }
778
779 pub fn with_importance_scores(scores: HashMap<String, f32>) -> Self {
780 Self {
781 layer_importance: scores,
782 }
783 }
784
785 pub fn analyze_model<M>(&mut self, model: &M) -> Result<()>
787 where
788 M: crate::traits::Model,
789 {
790 let total_params = model.num_parameters();
794
795 let typical_layers = vec![
797 ("embedding".to_string(), 0.8),
798 ("attention_0".to_string(), 0.6),
799 ("feedforward_0".to_string(), 0.4),
800 ("attention_1".to_string(), 0.5),
801 ("feedforward_1".to_string(), 0.3),
802 ("output".to_string(), 0.9),
803 ];
804
805 for (name, importance) in typical_layers {
806 self.layer_importance.insert(name, importance * total_params as f32);
807 }
808
809 Ok(())
810 }
811
812 pub fn get_pruning_candidates(&self, config: &PruningConfig) -> Result<Vec<String>> {
814 let total_layers = self.layer_importance.len();
815 let num_prune = (total_layers as f32 * config.target_sparsity) as usize;
816
817 let mut sorted_layers: Vec<(f32, String)> = self
819 .layer_importance
820 .iter()
821 .map(|(name, &score)| (score, name.clone()))
822 .collect();
823 sorted_layers.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Partial comparison failed"));
824
825 let pruned_layers: Vec<String> = sorted_layers
826 .iter()
827 .take(num_prune)
828 .map(|(_, name)| name.clone())
829 .filter(|name| !config.exclude_layers.contains(name))
830 .collect();
831
832 Ok(pruned_layers)
833 }
834}
835
836pub struct AutomaticPruner {
838 strategies: HashMap<String, Box<dyn PruningStrategy>>,
839 default_strategy: Box<dyn PruningStrategy>,
840}
841
842impl AutomaticPruner {
843 pub fn new() -> Self {
844 let mut strategies = HashMap::new();
845
846 strategies.insert(
848 "conv".to_string(),
849 Box::new(FilterPruner::new(FilterImportanceMetric::L2Norm)) as Box<dyn PruningStrategy>,
850 );
851 strategies.insert(
852 "attention".to_string(),
853 Box::new(HeadPruner::new(12, 64)) as Box<dyn PruningStrategy>,
854 );
855 strategies.insert(
856 "linear".to_string(),
857 Box::new(MagnitudePruner::new(0.5)) as Box<dyn PruningStrategy>,
858 );
859
860 let default_strategy = Box::new(MagnitudePruner::new(0.5));
861
862 Self {
863 strategies,
864 default_strategy,
865 }
866 }
867
868 pub fn with_strategy(mut self, layer_type: String, strategy: Box<dyn PruningStrategy>) -> Self {
869 self.strategies.insert(layer_type, strategy);
870 self
871 }
872
873 pub fn with_default_strategy(mut self, strategy: Box<dyn PruningStrategy>) -> Self {
874 self.default_strategy = strategy;
875 self
876 }
877
878 #[allow(dead_code)]
879 fn detect_layer_type(&self, layer_name: &str) -> String {
880 let name_lower = layer_name.to_lowercase();
881
882 if name_lower.contains("conv") {
883 "conv".to_string()
884 } else if name_lower.contains("attention") || name_lower.contains("attn") {
885 "attention".to_string()
886 } else if name_lower.contains("linear")
887 || name_lower.contains("dense")
888 || name_lower.contains("fc")
889 {
890 "linear".to_string()
891 } else if name_lower.contains("embed") {
892 "embedding".to_string()
893 } else {
894 "unknown".to_string()
895 }
896 }
897}
898
899impl Pruner for AutomaticPruner {
900 fn prune<M>(&self, model: M, config: &PruningConfig) -> Result<PruningResult<M>>
901 where
902 M: crate::traits::Model + Clone,
903 {
904 let total_params = model.num_parameters();
906 let estimated_pruned_params = (total_params as f32 * config.target_sparsity) as usize;
907
908 let mut layer_sparsity = HashMap::new();
910 let simulated_layers = vec![
911 ("embedding", 0.2), ("attention", 0.4), ("feedforward", 0.6), ("output", 0.1), ];
916
917 for (layer_type, base_sparsity) in simulated_layers {
918 let actual_sparsity = (base_sparsity * config.target_sparsity).min(0.9);
920 layer_sparsity.insert(layer_type.to_string(), actual_sparsity);
921 }
922
923 let overall_sparsity = config.target_sparsity;
924
925 let pruned_model = model;
928
929 Ok(PruningResult {
930 model: pruned_model,
931 sparsity: overall_sparsity,
932 pruned_params: estimated_pruned_params,
933 total_params,
934 layer_sparsity,
935 })
936 }
937
938 fn estimate_pruning_potential<M>(
939 &self,
940 model: &M,
941 config: &PruningConfig,
942 ) -> Result<PruningStats>
943 where
944 M: crate::traits::Model,
945 {
946 let total_params = model.num_parameters();
947 let estimated_zero_params = (total_params as f32 * config.target_sparsity) as usize;
948
949 let mut layer_stats = HashMap::new();
951 let simulated_layers = vec![
952 ("embedding", 0.15),
953 ("attention", 0.30),
954 ("feedforward", 0.45),
955 ("output", 0.05),
956 ];
957
958 for (layer_name, param_fraction) in simulated_layers {
959 let layer_total = (total_params as f32 * param_fraction) as usize;
960 let layer_zeros = (layer_total as f32 * config.target_sparsity) as usize;
961 let layer_sparsity =
962 if layer_total > 0 { layer_zeros as f32 / layer_total as f32 } else { 0.0 };
963
964 layer_stats.insert(
965 layer_name.to_string(),
966 LayerPruningStats {
967 total_params: layer_total,
968 zero_params: layer_zeros,
969 sparsity: layer_sparsity,
970 },
971 );
972 }
973
974 let overall_sparsity = if total_params > 0 {
975 estimated_zero_params as f32 / total_params as f32
976 } else {
977 0.0
978 };
979
980 Ok(PruningStats {
981 total_params,
982 zero_params: estimated_zero_params,
983 sparsity: overall_sparsity,
984 layer_stats,
985 })
986 }
987}
988
989impl Default for AutomaticPruner {
990 fn default() -> Self {
991 Self::new()
992 }
993}
994
995pub struct PruningUtils;
997
998impl PruningUtils {
999 pub fn calculate_layer_sensitivities<M>(
1001 model: &M,
1002 _validation_data: &[Tensor],
1003 ) -> Result<HashMap<String, f32>>
1004 where
1005 M: crate::traits::Model,
1006 {
1007 let mut sensitivities = HashMap::new();
1008
1009 let _total_params = model.num_parameters(); let typical_sensitivities = vec![
1014 ("embedding".to_string(), 0.95), ("attention".to_string(), 0.75), ("feedforward".to_string(), 0.50), ("output".to_string(), 0.90), ("classifier".to_string(), 0.90), ];
1020
1021 for (layer_name, sensitivity) in typical_sensitivities {
1022 sensitivities.insert(layer_name, sensitivity);
1023 }
1024
1025 Ok(sensitivities)
1026 }
1027
1028 pub fn generate_pruning_schedule(
1030 initial_sparsity: f32,
1031 final_sparsity: f32,
1032 num_steps: usize,
1033 ) -> Vec<f32> {
1034 let mut schedule = Vec::new();
1035
1036 for i in 0..num_steps {
1037 let progress = i as f32 / (num_steps - 1) as f32;
1038 let cubic_progress = progress * progress * progress;
1040 let sparsity = initial_sparsity + (final_sparsity - initial_sparsity) * cubic_progress;
1041 schedule.push(sparsity);
1042 }
1043
1044 schedule
1045 }
1046
1047 pub fn estimate_compression_ratio(target_sparsity: f32, quantization_bits: Option<u8>) -> f32 {
1049 let sparsity_compression = 1.0 / (1.0 - target_sparsity);
1050
1051 match quantization_bits {
1052 Some(bits) => sparsity_compression * (32.0 / bits as f32), None => sparsity_compression,
1054 }
1055 }
1056
1057 pub fn validate_config(config: &PruningConfig) -> Result<()> {
1059 if config.target_sparsity < 0.0 || config.target_sparsity > 1.0 {
1060 return Err(anyhow!("Target sparsity must be between 0.0 and 1.0"));
1061 }
1062
1063 if config.iterations == 0 {
1064 return Err(anyhow!("Number of iterations must be greater than 0"));
1065 }
1066
1067 if let Some(threshold) = config.magnitude_threshold {
1068 if threshold < 0.0 {
1069 return Err(anyhow!("Magnitude threshold must be non-negative"));
1070 }
1071 }
1072
1073 Ok(())
1074 }
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079 use super::*;
1080
1081 #[test]
1082 fn test_pruning_config_default() {
1083 let config = PruningConfig::default();
1084 assert_eq!(config.target_sparsity, 0.5);
1085 assert!(!config.iterative);
1086 assert_eq!(config.iterations, 1);
1087 assert!(config.fine_tune);
1088 }
1089
1090 #[test]
1091 fn test_magnitude_pruner() -> Result<()> {
1092 let pruner = MagnitudePruner::new(0.5);
1093 let weights = Tensor::from_vec(vec![0.1, -0.8, 0.3, -0.2, 0.9, -0.1], &[2, 3])?;
1094 let config = PruningConfig {
1095 target_sparsity: 0.5,
1096 ..Default::default()
1097 };
1098
1099 let mask = pruner.get_mask(&weights, &config)?;
1100 let mask_data = mask.data()?;
1101 let zero_count = mask_data.iter().filter(|&&x| x == 0.0).count();
1102
1103 assert_eq!(zero_count, 3);
1105 Ok(())
1106 }
1107
1108 #[test]
1109 fn test_pruning_utils_validation() {
1110 let valid_config = PruningConfig::default();
1111 assert!(PruningUtils::validate_config(&valid_config).is_ok());
1112
1113 let invalid_config = PruningConfig {
1114 target_sparsity: 1.5, ..Default::default()
1116 };
1117 assert!(PruningUtils::validate_config(&invalid_config).is_err());
1118 }
1119
1120 #[test]
1121 fn test_compression_ratio_estimation() {
1122 let ratio = PruningUtils::estimate_compression_ratio(0.5, None);
1123 assert_eq!(ratio, 2.0); let ratio_with_quant = PruningUtils::estimate_compression_ratio(0.5, Some(8));
1126 assert_eq!(ratio_with_quant, 8.0); }
1128
1129 #[test]
1130 fn test_pruning_schedule() {
1131 let schedule = PruningUtils::generate_pruning_schedule(0.0, 0.8, 5);
1132 assert_eq!(schedule.len(), 5);
1133 assert_eq!(schedule[0], 0.0);
1134 assert_eq!(schedule[4], 0.8);
1135 for i in 1..schedule.len() {
1137 assert!(schedule[i] >= schedule[i - 1]);
1138 }
1139 }
1140}