Skip to main content

trustformers_optim/
monitoring.rs

1//! # Optimizer Monitoring and Analysis Tools
2//!
3//! This module provides tools for monitoring optimizer performance, tracking metrics,
4//! and diagnosing optimization issues during training.
5//!
6//! ## Features
7//!
8//! - **Optimizer State Tracking**: Monitor learning rates, gradient norms, parameter changes
9//! - **Convergence Analysis**: Track loss trends, detect plateaus, measure convergence rates
10//! - **Performance Profiling**: Measure optimizer overhead and memory usage
11//! - **Debugging Tools**: Detect gradient explosions, vanishing gradients, oscillations
12
13use anyhow::{anyhow, Result};
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, VecDeque};
16use std::time::{Duration, Instant};
17use trustformers_core::tensor::Tensor;
18
19/// Configuration for optimizer monitoring.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct MonitoringConfig {
22    /// Whether to track gradient norms
23    pub track_gradient_norms: bool,
24    /// Whether to track parameter changes
25    pub track_parameter_changes: bool,
26    /// Whether to track learning rate changes
27    pub track_learning_rates: bool,
28    /// Whether to track convergence metrics
29    pub track_convergence: bool,
30    /// Whether to track performance metrics
31    pub track_performance: bool,
32    /// History window size for rolling statistics
33    pub history_window: usize,
34    /// Frequency of detailed logging (every N steps)
35    pub log_frequency: usize,
36}
37
38impl Default for MonitoringConfig {
39    fn default() -> Self {
40        Self {
41            track_gradient_norms: true,
42            track_parameter_changes: true,
43            track_learning_rates: true,
44            track_convergence: true,
45            track_performance: false,
46            history_window: 100,
47            log_frequency: 10,
48        }
49    }
50}
51
52/// Statistics for a single metric over time.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct MetricStats {
55    /// Recent values (up to history_window size)
56    pub values: VecDeque<f32>,
57    /// Current value
58    pub current: f32,
59    /// Mean over history window
60    pub mean: f32,
61    /// Standard deviation over history window
62    pub std: f32,
63    /// Minimum value in history
64    pub min: f32,
65    /// Maximum value in history
66    pub max: f32,
67    /// Trend (positive = increasing, negative = decreasing)
68    pub trend: f32,
69}
70
71impl MetricStats {
72    pub fn new(window_size: usize) -> Self {
73        Self {
74            values: VecDeque::with_capacity(window_size),
75            current: 0.0,
76            mean: 0.0,
77            std: 0.0,
78            min: f32::INFINITY,
79            max: f32::NEG_INFINITY,
80            trend: 0.0,
81        }
82    }
83
84    /// Update the metric with a new value.
85    pub fn update(&mut self, value: f32, window_size: usize) {
86        self.current = value;
87        self.values.push_back(value);
88
89        if self.values.len() > window_size {
90            self.values.pop_front();
91        }
92
93        self.compute_statistics();
94    }
95
96    fn compute_statistics(&mut self) {
97        if self.values.is_empty() {
98            return;
99        }
100
101        // Basic statistics
102        let sum: f32 = self.values.iter().sum();
103        self.mean = sum / self.values.len() as f32;
104
105        let variance: f32 = self.values.iter().map(|x| (x - self.mean).powi(2)).sum::<f32>()
106            / self.values.len() as f32;
107        self.std = variance.sqrt();
108
109        self.min = self.values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
110        self.max = self.values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
111
112        // Compute trend (linear regression slope)
113        if self.values.len() >= 2 {
114            let n = self.values.len() as f32;
115            let x_mean = (n - 1.0) / 2.0; // 0, 1, 2, ... n-1 mean
116
117            let mut numerator = 0.0;
118            let mut denominator = 0.0;
119
120            for (i, &y) in self.values.iter().enumerate() {
121                let x = i as f32;
122                numerator += (x - x_mean) * (y - self.mean);
123                denominator += (x - x_mean).powi(2);
124            }
125
126            self.trend = if denominator > 1e-8 { numerator / denominator } else { 0.0 };
127        }
128    }
129
130    /// Check if the metric has plateaued (low variance and trend).
131    pub fn is_plateaued(&self, variance_threshold: f32, trend_threshold: f32) -> bool {
132        self.std < variance_threshold && self.trend.abs() < trend_threshold
133    }
134
135    /// Check if the metric is trending upward.
136    pub fn is_increasing(&self, threshold: f32) -> bool {
137        self.trend > threshold
138    }
139
140    /// Check if the metric is trending downward.
141    pub fn is_decreasing(&self, threshold: f32) -> bool {
142        self.trend < -threshold
143    }
144}
145
146/// Performance monitoring data.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct PerformanceStats {
149    /// Total time spent in optimization steps
150    pub total_step_time: Duration,
151    /// Average time per step
152    pub avg_step_time: Duration,
153    /// Number of optimization steps
154    pub step_count: usize,
155    /// Memory usage statistics
156    pub memory_usage: Option<MemoryStats>,
157}
158
159impl PerformanceStats {
160    pub fn new() -> Self {
161        Self {
162            total_step_time: Duration::new(0, 0),
163            avg_step_time: Duration::new(0, 0),
164            step_count: 0,
165            memory_usage: None,
166        }
167    }
168
169    /// Record a step timing.
170    pub fn record_step_time(&mut self, duration: Duration) {
171        self.total_step_time += duration;
172        self.step_count += 1;
173        self.avg_step_time = self.total_step_time / self.step_count as u32;
174    }
175}
176
177impl Default for PerformanceStats {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183/// Memory usage statistics.
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct MemoryStats {
186    /// GPU memory usage in bytes
187    pub gpu_memory_bytes: usize,
188    /// CPU memory usage in bytes
189    pub cpu_memory_bytes: usize,
190    /// Peak memory usage
191    pub peak_memory_bytes: usize,
192}
193
194/// Comprehensive optimizer monitoring data.
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct OptimizerMetrics {
197    /// Current step number
198    pub step: usize,
199    /// Learning rate statistics
200    pub learning_rate: MetricStats,
201    /// Gradient norm statistics
202    pub gradient_norm: MetricStats,
203    /// Parameter change norm statistics
204    pub parameter_change_norm: MetricStats,
205    /// Loss statistics (if provided)
206    pub loss: MetricStats,
207    /// Performance statistics
208    pub performance: PerformanceStats,
209    /// Parameter-specific gradient norms
210    pub parameter_gradient_norms: HashMap<String, MetricStats>,
211    /// Convergence indicators
212    pub convergence_indicators: ConvergenceIndicators,
213}
214
215/// Convergence analysis indicators.
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct ConvergenceIndicators {
218    /// Whether loss appears to have plateaued
219    pub loss_plateaued: bool,
220    /// Whether gradients are vanishing (very small norms)
221    pub gradients_vanishing: bool,
222    /// Whether gradients are exploding (very large norms)
223    pub gradients_exploding: bool,
224    /// Whether training appears to be oscillating
225    pub oscillating: bool,
226    /// Estimated convergence rate
227    pub convergence_rate: f32,
228}
229
230impl ConvergenceIndicators {
231    pub fn new() -> Self {
232        Self {
233            loss_plateaued: false,
234            gradients_vanishing: false,
235            gradients_exploding: false,
236            oscillating: false,
237            convergence_rate: 0.0,
238        }
239    }
240}
241
242impl Default for ConvergenceIndicators {
243    fn default() -> Self {
244        Self::new()
245    }
246}
247
248/// Main optimizer monitor that tracks various metrics.
249#[derive(Debug)]
250pub struct OptimizerMonitor {
251    config: MonitoringConfig,
252    metrics: OptimizerMetrics,
253    step_start_time: Option<Instant>,
254    previous_parameters: Option<Vec<Tensor>>,
255}
256
257impl OptimizerMonitor {
258    /// Create a new optimizer monitor.
259    pub fn new(config: MonitoringConfig) -> Self {
260        Self {
261            metrics: OptimizerMetrics {
262                step: 0,
263                learning_rate: MetricStats::new(config.history_window),
264                gradient_norm: MetricStats::new(config.history_window),
265                parameter_change_norm: MetricStats::new(config.history_window),
266                loss: MetricStats::new(config.history_window),
267                performance: PerformanceStats::new(),
268                parameter_gradient_norms: HashMap::new(),
269                convergence_indicators: ConvergenceIndicators::new(),
270            },
271            config,
272            step_start_time: None,
273            previous_parameters: None,
274        }
275    }
276
277    /// Create a monitor with default configuration.
278    pub fn with_defaults() -> Self {
279        Self::new(MonitoringConfig::default())
280    }
281
282    /// Called before an optimizer step to start timing.
283    pub fn before_step(&mut self) {
284        if self.config.track_performance {
285            self.step_start_time = Some(Instant::now());
286        }
287    }
288
289    /// Called after an optimizer step to record metrics.
290    pub fn after_step(
291        &mut self,
292        learning_rate: f32,
293        parameters: &[Tensor],
294        loss: Option<f32>,
295    ) -> Result<()> {
296        self.metrics.step += 1;
297
298        // Record timing
299        if let Some(start_time) = self.step_start_time.take() {
300            let duration = start_time.elapsed();
301            self.metrics.performance.record_step_time(duration);
302        }
303
304        // Track learning rate
305        if self.config.track_learning_rates {
306            self.metrics.learning_rate.update(learning_rate, self.config.history_window);
307        }
308
309        // Track gradient norms
310        if self.config.track_gradient_norms {
311            let total_grad_norm = self.compute_total_gradient_norm(parameters)?;
312            self.metrics.gradient_norm.update(total_grad_norm, self.config.history_window);
313
314            // Track per-parameter gradient norms
315            for (i, param) in parameters.iter().enumerate() {
316                if let Ok(grad) = param.grad() {
317                    let param_name = format!("param_{}", i);
318                    let grad_norm = grad.norm()?;
319
320                    let param_stats = self
321                        .metrics
322                        .parameter_gradient_norms
323                        .entry(param_name)
324                        .or_insert_with(|| MetricStats::new(self.config.history_window));
325                    param_stats.update(grad_norm, self.config.history_window);
326                }
327            }
328        }
329
330        // Track parameter changes
331        if self.config.track_parameter_changes {
332            if let Some(prev_params) = &self.previous_parameters {
333                let change_norm = self.compute_parameter_change_norm(parameters, prev_params)?;
334                self.metrics
335                    .parameter_change_norm
336                    .update(change_norm, self.config.history_window);
337            }
338            self.previous_parameters = Some(parameters.to_vec());
339        }
340
341        // Track loss
342        if let Some(loss_value) = loss {
343            self.metrics.loss.update(loss_value, self.config.history_window);
344        }
345
346        // Update convergence indicators
347        if self.config.track_convergence {
348            self.update_convergence_indicators();
349        }
350
351        Ok(())
352    }
353
354    /// Update loss value (can be called independently of step).
355    pub fn update_loss(&mut self, loss: f32) {
356        self.metrics.loss.update(loss, self.config.history_window);
357        if self.config.track_convergence {
358            self.update_convergence_indicators();
359        }
360    }
361
362    /// Get current metrics.
363    pub fn get_metrics(&self) -> &OptimizerMetrics {
364        &self.metrics
365    }
366
367    /// Check if we should log detailed metrics this step.
368    pub fn should_log(&self) -> bool {
369        self.metrics.step % self.config.log_frequency == 0
370    }
371
372    /// Get a summary report of current optimizer status.
373    pub fn get_summary_report(&self) -> String {
374        format!(
375            "Step {}: LR={:.6}, GradNorm={:.6}±{:.6}, ParamChange={:.6}, Loss={:.6} (trend: {:.6})",
376            self.metrics.step,
377            self.metrics.learning_rate.current,
378            self.metrics.gradient_norm.mean,
379            self.metrics.gradient_norm.std,
380            self.metrics.parameter_change_norm.current,
381            self.metrics.loss.current,
382            self.metrics.loss.trend
383        )
384    }
385
386    /// Get convergence status report.
387    pub fn get_convergence_report(&self) -> String {
388        let indicators = &self.metrics.convergence_indicators;
389        format!(
390            "Convergence Status: Loss Plateaued: {}, Gradients Vanishing: {}, Gradients Exploding: {}, Oscillating: {}, Rate: {:.6}",
391            indicators.loss_plateaued,
392            indicators.gradients_vanishing,
393            indicators.gradients_exploding,
394            indicators.oscillating,
395            indicators.convergence_rate
396        )
397    }
398
399    /// Reset monitoring state.
400    pub fn reset(&mut self) {
401        self.metrics = OptimizerMetrics {
402            step: 0,
403            learning_rate: MetricStats::new(self.config.history_window),
404            gradient_norm: MetricStats::new(self.config.history_window),
405            parameter_change_norm: MetricStats::new(self.config.history_window),
406            loss: MetricStats::new(self.config.history_window),
407            performance: PerformanceStats::new(),
408            parameter_gradient_norms: HashMap::new(),
409            convergence_indicators: ConvergenceIndicators::new(),
410        };
411        self.previous_parameters = None;
412        self.step_start_time = None;
413    }
414
415    fn compute_total_gradient_norm(&self, parameters: &[Tensor]) -> Result<f32> {
416        let mut total_norm_sq = 0.0;
417        for param in parameters {
418            if let Ok(grad) = param.grad() {
419                let norm_sq = grad.norm_squared()?.to_scalar()?;
420                total_norm_sq += norm_sq;
421            }
422        }
423        Ok(total_norm_sq.sqrt())
424    }
425
426    fn compute_parameter_change_norm(
427        &self,
428        current: &[Tensor],
429        previous: &[Tensor],
430    ) -> Result<f32> {
431        if current.len() != previous.len() {
432            return Err(anyhow!("Parameter count mismatch"));
433        }
434
435        let mut total_change_sq = 0.0;
436        for (curr, prev) in current.iter().zip(previous.iter()) {
437            let diff = curr.sub(prev)?;
438            let norm_sq = diff.norm_squared()?.to_scalar()?;
439            total_change_sq += norm_sq;
440        }
441        Ok(total_change_sq.sqrt())
442    }
443
444    fn update_convergence_indicators(&mut self) {
445        let indicators = &mut self.metrics.convergence_indicators;
446
447        // Check for loss plateau
448        indicators.loss_plateaued = self.metrics.loss.is_plateaued(1e-6, 1e-6);
449
450        // Check for vanishing gradients (very small gradient norms)
451        indicators.gradients_vanishing = self.metrics.gradient_norm.current < 1e-8;
452
453        // Check for exploding gradients (very large gradient norms)
454        indicators.gradients_exploding = self.metrics.gradient_norm.current > 100.0;
455
456        // Check for oscillations (high variance in loss)
457        indicators.oscillating = self.metrics.loss.std > self.metrics.loss.mean * 0.1;
458
459        // Estimate convergence rate from loss trend
460        indicators.convergence_rate = -self.metrics.loss.trend; // Negative trend = positive convergence
461    }
462}
463
464/// Configuration for hyperparameter sensitivity analysis.
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct HyperparameterSensitivityConfig {
467    /// Whether to enable sensitivity analysis
468    pub enabled: bool,
469    /// Perturbation magnitude for finite difference approximation
470    pub perturbation_magnitude: f32,
471    /// Number of steps to analyze sensitivity over
472    pub analysis_window: usize,
473    /// Minimum number of samples before computing sensitivity
474    pub min_samples: usize,
475    /// Which hyperparameters to analyze
476    pub analyze_learning_rate: bool,
477    /// Whether to analyze momentum parameters
478    pub analyze_momentum: bool,
479    /// Whether to analyze weight decay
480    pub analyze_weight_decay: bool,
481    /// Whether to analyze epsilon (for Adam-like optimizers)
482    pub analyze_epsilon: bool,
483    /// Frequency of sensitivity analysis (every N steps)
484    pub analysis_frequency: usize,
485}
486
487impl Default for HyperparameterSensitivityConfig {
488    fn default() -> Self {
489        Self {
490            enabled: true,
491            perturbation_magnitude: 0.01, // 1% perturbation
492            analysis_window: 50,
493            min_samples: 10,
494            analyze_learning_rate: true,
495            analyze_momentum: true,
496            analyze_weight_decay: true,
497            analyze_epsilon: false,
498            analysis_frequency: 25,
499        }
500    }
501}
502
503/// Sensitivity metrics for a specific hyperparameter.
504#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct HyperparameterSensitivityMetrics {
506    /// Hyperparameter name
507    pub name: String,
508    /// Current sensitivity estimate (∂loss/∂hyperparameter)
509    pub current_sensitivity: f32,
510    /// Historical sensitivity values
511    pub sensitivity_history: VecDeque<f32>,
512    /// Mean sensitivity over history
513    pub mean_sensitivity: f32,
514    /// Standard deviation of sensitivity
515    pub std_sensitivity: f32,
516    /// Normalized sensitivity (sensitivity / hyperparameter_value)
517    pub normalized_sensitivity: f32,
518    /// Relative importance score (0-1)
519    pub importance_score: f32,
520}
521
522impl HyperparameterSensitivityMetrics {
523    pub fn new(name: String, window_size: usize) -> Self {
524        Self {
525            name,
526            current_sensitivity: 0.0,
527            sensitivity_history: VecDeque::with_capacity(window_size),
528            mean_sensitivity: 0.0,
529            std_sensitivity: 0.0,
530            normalized_sensitivity: 0.0,
531            importance_score: 0.0,
532        }
533    }
534
535    /// Update sensitivity with a new measurement.
536    pub fn update(&mut self, sensitivity: f32, hyperparameter_value: f32, window_size: usize) {
537        self.current_sensitivity = sensitivity;
538        self.sensitivity_history.push_back(sensitivity);
539
540        if self.sensitivity_history.len() > window_size {
541            self.sensitivity_history.pop_front();
542        }
543
544        self.compute_statistics(hyperparameter_value);
545    }
546
547    fn compute_statistics(&mut self, hyperparameter_value: f32) {
548        if self.sensitivity_history.is_empty() {
549            return;
550        }
551
552        // Basic statistics
553        let sum: f32 = self.sensitivity_history.iter().sum();
554        self.mean_sensitivity = sum / self.sensitivity_history.len() as f32;
555
556        let variance: f32 = self
557            .sensitivity_history
558            .iter()
559            .map(|x| (x - self.mean_sensitivity).powi(2))
560            .sum::<f32>()
561            / self.sensitivity_history.len() as f32;
562        self.std_sensitivity = variance.sqrt();
563
564        // Normalized sensitivity (relative to hyperparameter value)
565        if hyperparameter_value.abs() > 1e-8 {
566            self.normalized_sensitivity = self.current_sensitivity / hyperparameter_value;
567        } else {
568            self.normalized_sensitivity = 0.0;
569        }
570
571        // Importance score based on magnitude and stability
572        let magnitude_score = self.normalized_sensitivity.abs().tanh(); // Bounded [0,1]
573        let stability_score = (-self.std_sensitivity.abs()).exp(); // Higher for more stable
574        self.importance_score = magnitude_score * stability_score;
575    }
576
577    /// Check if this hyperparameter is highly sensitive.
578    pub fn is_highly_sensitive(&self, threshold: f32) -> bool {
579        self.importance_score > threshold
580    }
581
582    /// Check if sensitivity is stable (low variance).
583    pub fn is_stable(&self, variance_threshold: f32) -> bool {
584        self.std_sensitivity < variance_threshold
585    }
586}
587
588/// Main hyperparameter sensitivity analyzer.
589#[derive(Debug)]
590pub struct HyperparameterSensitivity {
591    config: HyperparameterSensitivityConfig,
592    sensitivity_metrics: HashMap<String, HyperparameterSensitivityMetrics>,
593    baseline_loss: Option<f32>,
594    perturbation_losses: HashMap<String, f32>,
595    step_count: usize,
596}
597
598impl HyperparameterSensitivity {
599    /// Create a new sensitivity analyzer.
600    pub fn new(config: HyperparameterSensitivityConfig) -> Self {
601        Self {
602            config,
603            sensitivity_metrics: HashMap::new(),
604            baseline_loss: None,
605            perturbation_losses: HashMap::new(),
606            step_count: 0,
607        }
608    }
609
610    /// Create analyzer with default configuration.
611    pub fn with_defaults() -> Self {
612        Self::new(HyperparameterSensitivityConfig::default())
613    }
614
615    /// Record baseline loss for sensitivity analysis.
616    pub fn record_baseline_loss(&mut self, loss: f32) {
617        self.baseline_loss = Some(loss);
618    }
619
620    /// Record loss after hyperparameter perturbation.
621    pub fn record_perturbation_loss(&mut self, hyperparameter_name: String, loss: f32) {
622        self.perturbation_losses.insert(hyperparameter_name, loss);
623    }
624
625    /// Compute sensitivity for a specific hyperparameter.
626    pub fn compute_sensitivity(
627        &mut self,
628        hyperparameter_name: &str,
629        hyperparameter_value: f32,
630        perturbed_value: f32,
631        loss_change: f32,
632    ) -> f32 {
633        let param_change = perturbed_value - hyperparameter_value;
634
635        // Avoid division by zero
636        if param_change.abs() < 1e-12 {
637            return 0.0;
638        }
639
640        // Finite difference approximation: ∂loss/∂param ≈ Δloss/Δparam
641        let sensitivity = loss_change / param_change;
642
643        // Update or create sensitivity metrics
644        let metrics = self
645            .sensitivity_metrics
646            .entry(hyperparameter_name.to_string())
647            .or_insert_with(|| {
648                HyperparameterSensitivityMetrics::new(
649                    hyperparameter_name.to_string(),
650                    self.config.analysis_window,
651                )
652            });
653
654        metrics.update(
655            sensitivity,
656            hyperparameter_value,
657            self.config.analysis_window,
658        );
659
660        sensitivity
661    }
662
663    /// Analyze sensitivity for learning rate.
664    pub fn analyze_learning_rate_sensitivity(
665        &mut self,
666        current_lr: f32,
667        baseline_loss: f32,
668        perturbed_loss: f32,
669    ) -> f32 {
670        let perturbed_lr = current_lr * (1.0 + self.config.perturbation_magnitude);
671        let loss_change = perturbed_loss - baseline_loss;
672
673        self.compute_sensitivity("learning_rate", current_lr, perturbed_lr, loss_change)
674    }
675
676    /// Analyze sensitivity for momentum parameter.
677    pub fn analyze_momentum_sensitivity(
678        &mut self,
679        current_momentum: f32,
680        baseline_loss: f32,
681        perturbed_loss: f32,
682    ) -> f32 {
683        let perturbed_momentum = current_momentum * (1.0 + self.config.perturbation_magnitude);
684        let loss_change = perturbed_loss - baseline_loss;
685
686        self.compute_sensitivity(
687            "momentum",
688            current_momentum,
689            perturbed_momentum,
690            loss_change,
691        )
692    }
693
694    /// Analyze sensitivity for weight decay.
695    pub fn analyze_weight_decay_sensitivity(
696        &mut self,
697        current_weight_decay: f32,
698        baseline_loss: f32,
699        perturbed_loss: f32,
700    ) -> f32 {
701        let perturbed_weight_decay =
702            current_weight_decay * (1.0 + self.config.perturbation_magnitude);
703        let loss_change = perturbed_loss - baseline_loss;
704
705        self.compute_sensitivity(
706            "weight_decay",
707            current_weight_decay,
708            perturbed_weight_decay,
709            loss_change,
710        )
711    }
712
713    /// Analyze sensitivity for epsilon parameter.
714    pub fn analyze_epsilon_sensitivity(
715        &mut self,
716        current_epsilon: f32,
717        baseline_loss: f32,
718        perturbed_loss: f32,
719    ) -> f32 {
720        let perturbed_epsilon = current_epsilon * (1.0 + self.config.perturbation_magnitude);
721        let loss_change = perturbed_loss - baseline_loss;
722
723        self.compute_sensitivity("epsilon", current_epsilon, perturbed_epsilon, loss_change)
724    }
725
726    /// Get sensitivity metrics for a specific hyperparameter.
727    pub fn get_sensitivity_metrics(
728        &self,
729        hyperparameter: &str,
730    ) -> Option<&HyperparameterSensitivityMetrics> {
731        self.sensitivity_metrics.get(hyperparameter)
732    }
733
734    /// Get all sensitivity metrics.
735    pub fn get_all_sensitivity_metrics(
736        &self,
737    ) -> &HashMap<String, HyperparameterSensitivityMetrics> {
738        &self.sensitivity_metrics
739    }
740
741    /// Get most sensitive hyperparameters (sorted by importance score).
742    pub fn get_most_sensitive_hyperparameters(
743        &self,
744    ) -> Vec<(&String, &HyperparameterSensitivityMetrics)> {
745        let mut sorted: Vec<_> = self.sensitivity_metrics.iter().collect();
746        sorted.sort_by(|a, b| {
747            b.1.importance_score
748                .partial_cmp(&a.1.importance_score)
749                .unwrap_or(std::cmp::Ordering::Equal)
750        });
751        sorted
752    }
753
754    /// Check if sensitivity analysis should be performed this step.
755    pub fn should_analyze(&self) -> bool {
756        self.config.enabled
757            && self.step_count % self.config.analysis_frequency == 0
758            && self.step_count >= self.config.min_samples
759    }
760
761    /// Increment step count.
762    pub fn step(&mut self) {
763        self.step_count += 1;
764    }
765
766    /// Get a summary report of hyperparameter sensitivities.
767    pub fn get_sensitivity_report(&self) -> String {
768        let mut report = String::from("Hyperparameter Sensitivity Analysis:\n");
769
770        let sorted_metrics = self.get_most_sensitive_hyperparameters();
771
772        for (name, metrics) in sorted_metrics.iter().take(5) {
773            // Top 5 most sensitive
774            report.push_str(&format!(
775                "  {}: Sensitivity={:.6}, Normalized={:.6}, Importance={:.3} ({})\n",
776                name,
777                metrics.current_sensitivity,
778                metrics.normalized_sensitivity,
779                metrics.importance_score,
780                if metrics.is_highly_sensitive(0.5) { "HIGH" } else { "LOW" }
781            ));
782        }
783
784        if sorted_metrics.is_empty() {
785            report.push_str("  No sensitivity data available yet.\n");
786        }
787
788        report
789    }
790
791    /// Get recommendations based on sensitivity analysis.
792    pub fn get_recommendations(&self) -> Vec<String> {
793        let mut recommendations = Vec::new();
794
795        for (name, metrics) in &self.sensitivity_metrics {
796            if metrics.is_highly_sensitive(0.7) {
797                recommendations.push(format!(
798                    "Consider careful tuning of {}: high sensitivity detected (score: {:.3})",
799                    name, metrics.importance_score
800                ));
801            }
802
803            if !metrics.is_stable(0.1) {
804                recommendations.push(format!(
805                    "Consider stabilizing {}: sensitivity varies significantly (std: {:.6})",
806                    name, metrics.std_sensitivity
807                ));
808            }
809        }
810
811        if recommendations.is_empty() {
812            recommendations.push(
813                "All hyperparameters appear to have reasonable sensitivity profiles.".to_string(),
814            );
815        }
816
817        recommendations
818    }
819
820    /// Reset sensitivity analysis state.
821    pub fn reset(&mut self) {
822        self.sensitivity_metrics.clear();
823        self.baseline_loss = None;
824        self.perturbation_losses.clear();
825        self.step_count = 0;
826    }
827}
828
829#[cfg(test)]
830mod tests {
831    use super::*;
832
833    #[test]
834    fn test_metric_stats_creation() {
835        let stats = MetricStats::new(10);
836        assert_eq!(stats.values.capacity(), 10);
837        assert_eq!(stats.current, 0.0);
838        assert_eq!(stats.mean, 0.0);
839    }
840
841    #[test]
842    fn test_metric_stats_update() {
843        let mut stats = MetricStats::new(3);
844
845        stats.update(1.0, 3);
846        assert_eq!(stats.current, 1.0);
847        assert_eq!(stats.mean, 1.0);
848
849        stats.update(2.0, 3);
850        assert_eq!(stats.current, 2.0);
851        assert_eq!(stats.mean, 1.5);
852
853        stats.update(3.0, 3);
854        assert_eq!(stats.current, 3.0);
855        assert_eq!(stats.mean, 2.0);
856
857        // Should maintain window size
858        stats.update(4.0, 3);
859        assert_eq!(stats.values.len(), 3);
860        assert_eq!(stats.mean, 3.0); // (2+3+4)/3
861    }
862
863    #[test]
864    fn test_metric_stats_trend() {
865        let mut stats = MetricStats::new(10);
866
867        // Add increasing values
868        for i in 1..=5 {
869            stats.update(i as f32, 10);
870        }
871
872        // Should detect positive trend
873        assert!(stats.trend > 0.0);
874        assert!(stats.is_increasing(0.5));
875        assert!(!stats.is_decreasing(0.5));
876    }
877
878    #[test]
879    fn test_performance_stats() {
880        let mut perf = PerformanceStats::new();
881        assert_eq!(perf.step_count, 0);
882
883        perf.record_step_time(Duration::from_millis(100));
884        assert_eq!(perf.step_count, 1);
885        assert_eq!(perf.avg_step_time, Duration::from_millis(100));
886
887        perf.record_step_time(Duration::from_millis(200));
888        assert_eq!(perf.step_count, 2);
889        assert_eq!(perf.avg_step_time, Duration::from_millis(150));
890    }
891
892    #[test]
893    fn test_convergence_indicators() {
894        let indicators = ConvergenceIndicators::new();
895        assert!(!indicators.loss_plateaued);
896        assert!(!indicators.gradients_vanishing);
897        assert!(!indicators.gradients_exploding);
898        assert!(!indicators.oscillating);
899        assert_eq!(indicators.convergence_rate, 0.0);
900    }
901
902    #[test]
903    fn test_optimizer_monitor_creation() {
904        let monitor = OptimizerMonitor::with_defaults();
905        assert_eq!(monitor.metrics.step, 0);
906        assert!(monitor.previous_parameters.is_none());
907    }
908
909    #[test]
910    fn test_monitor_should_log() {
911        let mut monitor = OptimizerMonitor::with_defaults();
912
913        // Should log at step 0
914        assert!(monitor.should_log());
915
916        monitor.metrics.step = 5;
917        assert!(!monitor.should_log()); // Default frequency is 10
918
919        monitor.metrics.step = 10;
920        assert!(monitor.should_log());
921    }
922
923    #[test]
924    fn test_hyperparameter_sensitivity_config() {
925        let config = HyperparameterSensitivityConfig::default();
926        assert!(config.enabled);
927        assert_eq!(config.perturbation_magnitude, 0.01);
928        assert_eq!(config.analysis_window, 50);
929        assert_eq!(config.min_samples, 10);
930        assert!(config.analyze_learning_rate);
931        assert!(config.analyze_momentum);
932        assert!(config.analyze_weight_decay);
933        assert!(!config.analyze_epsilon);
934        assert_eq!(config.analysis_frequency, 25);
935    }
936
937    #[test]
938    fn test_hyperparameter_sensitivity_metrics() {
939        let mut metrics = HyperparameterSensitivityMetrics::new("learning_rate".to_string(), 10);
940        assert_eq!(metrics.name, "learning_rate");
941        assert_eq!(metrics.current_sensitivity, 0.0);
942        assert_eq!(metrics.importance_score, 0.0);
943
944        // Update with some sensitivity values
945        metrics.update(0.5, 0.01, 10); // sensitivity=0.5, lr=0.01
946        assert_eq!(metrics.current_sensitivity, 0.5);
947        assert_eq!(metrics.normalized_sensitivity, 0.5 / 0.01);
948        assert!(metrics.importance_score > 0.0);
949
950        metrics.update(0.3, 0.01, 10);
951        assert_eq!(metrics.current_sensitivity, 0.3);
952        assert_eq!(metrics.sensitivity_history.len(), 2);
953
954        // Check if metrics are computed correctly
955        assert_eq!(metrics.mean_sensitivity, 0.4); // (0.5 + 0.3) / 2
956    }
957
958    #[test]
959    fn test_hyperparameter_sensitivity_analyzer() {
960        let mut analyzer = HyperparameterSensitivity::with_defaults();
961
962        // Test baseline loss recording
963        analyzer.record_baseline_loss(1.0);
964        assert_eq!(analyzer.baseline_loss, Some(1.0));
965
966        // Test perturbation loss recording
967        analyzer.record_perturbation_loss("learning_rate".to_string(), 1.1);
968        assert_eq!(
969            analyzer.perturbation_losses.get("learning_rate"),
970            Some(&1.1)
971        );
972
973        // Test sensitivity computation
974        let sensitivity = analyzer.compute_sensitivity("learning_rate", 0.01, 0.0101, 0.1);
975        let expected = 0.1 / 0.0001; // loss_change / param_change = 1000.0
976        assert!(
977            (sensitivity - expected).abs() < 0.01,
978            "Expected {}, got {}",
979            expected,
980            sensitivity
981        );
982
983        // Check that metrics were created
984        assert!(analyzer.sensitivity_metrics.contains_key("learning_rate"));
985    }
986
987    #[test]
988    fn test_sensitivity_analysis_methods() {
989        let mut analyzer = HyperparameterSensitivity::with_defaults();
990
991        // Test learning rate sensitivity
992        let lr_sensitivity = analyzer.analyze_learning_rate_sensitivity(0.01, 1.0, 1.1);
993        assert!(lr_sensitivity > 0.0);
994        assert!(analyzer.sensitivity_metrics.contains_key("learning_rate"));
995
996        // Test momentum sensitivity
997        let momentum_sensitivity = analyzer.analyze_momentum_sensitivity(0.9, 1.0, 0.95);
998        assert!(momentum_sensitivity < 0.0); // Loss decreased
999        assert!(analyzer.sensitivity_metrics.contains_key("momentum"));
1000
1001        // Test weight decay sensitivity
1002        let wd_sensitivity = analyzer.analyze_weight_decay_sensitivity(0.01, 1.0, 1.05);
1003        assert!(wd_sensitivity > 0.0);
1004        assert!(analyzer.sensitivity_metrics.contains_key("weight_decay"));
1005    }
1006
1007    #[test]
1008    fn test_sensitivity_should_analyze() {
1009        let mut analyzer = HyperparameterSensitivity::with_defaults();
1010
1011        // Should not analyze initially (below min_samples)
1012        assert!(!analyzer.should_analyze());
1013
1014        // Step forward to reach min_samples
1015        for _ in 0..10 {
1016            analyzer.step();
1017        }
1018        assert!(!analyzer.should_analyze()); // Still not at frequency
1019
1020        // Step to reach analysis frequency
1021        for _ in 0..15 {
1022            analyzer.step();
1023        }
1024        assert!(analyzer.should_analyze()); // Now at step 25
1025    }
1026
1027    #[test]
1028    fn test_sensitivity_report_generation() {
1029        let mut analyzer = HyperparameterSensitivity::with_defaults();
1030
1031        // Add some sensitivity data
1032        analyzer.compute_sensitivity("learning_rate", 0.01, 0.0101, 0.1);
1033        analyzer.compute_sensitivity("momentum", 0.9, 0.909, -0.05);
1034
1035        let report = analyzer.get_sensitivity_report();
1036        assert!(report.contains("Hyperparameter Sensitivity Analysis"));
1037        assert!(report.contains("learning_rate"));
1038        assert!(report.contains("momentum"));
1039
1040        let recommendations = analyzer.get_recommendations();
1041        assert!(!recommendations.is_empty());
1042    }
1043
1044    #[test]
1045    fn test_sensitivity_most_sensitive_hyperparameters() {
1046        let mut analyzer = HyperparameterSensitivity::with_defaults();
1047
1048        // Add different sensitivity levels
1049        analyzer.compute_sensitivity("learning_rate", 0.01, 0.0101, 0.2); // High sensitivity
1050        analyzer.compute_sensitivity("momentum", 0.9, 0.909, 0.01); // Low sensitivity
1051        analyzer.compute_sensitivity("weight_decay", 0.01, 0.0101, 0.15); // Medium sensitivity
1052
1053        let most_sensitive = analyzer.get_most_sensitive_hyperparameters();
1054        assert_eq!(most_sensitive.len(), 3);
1055
1056        // Should be sorted by importance score (descending)
1057        let first_importance = most_sensitive[0].1.importance_score;
1058        let second_importance = most_sensitive[1].1.importance_score;
1059        assert!(first_importance >= second_importance);
1060    }
1061
1062    #[test]
1063    fn test_sensitivity_reset() {
1064        let mut analyzer = HyperparameterSensitivity::with_defaults();
1065
1066        // Add some data
1067        analyzer.record_baseline_loss(1.0);
1068        analyzer.record_perturbation_loss("learning_rate".to_string(), 1.1);
1069        analyzer.compute_sensitivity("learning_rate", 0.01, 0.0101, 0.1);
1070        analyzer.step();
1071
1072        // Verify data exists
1073        assert!(analyzer.baseline_loss.is_some());
1074        assert!(!analyzer.perturbation_losses.is_empty());
1075        assert!(!analyzer.sensitivity_metrics.is_empty());
1076        assert_eq!(analyzer.step_count, 1);
1077
1078        // Reset and verify everything is cleared
1079        analyzer.reset();
1080        assert!(analyzer.baseline_loss.is_none());
1081        assert!(analyzer.perturbation_losses.is_empty());
1082        assert!(analyzer.sensitivity_metrics.is_empty());
1083        assert_eq!(analyzer.step_count, 0);
1084    }
1085}
1086
1087/// Optimizer Performance Analysis and Selection Tool
1088///
1089/// Helps users choose the right optimizer based on their requirements
1090/// including performance characteristics, model size, and training objectives.
1091#[derive(Debug, Clone)]
1092pub struct OptimizerSelector {
1093    /// Model parameter count
1094    pub model_size: usize,
1095    /// Training duration requirements (training time sensitivity)
1096    pub time_sensitive: bool,
1097    /// Memory constraints
1098    pub memory_constrained: bool,
1099    /// Convergence speed priority
1100    pub fast_convergence: bool,
1101    /// Robustness requirements (handling diverse training conditions)
1102    pub robustness_priority: bool,
1103    /// Advanced features requirements (entropy weighting, adaptive norms)
1104    pub advanced_features: bool,
1105}
1106
1107/// Optimizer recommendation with performance characteristics
1108#[derive(Debug, Clone)]
1109pub struct OptimizerRecommendation {
1110    pub name: String,
1111    pub description: String,
1112    pub performance_tier: PerformanceTier,
1113    pub convergence_speed: ConvergenceSpeed,
1114    pub memory_usage: MemoryUsage,
1115    pub use_cases: Vec<String>,
1116    pub estimated_overhead: f32, // multiplier compared to Adam baseline
1117}
1118
1119#[derive(Debug, Clone)]
1120pub enum PerformanceTier {
1121    Fastest,  // Traditional optimizers (Adam, SGD, AdamW)
1122    Moderate, // HN-Adam, Lookahead variants
1123    Advanced, // BGE-Adam, complex entropy-based methods
1124}
1125
1126#[derive(Debug, Clone)]
1127pub enum ConvergenceSpeed {
1128    Fast,
1129    Moderate,
1130    Superior, // Better than standard methods
1131}
1132
1133#[derive(Debug, Clone)]
1134pub enum MemoryUsage {
1135    Low,      // Similar to SGD
1136    Standard, // Adam-level
1137    High,     // Complex state tracking
1138}
1139
1140impl OptimizerSelector {
1141    pub fn new(model_size: usize) -> Self {
1142        Self {
1143            model_size,
1144            time_sensitive: false,
1145            memory_constrained: false,
1146            fast_convergence: false,
1147            robustness_priority: false,
1148            advanced_features: false,
1149        }
1150    }
1151
1152    pub fn time_sensitive(mut self, sensitive: bool) -> Self {
1153        self.time_sensitive = sensitive;
1154        self
1155    }
1156
1157    pub fn memory_constrained(mut self, constrained: bool) -> Self {
1158        self.memory_constrained = constrained;
1159        self
1160    }
1161
1162    pub fn fast_convergence(mut self, fast: bool) -> Self {
1163        self.fast_convergence = fast;
1164        self
1165    }
1166
1167    pub fn robustness_priority(mut self, robust: bool) -> Self {
1168        self.robustness_priority = robust;
1169        self
1170    }
1171
1172    pub fn advanced_features(mut self, advanced: bool) -> Self {
1173        self.advanced_features = advanced;
1174        self
1175    }
1176
1177    /// Get optimizer recommendations ranked by suitability
1178    pub fn get_recommendations(&self) -> Vec<OptimizerRecommendation> {
1179        let mut recommendations = self.generate_all_recommendations();
1180        self.rank_recommendations(&mut recommendations);
1181        recommendations
1182    }
1183
1184    /// Generate recommendations for all available optimizers
1185    fn generate_all_recommendations(&self) -> Vec<OptimizerRecommendation> {
1186        vec![
1187            OptimizerRecommendation {
1188                name: "AdamW".to_string(),
1189                description: "Decoupled weight decay Adam - excellent all-around optimizer"
1190                    .to_string(),
1191                performance_tier: PerformanceTier::Fastest,
1192                convergence_speed: ConvergenceSpeed::Fast,
1193                memory_usage: MemoryUsage::Standard,
1194                use_cases: vec![
1195                    "General purpose training".to_string(),
1196                    "Large language models".to_string(),
1197                    "Computer vision".to_string(),
1198                    "Production training".to_string(),
1199                ],
1200                estimated_overhead: 1.0, // baseline
1201            },
1202            OptimizerRecommendation {
1203                name: "Adam".to_string(),
1204                description: "Classic adaptive moment estimation optimizer".to_string(),
1205                performance_tier: PerformanceTier::Fastest,
1206                convergence_speed: ConvergenceSpeed::Fast,
1207                memory_usage: MemoryUsage::Standard,
1208                use_cases: vec![
1209                    "General purpose training".to_string(),
1210                    "Research and experimentation".to_string(),
1211                    "Quick prototyping".to_string(),
1212                ],
1213                estimated_overhead: 1.05, // slightly slower than AdamW
1214            },
1215            OptimizerRecommendation {
1216                name: "SGD".to_string(),
1217                description: "Stochastic gradient descent with momentum - simple and effective"
1218                    .to_string(),
1219                performance_tier: PerformanceTier::Fastest,
1220                convergence_speed: ConvergenceSpeed::Moderate,
1221                memory_usage: MemoryUsage::Low,
1222                use_cases: vec![
1223                    "Memory-constrained training".to_string(),
1224                    "Simple models".to_string(),
1225                    "Fine-tuning".to_string(),
1226                    "Educational purposes".to_string(),
1227                ],
1228                estimated_overhead: 1.1, // simple but effective
1229            },
1230            OptimizerRecommendation {
1231                name: "HN-Adam".to_string(),
1232                description: "Hybrid Norm Adam with adaptive step size based on parameter norms"
1233                    .to_string(),
1234                performance_tier: PerformanceTier::Moderate,
1235                convergence_speed: ConvergenceSpeed::Superior,
1236                memory_usage: MemoryUsage::Standard,
1237                use_cases: vec![
1238                    "Transformer training".to_string(),
1239                    "Computer vision tasks".to_string(),
1240                    "When adaptive learning rates are needed".to_string(),
1241                    "Research requiring latest optimization techniques".to_string(),
1242                ],
1243                estimated_overhead: 2.5, // ~2.5x slower but adaptive
1244            },
1245            OptimizerRecommendation {
1246                name: "BGE-Adam".to_string(),
1247                description: "Entropy-weighted Adam with adaptive gradient strategies".to_string(),
1248                performance_tier: PerformanceTier::Advanced,
1249                convergence_speed: ConvergenceSpeed::Superior,
1250                memory_usage: MemoryUsage::High,
1251                use_cases: vec![
1252                    "Research and experimentation".to_string(),
1253                    "Complex training scenarios".to_string(),
1254                    "When robustness is critical".to_string(),
1255                    "Handling diverse gradient conditions".to_string(),
1256                ],
1257                estimated_overhead: 13.0, // ~13x slower due to entropy calculations
1258            },
1259        ]
1260    }
1261
1262    /// Rank recommendations based on user requirements
1263    fn rank_recommendations(&self, recommendations: &mut [OptimizerRecommendation]) {
1264        recommendations.sort_by(|a, b| {
1265            let score_a = self.calculate_suitability_score(a);
1266            let score_b = self.calculate_suitability_score(b);
1267            score_b.partial_cmp(&score_a).unwrap()
1268        });
1269    }
1270
1271    /// Calculate suitability score for a recommendation
1272    fn calculate_suitability_score(&self, rec: &OptimizerRecommendation) -> f32 {
1273        let mut score = 0.0;
1274
1275        // Performance requirements
1276        if self.time_sensitive {
1277            score += match rec.performance_tier {
1278                PerformanceTier::Fastest => 10.0,
1279                PerformanceTier::Moderate => 5.0,
1280                PerformanceTier::Advanced => 1.0,
1281            };
1282        }
1283
1284        // Memory constraints
1285        if self.memory_constrained {
1286            score += match rec.memory_usage {
1287                MemoryUsage::Low => 10.0,
1288                MemoryUsage::Standard => 5.0,
1289                MemoryUsage::High => 1.0,
1290            };
1291        }
1292
1293        // Convergence speed priority
1294        if self.fast_convergence {
1295            score += match rec.convergence_speed {
1296                ConvergenceSpeed::Superior => 10.0,
1297                ConvergenceSpeed::Fast => 7.0,
1298                ConvergenceSpeed::Moderate => 3.0,
1299            };
1300        }
1301
1302        // Robustness priority
1303        if self.robustness_priority {
1304            match rec.name.as_str() {
1305                "BGE-Adam" => score += 10.0, // Highest robustness
1306                "HN-Adam" => score += 7.0,   // Good robustness
1307                "AdamW" => score += 5.0,     // Standard robustness
1308                _ => score += 3.0,
1309            }
1310        }
1311
1312        // Advanced features
1313        if self.advanced_features {
1314            match rec.name.as_str() {
1315                "BGE-Adam" => score += 10.0, // Entropy weighting
1316                "HN-Adam" => score += 8.0,   // Adaptive norms
1317                _ => score += 2.0,
1318            }
1319        }
1320
1321        // Model size considerations
1322        if self.model_size > 1_000_000 {
1323            // Large models benefit from stable optimizers
1324            match rec.name.as_str() {
1325                "AdamW" | "Adam" => score += 5.0,
1326                "HN-Adam" => score += 3.0,
1327                _ => score += 1.0,
1328            }
1329        }
1330
1331        // Base score for general usability
1332        match rec.name.as_str() {
1333            "AdamW" => score += 8.0,    // Excellent general purpose
1334            "Adam" => score += 7.0,     // Good general purpose
1335            "HN-Adam" => score += 6.0,  // Good with advanced features
1336            "SGD" => score += 5.0,      // Simple and reliable
1337            "BGE-Adam" => score += 4.0, // Specialized use case
1338            _ => score += 2.0,
1339        }
1340
1341        score
1342    }
1343
1344    /// Generate a detailed report with recommendations
1345    pub fn generate_report(&self) -> String {
1346        let recommendations = self.get_recommendations();
1347        let mut report = String::new();
1348
1349        report.push_str("🚀 TrustformeRS Optimizer Selection Report\n");
1350        report.push_str("=========================================\n\n");
1351
1352        report.push_str("📊 Model Configuration:\n");
1353        report.push_str(&format!(
1354            "   • Model size: {} parameters\n",
1355            self.model_size
1356        ));
1357        report.push_str(&format!("   • Time sensitive: {}\n", self.time_sensitive));
1358        report.push_str(&format!(
1359            "   • Memory constrained: {}\n",
1360            self.memory_constrained
1361        ));
1362        report.push_str(&format!(
1363            "   • Fast convergence priority: {}\n",
1364            self.fast_convergence
1365        ));
1366        report.push_str(&format!(
1367            "   • Robustness priority: {}\n",
1368            self.robustness_priority
1369        ));
1370        report.push_str(&format!(
1371            "   • Advanced features: {}\n\n",
1372            self.advanced_features
1373        ));
1374
1375        report.push_str("🏆 Recommended Optimizers (ranked by suitability):\n\n");
1376
1377        for (i, rec) in recommendations.iter().enumerate() {
1378            let rank_emoji = match i {
1379                0 => "🥇",
1380                1 => "🥈",
1381                2 => "🥉",
1382                _ => "📊",
1383            };
1384
1385            report.push_str(&format!(
1386                "{} {} - {}\n",
1387                rank_emoji, rec.name, rec.description
1388            ));
1389            report.push_str(&format!(
1390                "   Performance: {:?} | Convergence: {:?} | Memory: {:?}\n",
1391                rec.performance_tier, rec.convergence_speed, rec.memory_usage
1392            ));
1393            report.push_str(&format!(
1394                "   Overhead: {:.1}x compared to baseline\n",
1395                rec.estimated_overhead
1396            ));
1397            report.push_str(&format!("   Use cases: {}\n\n", rec.use_cases.join(", ")));
1398        }
1399
1400        report.push_str("💡 Performance Insights from Latest Benchmarks:\n");
1401        report.push_str("   • AdamW: 238µs/iter (100K params) - Fast and reliable\n");
1402        report.push_str("   • Adam: 248µs/iter - Slightly slower than AdamW\n");
1403        report.push_str("   • SGD: 257µs/iter - Simple and memory efficient\n");
1404        report.push_str("   • HN-Adam: 633µs/iter - 2.5x slower, adaptive step sizes\n");
1405        report.push_str("   • BGE-Adam: 3.3ms/iter - 13x slower, entropy-based robustness\n\n");
1406
1407        report.push_str("🎯 Quick Selection Guide:\n");
1408        report.push_str("   • Production training: AdamW\n");
1409        report.push_str("   • Memory constrained: SGD\n");
1410        report.push_str("   • Research/experimentation: HN-Adam or BGE-Adam\n");
1411        report.push_str("   • Maximum robustness: BGE-Adam\n");
1412        report.push_str("   • Adaptive learning rates: HN-Adam\n");
1413
1414        report
1415    }
1416}
1417
1418#[cfg(test)]
1419mod optimizer_selection_tests {
1420    use super::*;
1421
1422    #[test]
1423    fn test_optimizer_selector_basic() {
1424        let selector = OptimizerSelector::new(10000);
1425        let recommendations = selector.get_recommendations();
1426        assert!(!recommendations.is_empty());
1427        assert_eq!(recommendations.len(), 5); // All available optimizers
1428    }
1429
1430    #[test]
1431    fn test_time_sensitive_selection() {
1432        let selector = OptimizerSelector::new(10000).time_sensitive(true);
1433
1434        let recommendations = selector.get_recommendations();
1435        let top_rec = &recommendations[0];
1436
1437        // Should prioritize fastest optimizers
1438        assert!(matches!(top_rec.performance_tier, PerformanceTier::Fastest));
1439        assert!(top_rec.name == "AdamW" || top_rec.name == "Adam" || top_rec.name == "SGD");
1440    }
1441
1442    #[test]
1443    fn test_memory_constrained_selection() {
1444        let selector = OptimizerSelector::new(10000).memory_constrained(true);
1445
1446        let recommendations = selector.get_recommendations();
1447        let top_rec = &recommendations[0];
1448
1449        // Should prioritize low memory optimizers
1450        assert!(top_rec.name == "SGD" || matches!(top_rec.memory_usage, MemoryUsage::Low));
1451    }
1452
1453    #[test]
1454    fn test_robustness_priority_selection() {
1455        let selector = OptimizerSelector::new(10000).robustness_priority(true);
1456
1457        let recommendations = selector.get_recommendations();
1458        let top_rec = &recommendations[0];
1459
1460        // Should prioritize BGE-Adam for robustness
1461        assert!(top_rec.name == "BGE-Adam" || top_rec.name == "HN-Adam");
1462    }
1463
1464    #[test]
1465    fn test_advanced_features_selection() {
1466        let selector = OptimizerSelector::new(10000).advanced_features(true);
1467
1468        let recommendations = selector.get_recommendations();
1469        let top_rec = &recommendations[0];
1470
1471        // Should prioritize advanced optimizers
1472        assert!(top_rec.name == "BGE-Adam" || top_rec.name == "HN-Adam");
1473    }
1474
1475    #[test]
1476    fn test_report_generation() {
1477        let selector = OptimizerSelector::new(50000).time_sensitive(true).fast_convergence(true);
1478
1479        let report = selector.generate_report();
1480        assert!(report.contains("TrustformeRS Optimizer Selection Report"));
1481        assert!(report.contains("Model size: 50000"));
1482        assert!(report.contains("Time sensitive: true"));
1483        assert!(report.contains("🥇")); // Should have rankings
1484        assert!(report.contains("Performance Insights"));
1485    }
1486}