Skip to main content

voirs_recognizer/asr/
advanced_optimization.rs

1//! Advanced model optimization techniques
2//!
3//! This module provides state-of-the-art optimization techniques including
4//! knowledge distillation, progressive pruning, mixed-precision optimization,
5//! and benchmark-driven optimization selection.
6
7use crate::RecognitionError;
8use candle_core::{DType, Device, Tensor};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::Instant;
12use tracing::{debug, info};
13
14/// Advanced optimization configuration
15#[derive(Debug, Clone, Serialize, Deserialize)]
16/// Advanced Optimization Config
17pub struct AdvancedOptimizationConfig {
18    /// Enable knowledge distillation
19    pub enable_knowledge_distillation: bool,
20    /// Knowledge distillation temperature
21    pub distillation_temperature: f32,
22    /// Knowledge distillation loss weight
23    pub distillation_alpha: f32,
24
25    /// Enable progressive pruning
26    pub enable_progressive_pruning: bool,
27    /// Initial pruning ratio
28    pub initial_pruning_ratio: f32,
29    /// Final pruning ratio
30    pub final_pruning_ratio: f32,
31    /// Number of progressive pruning steps
32    pub pruning_steps: usize,
33
34    /// Enable mixed-precision optimization
35    pub enable_mixed_precision: bool,
36    /// Automatic precision selection
37    pub auto_precision_selection: bool,
38    /// Performance budget (RTF threshold)
39    pub performance_budget: f32,
40    /// Accuracy budget (minimum accuracy retention)
41    pub accuracy_budget: f32,
42
43    /// Enable quantization-aware training simulation
44    pub enable_qat_simulation: bool,
45    /// QAT simulation iterations
46    pub qat_iterations: usize,
47    /// QAT learning rate
48    pub qat_learning_rate: f32,
49
50    /// Enable benchmark-driven optimization
51    pub enable_benchmark_optimization: bool,
52    /// Target hardware platform
53    pub target_platform: OptimizationPlatform,
54    /// Optimization objective
55    pub optimization_objective: OptimizationObjective,
56}
57
58impl Default for AdvancedOptimizationConfig {
59    fn default() -> Self {
60        Self {
61            enable_knowledge_distillation: false,
62            distillation_temperature: 4.0,
63            distillation_alpha: 0.7,
64
65            enable_progressive_pruning: false,
66            initial_pruning_ratio: 0.1,
67            final_pruning_ratio: 0.5,
68            pruning_steps: 10,
69
70            enable_mixed_precision: true,
71            auto_precision_selection: true,
72            performance_budget: 0.3, // RTF < 0.3
73            accuracy_budget: 0.95,   // Retain 95% accuracy
74
75            enable_qat_simulation: false,
76            qat_iterations: 100,
77            qat_learning_rate: 0.001,
78
79            enable_benchmark_optimization: true,
80            target_platform: OptimizationPlatform::CPU,
81            optimization_objective: OptimizationObjective::Balanced,
82        }
83    }
84}
85
86/// Target platform for optimization
87#[derive(Debug, Clone, Serialize, Deserialize)]
88/// Optimization Platform
89pub enum OptimizationPlatform {
90    /// C p u
91    CPU,
92    /// G p u
93    GPU,
94    /// Mobile
95    Mobile,
96    /// Edge
97    Edge,
98    /// Server
99    Server,
100}
101
102/// Optimization objective
103#[derive(Debug, Clone, Serialize, Deserialize)]
104/// Optimization Objective
105pub enum OptimizationObjective {
106    /// Minimize latency
107    Latency,
108    /// Minimize memory usage
109    Memory,
110    /// Minimize model size
111    Size,
112    /// Balance all metrics
113    Balanced,
114    /// Maximize throughput
115    Throughput,
116}
117
118/// Knowledge distillation trainer
119#[derive(Debug)]
120/// Knowledge Distillation Optimizer
121pub struct KnowledgeDistillationOptimizer {
122    /// Teacher model reference (larger, accurate model)
123    teacher_layers: HashMap<String, Tensor>,
124    /// Student model reference (smaller, optimized model)
125    student_layers: HashMap<String, Tensor>,
126    /// Configuration
127    config: AdvancedOptimizationConfig,
128    /// Device
129    device: Device,
130    /// Distillation statistics
131    distillation_stats: DistillationStats,
132}
133
134/// Knowledge distillation statistics
135#[derive(Debug, Clone)]
136/// Distillation Stats
137pub struct DistillationStats {
138    /// Teacher-student loss over time
139    pub loss_history: Vec<f32>,
140    /// Knowledge transfer efficiency
141    pub transfer_efficiency: f32,
142    /// Layer-wise distillation effectiveness
143    pub layer_effectiveness: HashMap<String, f32>,
144    /// Temperature sensitivity analysis
145    pub temperature_sensitivity: Vec<(f32, f32)>, // (temperature, accuracy)
146}
147
148impl KnowledgeDistillationOptimizer {
149    /// Create new knowledge distillation optimizer
150    #[must_use]
151    pub fn new(config: AdvancedOptimizationConfig, device: Device) -> Self {
152        Self {
153            teacher_layers: HashMap::new(),
154            student_layers: HashMap::new(),
155            config,
156            device,
157            distillation_stats: DistillationStats {
158                loss_history: Vec::new(),
159                transfer_efficiency: 0.0,
160                layer_effectiveness: HashMap::new(),
161                temperature_sensitivity: Vec::new(),
162            },
163        }
164    }
165
166    /// Set teacher model layers
167    pub fn set_teacher_layers(&mut self, layers: HashMap<String, Tensor>) {
168        self.teacher_layers = layers;
169        info!(
170            "Set teacher model with {} layers",
171            self.teacher_layers.len()
172        );
173    }
174
175    /// Set student model layers
176    pub fn set_student_layers(&mut self, layers: HashMap<String, Tensor>) {
177        self.student_layers = layers;
178        info!(
179            "Set student model with {} layers",
180            self.student_layers.len()
181        );
182    }
183
184    /// Compute knowledge distillation loss
185    pub fn compute_distillation_loss(
186        &self,
187        teacher_logits: &Tensor,
188        student_logits: &Tensor,
189    ) -> Result<f32, RecognitionError> {
190        let temperature = self.config.distillation_temperature;
191
192        // Apply temperature scaling
193        let teacher_soft = self.apply_temperature_scaling(teacher_logits, temperature)?;
194        let student_soft = self.apply_temperature_scaling(student_logits, temperature)?;
195
196        // Compute KL divergence loss
197        let kl_loss = self.compute_kl_divergence(&teacher_soft, &student_soft)?;
198
199        debug!("Knowledge distillation loss: {:.6}", kl_loss);
200        Ok(kl_loss)
201    }
202
203    /// Apply temperature scaling for knowledge distillation
204    fn apply_temperature_scaling(
205        &self,
206        logits: &Tensor,
207        temperature: f32,
208    ) -> Result<Tensor, RecognitionError> {
209        let temp_tensor = Tensor::new(temperature, &self.device)?;
210        let scaled = logits.div(&temp_tensor)?;
211        let softmax = candle_nn::ops::softmax(&scaled, 1)?;
212        Ok(softmax)
213    }
214
215    /// Compute KL divergence between teacher and student distributions
216    fn compute_kl_divergence(
217        &self,
218        teacher: &Tensor,
219        student: &Tensor,
220    ) -> Result<f32, RecognitionError> {
221        // KL(P||Q) = sum(P * log(P/Q))
222        let log_ratio = teacher.div(student)?.log()?;
223        let kl = teacher.mul(&log_ratio)?.sum_all()?.to_scalar::<f32>()?;
224        Ok(kl)
225    }
226
227    /// Perform intermediate layer distillation
228    pub fn distill_intermediate_layers(
229        &mut self,
230    ) -> Result<HashMap<String, f32>, RecognitionError> {
231        let mut layer_losses = HashMap::new();
232
233        for (layer_name, teacher_features) in &self.teacher_layers {
234            if let Some(student_features) = self.student_layers.get(layer_name) {
235                // Compute feature matching loss
236                let feature_loss =
237                    self.compute_feature_matching_loss(teacher_features, student_features)?;
238                layer_losses.insert(layer_name.clone(), feature_loss);
239
240                // Update layer effectiveness
241                self.distillation_stats.layer_effectiveness.insert(
242                    layer_name.clone(),
243                    1.0 / (1.0 + feature_loss), // Effectiveness inversely related to loss
244                );
245            }
246        }
247
248        info!("Distilled {} intermediate layers", layer_losses.len());
249        Ok(layer_losses)
250    }
251
252    /// Compute feature matching loss between teacher and student features
253    fn compute_feature_matching_loss(
254        &self,
255        teacher: &Tensor,
256        student: &Tensor,
257    ) -> Result<f32, RecognitionError> {
258        // MSE loss between features
259        let diff = teacher.sub(student)?;
260        let squared_diff = diff.sqr()?;
261        let mse_loss = squared_diff.mean_all()?.to_scalar::<f32>()?;
262        Ok(mse_loss)
263    }
264
265    /// Analyze temperature sensitivity
266    pub async fn analyze_temperature_sensitivity(
267        &mut self,
268        temperatures: Vec<f32>,
269        validation_data: &[Tensor],
270    ) -> Result<(), RecognitionError> {
271        info!(
272            "Analyzing temperature sensitivity with {} temperature values",
273            temperatures.len()
274        );
275
276        for &temperature in &temperatures {
277            // Simulate distillation with this temperature
278            let mut total_accuracy = 0.0;
279
280            for data in validation_data {
281                // Mock teacher and student logits for demonstration
282                let teacher_logits = data.clone();
283                let student_logits = data.clone(); // In real implementation, this would be student output
284
285                let teacher_soft = self.apply_temperature_scaling(&teacher_logits, temperature)?;
286                let student_soft = self.apply_temperature_scaling(&student_logits, temperature)?;
287
288                // Compute accuracy (simplified)
289                let accuracy = self.compute_prediction_accuracy(&teacher_soft, &student_soft)?;
290                total_accuracy += accuracy;
291            }
292
293            let avg_accuracy = total_accuracy / validation_data.len() as f32;
294            self.distillation_stats
295                .temperature_sensitivity
296                .push((temperature, avg_accuracy));
297
298            debug!(
299                "Temperature {:.1}: accuracy {:.3}",
300                temperature, avg_accuracy
301            );
302        }
303
304        Ok(())
305    }
306
307    /// Compute prediction accuracy between teacher and student
308    fn compute_prediction_accuracy(
309        &self,
310        teacher: &Tensor,
311        student: &Tensor,
312    ) -> Result<f32, RecognitionError> {
313        // Simplified accuracy computation
314        let teacher_pred = teacher.argmax(1)?;
315        let student_pred = student.argmax(1)?;
316        let matches = teacher_pred
317            .eq(&student_pred)?
318            .to_dtype(DType::F32)?
319            .mean_all()?
320            .to_scalar::<f32>()?;
321        Ok(matches)
322    }
323
324    /// Get distillation statistics
325    #[must_use]
326    pub fn get_stats(&self) -> &DistillationStats {
327        &self.distillation_stats
328    }
329}
330
331/// Progressive pruning optimizer
332#[derive(Debug)]
333/// Progressive Pruning Optimizer
334pub struct ProgressivePruningOptimizer {
335    /// Configuration
336    config: AdvancedOptimizationConfig,
337    /// Current pruning step
338    current_step: usize,
339    /// Pruning schedule
340    pruning_schedule: Vec<f32>,
341    /// Layer importance scores
342    layer_importance: HashMap<String, f32>,
343    /// Pruning history
344    pruning_history: Vec<PruningStepResult>,
345    /// Device
346    device: Device,
347}
348
349/// Result of a pruning step
350#[derive(Debug, Clone)]
351/// Pruning Step Result
352pub struct PruningStepResult {
353    /// Step number
354    pub step: usize,
355    /// Pruning ratio applied
356    pub pruning_ratio: f32,
357    /// Accuracy after pruning
358    pub accuracy: f32,
359    /// Model size reduction
360    pub size_reduction: f32,
361    /// Speedup achieved
362    pub speedup: f32,
363    /// Recovery iterations needed
364    pub recovery_iterations: usize,
365}
366
367impl ProgressivePruningOptimizer {
368    /// Create new progressive pruning optimizer
369    #[must_use]
370    pub fn new(config: AdvancedOptimizationConfig, device: Device) -> Self {
371        let pruning_schedule = Self::create_pruning_schedule(&config);
372
373        Self {
374            config,
375            current_step: 0,
376            pruning_schedule,
377            layer_importance: HashMap::new(),
378            pruning_history: Vec::new(),
379            device,
380        }
381    }
382
383    /// Create pruning schedule
384    fn create_pruning_schedule(config: &AdvancedOptimizationConfig) -> Vec<f32> {
385        let steps = config.pruning_steps;
386        let initial = config.initial_pruning_ratio;
387        let final_ratio = config.final_pruning_ratio;
388
389        (0..steps)
390            .map(|i| {
391                let progress = i as f32 / (steps - 1) as f32;
392                initial + (final_ratio - initial) * progress
393            })
394            .collect()
395    }
396
397    /// Compute layer importance scores
398    pub fn compute_layer_importance(
399        &mut self,
400        model_layers: &HashMap<String, Tensor>,
401    ) -> Result<(), RecognitionError> {
402        info!(
403            "Computing layer importance scores for {} layers",
404            model_layers.len()
405        );
406
407        for (layer_name, weights) in model_layers {
408            // Use magnitude-based importance (L1 norm)
409            let magnitude_sum = weights.abs()?.sum_all()?.to_scalar::<f32>()?;
410            let num_params = weights.elem_count();
411            let importance = magnitude_sum / num_params as f32;
412
413            self.layer_importance.insert(layer_name.clone(), importance);
414            debug!("Layer {} importance: {:.6}", layer_name, importance);
415        }
416
417        Ok(())
418    }
419
420    /// Execute progressive pruning step
421    pub fn execute_pruning_step(
422        &mut self,
423        model_layers: &mut HashMap<String, Tensor>,
424        validation_fn: impl Fn(&HashMap<String, Tensor>) -> Result<f32, RecognitionError>,
425    ) -> Result<PruningStepResult, RecognitionError> {
426        if self.current_step >= self.pruning_schedule.len() {
427            return Err(RecognitionError::ModelError {
428                message: "All pruning steps completed".to_string(),
429                source: None,
430            });
431        }
432
433        let target_ratio = self.pruning_schedule[self.current_step];
434        info!(
435            "Executing pruning step {}: target ratio {:.2}",
436            self.current_step + 1,
437            target_ratio
438        );
439
440        let start_time = Instant::now();
441
442        // Measure baseline accuracy
443        let baseline_accuracy = validation_fn(model_layers)?;
444        let baseline_size = self.compute_model_size(model_layers);
445
446        // Apply structured pruning
447        let _pruned_params = self.apply_structured_pruning(model_layers, target_ratio)?;
448
449        // Measure post-pruning accuracy
450        let pruned_accuracy = validation_fn(model_layers)?;
451        let pruned_size = self.compute_model_size(model_layers);
452
453        // Simulate recovery iterations (in real implementation, this would involve fine-tuning)
454        let recovery_iterations =
455            self.simulate_recovery_training(model_layers, baseline_accuracy)?;
456
457        let processing_time = start_time.elapsed();
458        let speedup = self.estimate_inference_speedup(target_ratio);
459        let size_reduction = (baseline_size - pruned_size) / baseline_size;
460
461        let result = PruningStepResult {
462            step: self.current_step + 1,
463            pruning_ratio: target_ratio,
464            accuracy: pruned_accuracy,
465            size_reduction,
466            speedup,
467            recovery_iterations,
468        };
469
470        self.pruning_history.push(result.clone());
471        self.current_step += 1;
472
473        info!(
474            "Pruning step completed in {:.2}s: {:.1}% accuracy, {:.1}% size reduction",
475            processing_time.as_secs_f32(),
476            pruned_accuracy * 100.0,
477            size_reduction * 100.0
478        );
479
480        Ok(result)
481    }
482
483    /// Apply structured pruning to model layers
484    fn apply_structured_pruning(
485        &self,
486        model_layers: &mut HashMap<String, Tensor>,
487        target_ratio: f32,
488    ) -> Result<usize, RecognitionError> {
489        let mut total_pruned = 0;
490
491        for (layer_name, weights) in model_layers.iter_mut() {
492            if let Some(&importance) = self.layer_importance.get(layer_name) {
493                // Apply pruning based on importance (lower importance = more pruning)
494                let layer_ratio = target_ratio * (1.0 - importance).max(0.1);
495                let pruned = self.prune_layer_structured(weights, layer_ratio)?;
496                total_pruned += pruned;
497
498                debug!(
499                    "Pruned {} parameters from layer {} (ratio: {:.3})",
500                    pruned, layer_name, layer_ratio
501                );
502            }
503        }
504
505        Ok(total_pruned)
506    }
507
508    /// Prune individual layer with structured approach
509    fn prune_layer_structured(
510        &self,
511        weights: &mut Tensor,
512        ratio: f32,
513    ) -> Result<usize, RecognitionError> {
514        let shape = weights.shape();
515        let total_params = shape.elem_count();
516        let params_to_prune = (total_params as f32 * ratio) as usize;
517
518        // For structured pruning, we would typically prune entire channels/filters
519        // This is a simplified implementation
520        let magnitude = weights.abs()?.sum_keepdim(1)?;
521        let threshold = self.compute_pruning_threshold(&magnitude, ratio)?;
522
523        // Create mask for pruning
524        let threshold_tensor = Tensor::new(threshold, &self.device)?;
525        let mask = magnitude.gt(&threshold_tensor)?;
526        *weights = weights.mul(&mask)?;
527
528        Ok(params_to_prune)
529    }
530
531    /// Compute pruning threshold based on magnitude distribution
532    fn compute_pruning_threshold(
533        &self,
534        magnitudes: &Tensor,
535        ratio: f32,
536    ) -> Result<f32, RecognitionError> {
537        // Find the magnitude value that corresponds to the pruning ratio
538        let flat = magnitudes.flatten_all()?;
539        let values: Vec<f32> = flat.to_vec1()?;
540        let mut sorted_values = values;
541        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
542
543        let threshold_index = (sorted_values.len() as f32 * ratio) as usize;
544        let threshold = sorted_values.get(threshold_index).copied().unwrap_or(0.0);
545
546        Ok(threshold)
547    }
548
549    /// Simulate recovery training iterations
550    fn simulate_recovery_training(
551        &self,
552        _model_layers: &HashMap<String, Tensor>,
553        target_accuracy: f32,
554    ) -> Result<usize, RecognitionError> {
555        // Simulate the number of iterations needed to recover accuracy
556        // This is a heuristic based on pruning ratio
557        let pruning_ratio = self.pruning_schedule[self.current_step];
558        let base_iterations = 10;
559        let recovery_iterations = (base_iterations as f32 * (1.0 + pruning_ratio * 2.0)) as usize;
560
561        debug!(
562            "Estimated {} recovery iterations for {:.1}% accuracy recovery",
563            recovery_iterations,
564            target_accuracy * 100.0
565        );
566
567        Ok(recovery_iterations)
568    }
569
570    /// Compute total model size
571    fn compute_model_size(&self, model_layers: &HashMap<String, Tensor>) -> f32 {
572        model_layers
573            .values()
574            .map(|tensor| tensor.elem_count() as f32)
575            .sum()
576    }
577
578    /// Estimate inference speedup from pruning
579    fn estimate_inference_speedup(&self, pruning_ratio: f32) -> f32 {
580        // Empirical relationship between pruning ratio and speedup
581        // Actual speedup depends on hardware and implementation
582        1.0 + pruning_ratio * 0.8 // Conservative estimate
583    }
584
585    /// Get pruning history
586    #[must_use]
587    pub fn get_pruning_history(&self) -> &[PruningStepResult] {
588        &self.pruning_history
589    }
590
591    /// Get current pruning progress
592    #[must_use]
593    pub fn get_progress(&self) -> (usize, usize) {
594        (self.current_step, self.pruning_schedule.len())
595    }
596}
597
598/// Mixed-precision optimizer
599#[derive(Debug)]
600/// Mixed Precision Optimizer
601pub struct MixedPrecisionOptimizer {
602    /// Configuration
603    config: AdvancedOptimizationConfig,
604    /// Layer precision assignments
605    layer_precisions: HashMap<String, DType>,
606    /// Performance measurements per precision
607    precision_performance: HashMap<DType, PerformanceMeasurement>,
608    /// Automatic precision search results
609    search_results: Vec<PrecisionSearchResult>,
610    /// Device
611    device: Device,
612}
613
614/// Performance measurement for a precision setting
615#[derive(Debug, Clone)]
616/// Performance Measurement
617pub struct PerformanceMeasurement {
618    /// Inference time (milliseconds)
619    pub inference_time_ms: f32,
620    /// Memory usage (MB)
621    pub memory_usage_mb: f32,
622    /// Accuracy
623    pub accuracy: f32,
624    /// Model size (MB)
625    pub model_size_mb: f32,
626}
627
628/// Result of precision search
629#[derive(Debug, Clone)]
630/// Precision Search Result
631pub struct PrecisionSearchResult {
632    /// Layer name
633    pub layer_name: String,
634    /// Tested precision
635    pub precision: DType,
636    /// Performance measurement
637    pub performance: PerformanceMeasurement,
638    /// Meets performance budget
639    pub meets_performance_budget: bool,
640    /// Meets accuracy budget
641    pub meets_accuracy_budget: bool,
642}
643
644impl MixedPrecisionOptimizer {
645    /// Create new mixed-precision optimizer
646    #[must_use]
647    pub fn new(config: AdvancedOptimizationConfig, device: Device) -> Self {
648        Self {
649            config,
650            layer_precisions: HashMap::new(),
651            precision_performance: HashMap::new(),
652            search_results: Vec::new(),
653            device,
654        }
655    }
656
657    /// Perform automatic precision selection
658    pub fn auto_select_precisions(
659        &mut self,
660        model_layers: &HashMap<String, Tensor>,
661        benchmark_fn: impl Fn(
662            &HashMap<String, Tensor>,
663        ) -> Result<PerformanceMeasurement, RecognitionError>,
664    ) -> Result<(), RecognitionError> {
665        info!(
666            "Starting automatic precision selection for {} layers",
667            model_layers.len()
668        );
669
670        let precisions_to_test = vec![DType::F32, DType::F16, DType::U8];
671
672        for layer_name in model_layers.keys() {
673            let mut best_precision = DType::F32;
674            let mut best_score = f32::NEG_INFINITY;
675
676            for &precision in &precisions_to_test {
677                // Simulate conversion to target precision
678                let test_layers = model_layers.clone();
679                // In real implementation, convert layer to target precision here
680
681                let performance = benchmark_fn(&test_layers)?;
682
683                // Compute optimization score based on objective
684                let score = self.compute_optimization_score(&performance, precision);
685
686                let meets_perf_budget =
687                    performance.inference_time_ms / 1000.0 <= self.config.performance_budget;
688                let meets_acc_budget = performance.accuracy >= self.config.accuracy_budget;
689
690                let search_result = PrecisionSearchResult {
691                    layer_name: layer_name.clone(),
692                    precision,
693                    performance: performance.clone(),
694                    meets_performance_budget: meets_perf_budget,
695                    meets_accuracy_budget: meets_acc_budget,
696                };
697
698                self.search_results.push(search_result);
699
700                if meets_perf_budget && meets_acc_budget && score > best_score {
701                    best_score = score;
702                    best_precision = precision;
703                }
704
705                debug!(
706                    "Layer {}, precision {:?}: score {:.3}, perf budget: {}, acc budget: {}",
707                    layer_name, precision, score, meets_perf_budget, meets_acc_budget
708                );
709            }
710
711            self.layer_precisions
712                .insert(layer_name.clone(), best_precision);
713            info!(
714                "Selected {:?} precision for layer {}",
715                best_precision, layer_name
716            );
717        }
718
719        Ok(())
720    }
721
722    /// Compute optimization score based on objective
723    fn compute_optimization_score(
724        &self,
725        performance: &PerformanceMeasurement,
726        _precision: DType,
727    ) -> f32 {
728        match self.config.optimization_objective {
729            OptimizationObjective::Latency => {
730                -performance.inference_time_ms // Lower is better
731            }
732            OptimizationObjective::Memory => {
733                -performance.memory_usage_mb // Lower is better
734            }
735            OptimizationObjective::Size => {
736                -performance.model_size_mb // Lower is better
737            }
738            OptimizationObjective::Balanced => {
739                // Weighted combination
740                let latency_score = -performance.inference_time_ms / 1000.0;
741                let memory_score = -performance.memory_usage_mb / 1000.0;
742                let accuracy_score = performance.accuracy;
743                let size_score = -performance.model_size_mb / 100.0;
744
745                0.3 * latency_score + 0.2 * memory_score + 0.4 * accuracy_score + 0.1 * size_score
746            }
747            OptimizationObjective::Throughput => {
748                1000.0 / performance.inference_time_ms // Higher throughput is better
749            }
750        }
751    }
752
753    /// Apply mixed-precision configuration to model
754    pub fn apply_mixed_precision(
755        &self,
756        model_layers: &mut HashMap<String, Tensor>,
757    ) -> Result<MixedPrecisionStats, RecognitionError> {
758        let mut stats = MixedPrecisionStats {
759            total_layers: model_layers.len(),
760            fp32_layers: 0,
761            fp16_layers: 0,
762            int8_layers: 0,
763            estimated_speedup: 1.0,
764            estimated_memory_reduction: 0.0,
765        };
766
767        for (layer_name, weights) in model_layers.iter_mut() {
768            if let Some(&target_precision) = self.layer_precisions.get(layer_name) {
769                // Convert weights to target precision
770                let converted_weights = weights.to_dtype(target_precision)?;
771                *weights = converted_weights;
772
773                // Update statistics
774                match target_precision {
775                    DType::F32 => stats.fp32_layers += 1,
776                    DType::F16 => stats.fp16_layers += 1,
777                    DType::U8 => stats.int8_layers += 1,
778                    _ => {}
779                }
780
781                debug!("Converted layer {} to {:?}", layer_name, target_precision);
782            }
783        }
784
785        // Estimate performance improvements
786        stats.estimated_speedup = self.estimate_mixed_precision_speedup(&stats);
787        stats.estimated_memory_reduction = self.estimate_memory_reduction(&stats);
788
789        info!(
790            "Applied mixed-precision: {:.1}x speedup, {:.1}% memory reduction",
791            stats.estimated_speedup,
792            stats.estimated_memory_reduction * 100.0
793        );
794
795        Ok(stats)
796    }
797
798    /// Estimate speedup from mixed-precision configuration
799    fn estimate_mixed_precision_speedup(&self, stats: &MixedPrecisionStats) -> f32 {
800        let total = stats.total_layers as f32;
801        let fp16_ratio = stats.fp16_layers as f32 / total;
802        let int8_ratio = stats.int8_layers as f32 / total;
803
804        // Empirical speedup estimates
805        1.0 + fp16_ratio * 0.4 + int8_ratio * 0.8
806    }
807
808    /// Estimate memory reduction from mixed-precision
809    fn estimate_memory_reduction(&self, stats: &MixedPrecisionStats) -> f32 {
810        let total = stats.total_layers as f32;
811        let fp16_ratio = stats.fp16_layers as f32 / total;
812        let int8_ratio = stats.int8_layers as f32 / total;
813
814        // Memory reduction estimates (FP16 = 50% reduction, INT8 = 75% reduction)
815        fp16_ratio * 0.5 + int8_ratio * 0.75
816    }
817
818    /// Get layer precision assignments
819    #[must_use]
820    pub fn get_layer_precisions(&self) -> &HashMap<String, DType> {
821        &self.layer_precisions
822    }
823
824    /// Get search results
825    #[must_use]
826    pub fn get_search_results(&self) -> &[PrecisionSearchResult] {
827        &self.search_results
828    }
829}
830
831/// Mixed-precision optimization statistics
832#[derive(Debug, Clone)]
833/// Mixed Precision Stats
834pub struct MixedPrecisionStats {
835    /// Total number of layers
836    pub total_layers: usize,
837    /// Number of FP32 layers
838    pub fp32_layers: usize,
839    /// Number of FP16 layers
840    pub fp16_layers: usize,
841    /// Number of INT8 layers
842    pub int8_layers: usize,
843    /// Estimated speedup
844    pub estimated_speedup: f32,
845    /// Estimated memory reduction (0.0 to 1.0)
846    pub estimated_memory_reduction: f32,
847}
848
849#[cfg(test)]
850mod tests {
851    use super::*;
852    use crate::asr::whisper::quantization::MovingAverageTracker;
853
854    #[test]
855    fn test_optimization_config_creation() {
856        let config = AdvancedOptimizationConfig::default();
857
858        assert!(!config.enable_knowledge_distillation);
859        assert!(config.enable_mixed_precision);
860        assert!(config.auto_precision_selection);
861        assert_eq!(config.distillation_temperature, 4.0);
862        assert_eq!(config.performance_budget, 0.3);
863        assert_eq!(config.accuracy_budget, 0.95);
864    }
865
866    #[test]
867    fn test_pruning_schedule_creation() {
868        let config = AdvancedOptimizationConfig {
869            pruning_steps: 5,
870            initial_pruning_ratio: 0.1,
871            final_pruning_ratio: 0.5,
872            ..Default::default()
873        };
874
875        let schedule = ProgressivePruningOptimizer::create_pruning_schedule(&config);
876
877        assert_eq!(schedule.len(), 5);
878        assert_eq!(schedule[0], 0.1);
879        assert_eq!(schedule[4], 0.5);
880
881        // Verify monotonic increase
882        for i in 1..schedule.len() {
883            assert!(schedule[i] >= schedule[i - 1]);
884        }
885    }
886
887    #[test]
888    fn test_moving_average_tracker() {
889        let mut tracker = MovingAverageTracker::new(3);
890
891        tracker.update(1.0, 2.0);
892        tracker.update(2.0, 3.0);
893        tracker.update(3.0, 4.0);
894
895        let (avg_min, avg_max) = tracker.get_averaged_range();
896        assert_eq!(avg_min, 2.0); // (1+2+3)/3
897        assert_eq!(avg_max, 3.0); // (2+3+4)/3
898    }
899
900    #[test]
901    fn test_mixed_precision_stats() {
902        let stats = MixedPrecisionStats {
903            total_layers: 10,
904            fp32_layers: 4,
905            fp16_layers: 4,
906            int8_layers: 2,
907            estimated_speedup: 1.0,
908            estimated_memory_reduction: 0.0,
909        };
910
911        let optimizer =
912            MixedPrecisionOptimizer::new(AdvancedOptimizationConfig::default(), Device::Cpu);
913
914        let speedup = optimizer.estimate_mixed_precision_speedup(&stats);
915        let memory_reduction = optimizer.estimate_memory_reduction(&stats);
916
917        assert!(speedup > 1.0);
918        assert!(memory_reduction > 0.0);
919        assert!(memory_reduction < 1.0);
920    }
921}