scirs2_io/ml_framework/
optimization.rs1use crate::error::Result;
4use crate::ml_framework::MLModel;
5
6#[derive(Debug, Clone)]
8pub enum OptimizationTechnique {
9 Pruning { sparsity: f32 },
11 OperatorFusion,
13 ConstantFolding,
15 GraphOptimization,
17 Distillation,
19}
20
21pub struct ModelOptimizer {
23 techniques: Vec<OptimizationTechnique>,
24}
25
26impl Default for ModelOptimizer {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl ModelOptimizer {
33 pub fn new() -> Self {
34 Self {
35 techniques: Vec::new(),
36 }
37 }
38
39 pub fn add_technique(mut self, technique: OptimizationTechnique) -> Self {
40 self.techniques.push(technique);
41 self
42 }
43
44 pub fn optimize(&self, model: &MLModel) -> Result<MLModel> {
46 let mut optimized = model.clone();
47
48 for technique in &self.techniques {
49 match technique {
50 OptimizationTechnique::Pruning { sparsity } => {
51 optimized = self.apply_pruning(optimized, *sparsity)?;
52 }
53 OptimizationTechnique::OperatorFusion => {
54 }
56 _ => {}
57 }
58 }
59
60 Ok(optimized)
61 }
62
63 fn apply_pruning(&self, mut model: MLModel, sparsity: f32) -> Result<MLModel> {
64 for (_, tensor) in model.weights.iter_mut() {
65 let data = tensor.data.as_slice_mut().unwrap();
66 let threshold = self.compute_pruning_threshold(data, sparsity);
67
68 for val in data.iter_mut() {
69 if val.abs() < threshold {
70 *val = 0.0;
71 }
72 }
73 }
74
75 Ok(model)
76 }
77
78 fn compute_pruning_threshold(&self, data: &[f32], sparsity: f32) -> f32 {
79 let mut sorted: Vec<f32> = data.iter().map(|x| x.abs()).collect();
80 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
81 let idx = (sorted.len() as f32 * sparsity) as usize;
82 sorted.get(idx).copied().unwrap_or(0.0)
83 }
84}