1use anyhow::{anyhow, Result};
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, VecDeque};
16use std::time::{Duration, Instant};
17use trustformers_core::tensor::Tensor;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct MonitoringConfig {
22 pub track_gradient_norms: bool,
24 pub track_parameter_changes: bool,
26 pub track_learning_rates: bool,
28 pub track_convergence: bool,
30 pub track_performance: bool,
32 pub history_window: usize,
34 pub log_frequency: usize,
36}
37
38impl Default for MonitoringConfig {
39 fn default() -> Self {
40 Self {
41 track_gradient_norms: true,
42 track_parameter_changes: true,
43 track_learning_rates: true,
44 track_convergence: true,
45 track_performance: false,
46 history_window: 100,
47 log_frequency: 10,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct MetricStats {
55 pub values: VecDeque<f32>,
57 pub current: f32,
59 pub mean: f32,
61 pub std: f32,
63 pub min: f32,
65 pub max: f32,
67 pub trend: f32,
69}
70
71impl MetricStats {
72 pub fn new(window_size: usize) -> Self {
73 Self {
74 values: VecDeque::with_capacity(window_size),
75 current: 0.0,
76 mean: 0.0,
77 std: 0.0,
78 min: f32::INFINITY,
79 max: f32::NEG_INFINITY,
80 trend: 0.0,
81 }
82 }
83
84 pub fn update(&mut self, value: f32, window_size: usize) {
86 self.current = value;
87 self.values.push_back(value);
88
89 if self.values.len() > window_size {
90 self.values.pop_front();
91 }
92
93 self.compute_statistics();
94 }
95
96 fn compute_statistics(&mut self) {
97 if self.values.is_empty() {
98 return;
99 }
100
101 let sum: f32 = self.values.iter().sum();
103 self.mean = sum / self.values.len() as f32;
104
105 let variance: f32 = self.values.iter().map(|x| (x - self.mean).powi(2)).sum::<f32>()
106 / self.values.len() as f32;
107 self.std = variance.sqrt();
108
109 self.min = self.values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
110 self.max = self.values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
111
112 if self.values.len() >= 2 {
114 let n = self.values.len() as f32;
115 let x_mean = (n - 1.0) / 2.0; let mut numerator = 0.0;
118 let mut denominator = 0.0;
119
120 for (i, &y) in self.values.iter().enumerate() {
121 let x = i as f32;
122 numerator += (x - x_mean) * (y - self.mean);
123 denominator += (x - x_mean).powi(2);
124 }
125
126 self.trend = if denominator > 1e-8 { numerator / denominator } else { 0.0 };
127 }
128 }
129
130 pub fn is_plateaued(&self, variance_threshold: f32, trend_threshold: f32) -> bool {
132 self.std < variance_threshold && self.trend.abs() < trend_threshold
133 }
134
135 pub fn is_increasing(&self, threshold: f32) -> bool {
137 self.trend > threshold
138 }
139
140 pub fn is_decreasing(&self, threshold: f32) -> bool {
142 self.trend < -threshold
143 }
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct PerformanceStats {
149 pub total_step_time: Duration,
151 pub avg_step_time: Duration,
153 pub step_count: usize,
155 pub memory_usage: Option<MemoryStats>,
157}
158
159impl PerformanceStats {
160 pub fn new() -> Self {
161 Self {
162 total_step_time: Duration::new(0, 0),
163 avg_step_time: Duration::new(0, 0),
164 step_count: 0,
165 memory_usage: None,
166 }
167 }
168
169 pub fn record_step_time(&mut self, duration: Duration) {
171 self.total_step_time += duration;
172 self.step_count += 1;
173 self.avg_step_time = self.total_step_time / self.step_count as u32;
174 }
175}
176
177impl Default for PerformanceStats {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct MemoryStats {
186 pub gpu_memory_bytes: usize,
188 pub cpu_memory_bytes: usize,
190 pub peak_memory_bytes: usize,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct OptimizerMetrics {
197 pub step: usize,
199 pub learning_rate: MetricStats,
201 pub gradient_norm: MetricStats,
203 pub parameter_change_norm: MetricStats,
205 pub loss: MetricStats,
207 pub performance: PerformanceStats,
209 pub parameter_gradient_norms: HashMap<String, MetricStats>,
211 pub convergence_indicators: ConvergenceIndicators,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct ConvergenceIndicators {
218 pub loss_plateaued: bool,
220 pub gradients_vanishing: bool,
222 pub gradients_exploding: bool,
224 pub oscillating: bool,
226 pub convergence_rate: f32,
228}
229
230impl ConvergenceIndicators {
231 pub fn new() -> Self {
232 Self {
233 loss_plateaued: false,
234 gradients_vanishing: false,
235 gradients_exploding: false,
236 oscillating: false,
237 convergence_rate: 0.0,
238 }
239 }
240}
241
242impl Default for ConvergenceIndicators {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248#[derive(Debug)]
250pub struct OptimizerMonitor {
251 config: MonitoringConfig,
252 metrics: OptimizerMetrics,
253 step_start_time: Option<Instant>,
254 previous_parameters: Option<Vec<Tensor>>,
255}
256
257impl OptimizerMonitor {
258 pub fn new(config: MonitoringConfig) -> Self {
260 Self {
261 metrics: OptimizerMetrics {
262 step: 0,
263 learning_rate: MetricStats::new(config.history_window),
264 gradient_norm: MetricStats::new(config.history_window),
265 parameter_change_norm: MetricStats::new(config.history_window),
266 loss: MetricStats::new(config.history_window),
267 performance: PerformanceStats::new(),
268 parameter_gradient_norms: HashMap::new(),
269 convergence_indicators: ConvergenceIndicators::new(),
270 },
271 config,
272 step_start_time: None,
273 previous_parameters: None,
274 }
275 }
276
277 pub fn with_defaults() -> Self {
279 Self::new(MonitoringConfig::default())
280 }
281
282 pub fn before_step(&mut self) {
284 if self.config.track_performance {
285 self.step_start_time = Some(Instant::now());
286 }
287 }
288
289 pub fn after_step(
291 &mut self,
292 learning_rate: f32,
293 parameters: &[Tensor],
294 loss: Option<f32>,
295 ) -> Result<()> {
296 self.metrics.step += 1;
297
298 if let Some(start_time) = self.step_start_time.take() {
300 let duration = start_time.elapsed();
301 self.metrics.performance.record_step_time(duration);
302 }
303
304 if self.config.track_learning_rates {
306 self.metrics.learning_rate.update(learning_rate, self.config.history_window);
307 }
308
309 if self.config.track_gradient_norms {
311 let total_grad_norm = self.compute_total_gradient_norm(parameters)?;
312 self.metrics.gradient_norm.update(total_grad_norm, self.config.history_window);
313
314 for (i, param) in parameters.iter().enumerate() {
316 if let Ok(grad) = param.grad() {
317 let param_name = format!("param_{}", i);
318 let grad_norm = grad.norm()?;
319
320 let param_stats = self
321 .metrics
322 .parameter_gradient_norms
323 .entry(param_name)
324 .or_insert_with(|| MetricStats::new(self.config.history_window));
325 param_stats.update(grad_norm, self.config.history_window);
326 }
327 }
328 }
329
330 if self.config.track_parameter_changes {
332 if let Some(prev_params) = &self.previous_parameters {
333 let change_norm = self.compute_parameter_change_norm(parameters, prev_params)?;
334 self.metrics
335 .parameter_change_norm
336 .update(change_norm, self.config.history_window);
337 }
338 self.previous_parameters = Some(parameters.to_vec());
339 }
340
341 if let Some(loss_value) = loss {
343 self.metrics.loss.update(loss_value, self.config.history_window);
344 }
345
346 if self.config.track_convergence {
348 self.update_convergence_indicators();
349 }
350
351 Ok(())
352 }
353
354 pub fn update_loss(&mut self, loss: f32) {
356 self.metrics.loss.update(loss, self.config.history_window);
357 if self.config.track_convergence {
358 self.update_convergence_indicators();
359 }
360 }
361
362 pub fn get_metrics(&self) -> &OptimizerMetrics {
364 &self.metrics
365 }
366
367 pub fn should_log(&self) -> bool {
369 self.metrics.step % self.config.log_frequency == 0
370 }
371
372 pub fn get_summary_report(&self) -> String {
374 format!(
375 "Step {}: LR={:.6}, GradNorm={:.6}±{:.6}, ParamChange={:.6}, Loss={:.6} (trend: {:.6})",
376 self.metrics.step,
377 self.metrics.learning_rate.current,
378 self.metrics.gradient_norm.mean,
379 self.metrics.gradient_norm.std,
380 self.metrics.parameter_change_norm.current,
381 self.metrics.loss.current,
382 self.metrics.loss.trend
383 )
384 }
385
386 pub fn get_convergence_report(&self) -> String {
388 let indicators = &self.metrics.convergence_indicators;
389 format!(
390 "Convergence Status: Loss Plateaued: {}, Gradients Vanishing: {}, Gradients Exploding: {}, Oscillating: {}, Rate: {:.6}",
391 indicators.loss_plateaued,
392 indicators.gradients_vanishing,
393 indicators.gradients_exploding,
394 indicators.oscillating,
395 indicators.convergence_rate
396 )
397 }
398
399 pub fn reset(&mut self) {
401 self.metrics = OptimizerMetrics {
402 step: 0,
403 learning_rate: MetricStats::new(self.config.history_window),
404 gradient_norm: MetricStats::new(self.config.history_window),
405 parameter_change_norm: MetricStats::new(self.config.history_window),
406 loss: MetricStats::new(self.config.history_window),
407 performance: PerformanceStats::new(),
408 parameter_gradient_norms: HashMap::new(),
409 convergence_indicators: ConvergenceIndicators::new(),
410 };
411 self.previous_parameters = None;
412 self.step_start_time = None;
413 }
414
415 fn compute_total_gradient_norm(&self, parameters: &[Tensor]) -> Result<f32> {
416 let mut total_norm_sq = 0.0;
417 for param in parameters {
418 if let Ok(grad) = param.grad() {
419 let norm_sq = grad.norm_squared()?.to_scalar()?;
420 total_norm_sq += norm_sq;
421 }
422 }
423 Ok(total_norm_sq.sqrt())
424 }
425
426 fn compute_parameter_change_norm(
427 &self,
428 current: &[Tensor],
429 previous: &[Tensor],
430 ) -> Result<f32> {
431 if current.len() != previous.len() {
432 return Err(anyhow!("Parameter count mismatch"));
433 }
434
435 let mut total_change_sq = 0.0;
436 for (curr, prev) in current.iter().zip(previous.iter()) {
437 let diff = curr.sub(prev)?;
438 let norm_sq = diff.norm_squared()?.to_scalar()?;
439 total_change_sq += norm_sq;
440 }
441 Ok(total_change_sq.sqrt())
442 }
443
444 fn update_convergence_indicators(&mut self) {
445 let indicators = &mut self.metrics.convergence_indicators;
446
447 indicators.loss_plateaued = self.metrics.loss.is_plateaued(1e-6, 1e-6);
449
450 indicators.gradients_vanishing = self.metrics.gradient_norm.current < 1e-8;
452
453 indicators.gradients_exploding = self.metrics.gradient_norm.current > 100.0;
455
456 indicators.oscillating = self.metrics.loss.std > self.metrics.loss.mean * 0.1;
458
459 indicators.convergence_rate = -self.metrics.loss.trend; }
462}
463
464#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct HyperparameterSensitivityConfig {
467 pub enabled: bool,
469 pub perturbation_magnitude: f32,
471 pub analysis_window: usize,
473 pub min_samples: usize,
475 pub analyze_learning_rate: bool,
477 pub analyze_momentum: bool,
479 pub analyze_weight_decay: bool,
481 pub analyze_epsilon: bool,
483 pub analysis_frequency: usize,
485}
486
487impl Default for HyperparameterSensitivityConfig {
488 fn default() -> Self {
489 Self {
490 enabled: true,
491 perturbation_magnitude: 0.01, analysis_window: 50,
493 min_samples: 10,
494 analyze_learning_rate: true,
495 analyze_momentum: true,
496 analyze_weight_decay: true,
497 analyze_epsilon: false,
498 analysis_frequency: 25,
499 }
500 }
501}
502
503#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct HyperparameterSensitivityMetrics {
506 pub name: String,
508 pub current_sensitivity: f32,
510 pub sensitivity_history: VecDeque<f32>,
512 pub mean_sensitivity: f32,
514 pub std_sensitivity: f32,
516 pub normalized_sensitivity: f32,
518 pub importance_score: f32,
520}
521
522impl HyperparameterSensitivityMetrics {
523 pub fn new(name: String, window_size: usize) -> Self {
524 Self {
525 name,
526 current_sensitivity: 0.0,
527 sensitivity_history: VecDeque::with_capacity(window_size),
528 mean_sensitivity: 0.0,
529 std_sensitivity: 0.0,
530 normalized_sensitivity: 0.0,
531 importance_score: 0.0,
532 }
533 }
534
535 pub fn update(&mut self, sensitivity: f32, hyperparameter_value: f32, window_size: usize) {
537 self.current_sensitivity = sensitivity;
538 self.sensitivity_history.push_back(sensitivity);
539
540 if self.sensitivity_history.len() > window_size {
541 self.sensitivity_history.pop_front();
542 }
543
544 self.compute_statistics(hyperparameter_value);
545 }
546
547 fn compute_statistics(&mut self, hyperparameter_value: f32) {
548 if self.sensitivity_history.is_empty() {
549 return;
550 }
551
552 let sum: f32 = self.sensitivity_history.iter().sum();
554 self.mean_sensitivity = sum / self.sensitivity_history.len() as f32;
555
556 let variance: f32 = self
557 .sensitivity_history
558 .iter()
559 .map(|x| (x - self.mean_sensitivity).powi(2))
560 .sum::<f32>()
561 / self.sensitivity_history.len() as f32;
562 self.std_sensitivity = variance.sqrt();
563
564 if hyperparameter_value.abs() > 1e-8 {
566 self.normalized_sensitivity = self.current_sensitivity / hyperparameter_value;
567 } else {
568 self.normalized_sensitivity = 0.0;
569 }
570
571 let magnitude_score = self.normalized_sensitivity.abs().tanh(); let stability_score = (-self.std_sensitivity.abs()).exp(); self.importance_score = magnitude_score * stability_score;
575 }
576
577 pub fn is_highly_sensitive(&self, threshold: f32) -> bool {
579 self.importance_score > threshold
580 }
581
582 pub fn is_stable(&self, variance_threshold: f32) -> bool {
584 self.std_sensitivity < variance_threshold
585 }
586}
587
588#[derive(Debug)]
590pub struct HyperparameterSensitivity {
591 config: HyperparameterSensitivityConfig,
592 sensitivity_metrics: HashMap<String, HyperparameterSensitivityMetrics>,
593 baseline_loss: Option<f32>,
594 perturbation_losses: HashMap<String, f32>,
595 step_count: usize,
596}
597
598impl HyperparameterSensitivity {
599 pub fn new(config: HyperparameterSensitivityConfig) -> Self {
601 Self {
602 config,
603 sensitivity_metrics: HashMap::new(),
604 baseline_loss: None,
605 perturbation_losses: HashMap::new(),
606 step_count: 0,
607 }
608 }
609
610 pub fn with_defaults() -> Self {
612 Self::new(HyperparameterSensitivityConfig::default())
613 }
614
615 pub fn record_baseline_loss(&mut self, loss: f32) {
617 self.baseline_loss = Some(loss);
618 }
619
620 pub fn record_perturbation_loss(&mut self, hyperparameter_name: String, loss: f32) {
622 self.perturbation_losses.insert(hyperparameter_name, loss);
623 }
624
625 pub fn compute_sensitivity(
627 &mut self,
628 hyperparameter_name: &str,
629 hyperparameter_value: f32,
630 perturbed_value: f32,
631 loss_change: f32,
632 ) -> f32 {
633 let param_change = perturbed_value - hyperparameter_value;
634
635 if param_change.abs() < 1e-12 {
637 return 0.0;
638 }
639
640 let sensitivity = loss_change / param_change;
642
643 let metrics = self
645 .sensitivity_metrics
646 .entry(hyperparameter_name.to_string())
647 .or_insert_with(|| {
648 HyperparameterSensitivityMetrics::new(
649 hyperparameter_name.to_string(),
650 self.config.analysis_window,
651 )
652 });
653
654 metrics.update(
655 sensitivity,
656 hyperparameter_value,
657 self.config.analysis_window,
658 );
659
660 sensitivity
661 }
662
663 pub fn analyze_learning_rate_sensitivity(
665 &mut self,
666 current_lr: f32,
667 baseline_loss: f32,
668 perturbed_loss: f32,
669 ) -> f32 {
670 let perturbed_lr = current_lr * (1.0 + self.config.perturbation_magnitude);
671 let loss_change = perturbed_loss - baseline_loss;
672
673 self.compute_sensitivity("learning_rate", current_lr, perturbed_lr, loss_change)
674 }
675
676 pub fn analyze_momentum_sensitivity(
678 &mut self,
679 current_momentum: f32,
680 baseline_loss: f32,
681 perturbed_loss: f32,
682 ) -> f32 {
683 let perturbed_momentum = current_momentum * (1.0 + self.config.perturbation_magnitude);
684 let loss_change = perturbed_loss - baseline_loss;
685
686 self.compute_sensitivity(
687 "momentum",
688 current_momentum,
689 perturbed_momentum,
690 loss_change,
691 )
692 }
693
694 pub fn analyze_weight_decay_sensitivity(
696 &mut self,
697 current_weight_decay: f32,
698 baseline_loss: f32,
699 perturbed_loss: f32,
700 ) -> f32 {
701 let perturbed_weight_decay =
702 current_weight_decay * (1.0 + self.config.perturbation_magnitude);
703 let loss_change = perturbed_loss - baseline_loss;
704
705 self.compute_sensitivity(
706 "weight_decay",
707 current_weight_decay,
708 perturbed_weight_decay,
709 loss_change,
710 )
711 }
712
713 pub fn analyze_epsilon_sensitivity(
715 &mut self,
716 current_epsilon: f32,
717 baseline_loss: f32,
718 perturbed_loss: f32,
719 ) -> f32 {
720 let perturbed_epsilon = current_epsilon * (1.0 + self.config.perturbation_magnitude);
721 let loss_change = perturbed_loss - baseline_loss;
722
723 self.compute_sensitivity("epsilon", current_epsilon, perturbed_epsilon, loss_change)
724 }
725
726 pub fn get_sensitivity_metrics(
728 &self,
729 hyperparameter: &str,
730 ) -> Option<&HyperparameterSensitivityMetrics> {
731 self.sensitivity_metrics.get(hyperparameter)
732 }
733
734 pub fn get_all_sensitivity_metrics(
736 &self,
737 ) -> &HashMap<String, HyperparameterSensitivityMetrics> {
738 &self.sensitivity_metrics
739 }
740
741 pub fn get_most_sensitive_hyperparameters(
743 &self,
744 ) -> Vec<(&String, &HyperparameterSensitivityMetrics)> {
745 let mut sorted: Vec<_> = self.sensitivity_metrics.iter().collect();
746 sorted.sort_by(|a, b| {
747 b.1.importance_score
748 .partial_cmp(&a.1.importance_score)
749 .unwrap_or(std::cmp::Ordering::Equal)
750 });
751 sorted
752 }
753
754 pub fn should_analyze(&self) -> bool {
756 self.config.enabled
757 && self.step_count % self.config.analysis_frequency == 0
758 && self.step_count >= self.config.min_samples
759 }
760
761 pub fn step(&mut self) {
763 self.step_count += 1;
764 }
765
766 pub fn get_sensitivity_report(&self) -> String {
768 let mut report = String::from("Hyperparameter Sensitivity Analysis:\n");
769
770 let sorted_metrics = self.get_most_sensitive_hyperparameters();
771
772 for (name, metrics) in sorted_metrics.iter().take(5) {
773 report.push_str(&format!(
775 " {}: Sensitivity={:.6}, Normalized={:.6}, Importance={:.3} ({})\n",
776 name,
777 metrics.current_sensitivity,
778 metrics.normalized_sensitivity,
779 metrics.importance_score,
780 if metrics.is_highly_sensitive(0.5) { "HIGH" } else { "LOW" }
781 ));
782 }
783
784 if sorted_metrics.is_empty() {
785 report.push_str(" No sensitivity data available yet.\n");
786 }
787
788 report
789 }
790
791 pub fn get_recommendations(&self) -> Vec<String> {
793 let mut recommendations = Vec::new();
794
795 for (name, metrics) in &self.sensitivity_metrics {
796 if metrics.is_highly_sensitive(0.7) {
797 recommendations.push(format!(
798 "Consider careful tuning of {}: high sensitivity detected (score: {:.3})",
799 name, metrics.importance_score
800 ));
801 }
802
803 if !metrics.is_stable(0.1) {
804 recommendations.push(format!(
805 "Consider stabilizing {}: sensitivity varies significantly (std: {:.6})",
806 name, metrics.std_sensitivity
807 ));
808 }
809 }
810
811 if recommendations.is_empty() {
812 recommendations.push(
813 "All hyperparameters appear to have reasonable sensitivity profiles.".to_string(),
814 );
815 }
816
817 recommendations
818 }
819
820 pub fn reset(&mut self) {
822 self.sensitivity_metrics.clear();
823 self.baseline_loss = None;
824 self.perturbation_losses.clear();
825 self.step_count = 0;
826 }
827}
828
829#[cfg(test)]
830mod tests {
831 use super::*;
832
833 #[test]
834 fn test_metric_stats_creation() {
835 let stats = MetricStats::new(10);
836 assert_eq!(stats.values.capacity(), 10);
837 assert_eq!(stats.current, 0.0);
838 assert_eq!(stats.mean, 0.0);
839 }
840
841 #[test]
842 fn test_metric_stats_update() {
843 let mut stats = MetricStats::new(3);
844
845 stats.update(1.0, 3);
846 assert_eq!(stats.current, 1.0);
847 assert_eq!(stats.mean, 1.0);
848
849 stats.update(2.0, 3);
850 assert_eq!(stats.current, 2.0);
851 assert_eq!(stats.mean, 1.5);
852
853 stats.update(3.0, 3);
854 assert_eq!(stats.current, 3.0);
855 assert_eq!(stats.mean, 2.0);
856
857 stats.update(4.0, 3);
859 assert_eq!(stats.values.len(), 3);
860 assert_eq!(stats.mean, 3.0); }
862
863 #[test]
864 fn test_metric_stats_trend() {
865 let mut stats = MetricStats::new(10);
866
867 for i in 1..=5 {
869 stats.update(i as f32, 10);
870 }
871
872 assert!(stats.trend > 0.0);
874 assert!(stats.is_increasing(0.5));
875 assert!(!stats.is_decreasing(0.5));
876 }
877
878 #[test]
879 fn test_performance_stats() {
880 let mut perf = PerformanceStats::new();
881 assert_eq!(perf.step_count, 0);
882
883 perf.record_step_time(Duration::from_millis(100));
884 assert_eq!(perf.step_count, 1);
885 assert_eq!(perf.avg_step_time, Duration::from_millis(100));
886
887 perf.record_step_time(Duration::from_millis(200));
888 assert_eq!(perf.step_count, 2);
889 assert_eq!(perf.avg_step_time, Duration::from_millis(150));
890 }
891
892 #[test]
893 fn test_convergence_indicators() {
894 let indicators = ConvergenceIndicators::new();
895 assert!(!indicators.loss_plateaued);
896 assert!(!indicators.gradients_vanishing);
897 assert!(!indicators.gradients_exploding);
898 assert!(!indicators.oscillating);
899 assert_eq!(indicators.convergence_rate, 0.0);
900 }
901
902 #[test]
903 fn test_optimizer_monitor_creation() {
904 let monitor = OptimizerMonitor::with_defaults();
905 assert_eq!(monitor.metrics.step, 0);
906 assert!(monitor.previous_parameters.is_none());
907 }
908
909 #[test]
910 fn test_monitor_should_log() {
911 let mut monitor = OptimizerMonitor::with_defaults();
912
913 assert!(monitor.should_log());
915
916 monitor.metrics.step = 5;
917 assert!(!monitor.should_log()); monitor.metrics.step = 10;
920 assert!(monitor.should_log());
921 }
922
923 #[test]
924 fn test_hyperparameter_sensitivity_config() {
925 let config = HyperparameterSensitivityConfig::default();
926 assert!(config.enabled);
927 assert_eq!(config.perturbation_magnitude, 0.01);
928 assert_eq!(config.analysis_window, 50);
929 assert_eq!(config.min_samples, 10);
930 assert!(config.analyze_learning_rate);
931 assert!(config.analyze_momentum);
932 assert!(config.analyze_weight_decay);
933 assert!(!config.analyze_epsilon);
934 assert_eq!(config.analysis_frequency, 25);
935 }
936
937 #[test]
938 fn test_hyperparameter_sensitivity_metrics() {
939 let mut metrics = HyperparameterSensitivityMetrics::new("learning_rate".to_string(), 10);
940 assert_eq!(metrics.name, "learning_rate");
941 assert_eq!(metrics.current_sensitivity, 0.0);
942 assert_eq!(metrics.importance_score, 0.0);
943
944 metrics.update(0.5, 0.01, 10); assert_eq!(metrics.current_sensitivity, 0.5);
947 assert_eq!(metrics.normalized_sensitivity, 0.5 / 0.01);
948 assert!(metrics.importance_score > 0.0);
949
950 metrics.update(0.3, 0.01, 10);
951 assert_eq!(metrics.current_sensitivity, 0.3);
952 assert_eq!(metrics.sensitivity_history.len(), 2);
953
954 assert_eq!(metrics.mean_sensitivity, 0.4); }
957
958 #[test]
959 fn test_hyperparameter_sensitivity_analyzer() {
960 let mut analyzer = HyperparameterSensitivity::with_defaults();
961
962 analyzer.record_baseline_loss(1.0);
964 assert_eq!(analyzer.baseline_loss, Some(1.0));
965
966 analyzer.record_perturbation_loss("learning_rate".to_string(), 1.1);
968 assert_eq!(
969 analyzer.perturbation_losses.get("learning_rate"),
970 Some(&1.1)
971 );
972
973 let sensitivity = analyzer.compute_sensitivity("learning_rate", 0.01, 0.0101, 0.1);
975 let expected = 0.1 / 0.0001; assert!(
977 (sensitivity - expected).abs() < 0.01,
978 "Expected {}, got {}",
979 expected,
980 sensitivity
981 );
982
983 assert!(analyzer.sensitivity_metrics.contains_key("learning_rate"));
985 }
986
987 #[test]
988 fn test_sensitivity_analysis_methods() {
989 let mut analyzer = HyperparameterSensitivity::with_defaults();
990
991 let lr_sensitivity = analyzer.analyze_learning_rate_sensitivity(0.01, 1.0, 1.1);
993 assert!(lr_sensitivity > 0.0);
994 assert!(analyzer.sensitivity_metrics.contains_key("learning_rate"));
995
996 let momentum_sensitivity = analyzer.analyze_momentum_sensitivity(0.9, 1.0, 0.95);
998 assert!(momentum_sensitivity < 0.0); assert!(analyzer.sensitivity_metrics.contains_key("momentum"));
1000
1001 let wd_sensitivity = analyzer.analyze_weight_decay_sensitivity(0.01, 1.0, 1.05);
1003 assert!(wd_sensitivity > 0.0);
1004 assert!(analyzer.sensitivity_metrics.contains_key("weight_decay"));
1005 }
1006
1007 #[test]
1008 fn test_sensitivity_should_analyze() {
1009 let mut analyzer = HyperparameterSensitivity::with_defaults();
1010
1011 assert!(!analyzer.should_analyze());
1013
1014 for _ in 0..10 {
1016 analyzer.step();
1017 }
1018 assert!(!analyzer.should_analyze()); for _ in 0..15 {
1022 analyzer.step();
1023 }
1024 assert!(analyzer.should_analyze()); }
1026
1027 #[test]
1028 fn test_sensitivity_report_generation() {
1029 let mut analyzer = HyperparameterSensitivity::with_defaults();
1030
1031 analyzer.compute_sensitivity("learning_rate", 0.01, 0.0101, 0.1);
1033 analyzer.compute_sensitivity("momentum", 0.9, 0.909, -0.05);
1034
1035 let report = analyzer.get_sensitivity_report();
1036 assert!(report.contains("Hyperparameter Sensitivity Analysis"));
1037 assert!(report.contains("learning_rate"));
1038 assert!(report.contains("momentum"));
1039
1040 let recommendations = analyzer.get_recommendations();
1041 assert!(!recommendations.is_empty());
1042 }
1043
1044 #[test]
1045 fn test_sensitivity_most_sensitive_hyperparameters() {
1046 let mut analyzer = HyperparameterSensitivity::with_defaults();
1047
1048 analyzer.compute_sensitivity("learning_rate", 0.01, 0.0101, 0.2); analyzer.compute_sensitivity("momentum", 0.9, 0.909, 0.01); analyzer.compute_sensitivity("weight_decay", 0.01, 0.0101, 0.15); let most_sensitive = analyzer.get_most_sensitive_hyperparameters();
1054 assert_eq!(most_sensitive.len(), 3);
1055
1056 let first_importance = most_sensitive[0].1.importance_score;
1058 let second_importance = most_sensitive[1].1.importance_score;
1059 assert!(first_importance >= second_importance);
1060 }
1061
1062 #[test]
1063 fn test_sensitivity_reset() {
1064 let mut analyzer = HyperparameterSensitivity::with_defaults();
1065
1066 analyzer.record_baseline_loss(1.0);
1068 analyzer.record_perturbation_loss("learning_rate".to_string(), 1.1);
1069 analyzer.compute_sensitivity("learning_rate", 0.01, 0.0101, 0.1);
1070 analyzer.step();
1071
1072 assert!(analyzer.baseline_loss.is_some());
1074 assert!(!analyzer.perturbation_losses.is_empty());
1075 assert!(!analyzer.sensitivity_metrics.is_empty());
1076 assert_eq!(analyzer.step_count, 1);
1077
1078 analyzer.reset();
1080 assert!(analyzer.baseline_loss.is_none());
1081 assert!(analyzer.perturbation_losses.is_empty());
1082 assert!(analyzer.sensitivity_metrics.is_empty());
1083 assert_eq!(analyzer.step_count, 0);
1084 }
1085}
1086
1087#[derive(Debug, Clone)]
1092pub struct OptimizerSelector {
1093 pub model_size: usize,
1095 pub time_sensitive: bool,
1097 pub memory_constrained: bool,
1099 pub fast_convergence: bool,
1101 pub robustness_priority: bool,
1103 pub advanced_features: bool,
1105}
1106
1107#[derive(Debug, Clone)]
1109pub struct OptimizerRecommendation {
1110 pub name: String,
1111 pub description: String,
1112 pub performance_tier: PerformanceTier,
1113 pub convergence_speed: ConvergenceSpeed,
1114 pub memory_usage: MemoryUsage,
1115 pub use_cases: Vec<String>,
1116 pub estimated_overhead: f32, }
1118
1119#[derive(Debug, Clone)]
1120pub enum PerformanceTier {
1121 Fastest, Moderate, Advanced, }
1125
1126#[derive(Debug, Clone)]
1127pub enum ConvergenceSpeed {
1128 Fast,
1129 Moderate,
1130 Superior, }
1132
1133#[derive(Debug, Clone)]
1134pub enum MemoryUsage {
1135 Low, Standard, High, }
1139
1140impl OptimizerSelector {
1141 pub fn new(model_size: usize) -> Self {
1142 Self {
1143 model_size,
1144 time_sensitive: false,
1145 memory_constrained: false,
1146 fast_convergence: false,
1147 robustness_priority: false,
1148 advanced_features: false,
1149 }
1150 }
1151
1152 pub fn time_sensitive(mut self, sensitive: bool) -> Self {
1153 self.time_sensitive = sensitive;
1154 self
1155 }
1156
1157 pub fn memory_constrained(mut self, constrained: bool) -> Self {
1158 self.memory_constrained = constrained;
1159 self
1160 }
1161
1162 pub fn fast_convergence(mut self, fast: bool) -> Self {
1163 self.fast_convergence = fast;
1164 self
1165 }
1166
1167 pub fn robustness_priority(mut self, robust: bool) -> Self {
1168 self.robustness_priority = robust;
1169 self
1170 }
1171
1172 pub fn advanced_features(mut self, advanced: bool) -> Self {
1173 self.advanced_features = advanced;
1174 self
1175 }
1176
1177 pub fn get_recommendations(&self) -> Vec<OptimizerRecommendation> {
1179 let mut recommendations = self.generate_all_recommendations();
1180 self.rank_recommendations(&mut recommendations);
1181 recommendations
1182 }
1183
1184 fn generate_all_recommendations(&self) -> Vec<OptimizerRecommendation> {
1186 vec![
1187 OptimizerRecommendation {
1188 name: "AdamW".to_string(),
1189 description: "Decoupled weight decay Adam - excellent all-around optimizer"
1190 .to_string(),
1191 performance_tier: PerformanceTier::Fastest,
1192 convergence_speed: ConvergenceSpeed::Fast,
1193 memory_usage: MemoryUsage::Standard,
1194 use_cases: vec![
1195 "General purpose training".to_string(),
1196 "Large language models".to_string(),
1197 "Computer vision".to_string(),
1198 "Production training".to_string(),
1199 ],
1200 estimated_overhead: 1.0, },
1202 OptimizerRecommendation {
1203 name: "Adam".to_string(),
1204 description: "Classic adaptive moment estimation optimizer".to_string(),
1205 performance_tier: PerformanceTier::Fastest,
1206 convergence_speed: ConvergenceSpeed::Fast,
1207 memory_usage: MemoryUsage::Standard,
1208 use_cases: vec![
1209 "General purpose training".to_string(),
1210 "Research and experimentation".to_string(),
1211 "Quick prototyping".to_string(),
1212 ],
1213 estimated_overhead: 1.05, },
1215 OptimizerRecommendation {
1216 name: "SGD".to_string(),
1217 description: "Stochastic gradient descent with momentum - simple and effective"
1218 .to_string(),
1219 performance_tier: PerformanceTier::Fastest,
1220 convergence_speed: ConvergenceSpeed::Moderate,
1221 memory_usage: MemoryUsage::Low,
1222 use_cases: vec![
1223 "Memory-constrained training".to_string(),
1224 "Simple models".to_string(),
1225 "Fine-tuning".to_string(),
1226 "Educational purposes".to_string(),
1227 ],
1228 estimated_overhead: 1.1, },
1230 OptimizerRecommendation {
1231 name: "HN-Adam".to_string(),
1232 description: "Hybrid Norm Adam with adaptive step size based on parameter norms"
1233 .to_string(),
1234 performance_tier: PerformanceTier::Moderate,
1235 convergence_speed: ConvergenceSpeed::Superior,
1236 memory_usage: MemoryUsage::Standard,
1237 use_cases: vec![
1238 "Transformer training".to_string(),
1239 "Computer vision tasks".to_string(),
1240 "When adaptive learning rates are needed".to_string(),
1241 "Research requiring latest optimization techniques".to_string(),
1242 ],
1243 estimated_overhead: 2.5, },
1245 OptimizerRecommendation {
1246 name: "BGE-Adam".to_string(),
1247 description: "Entropy-weighted Adam with adaptive gradient strategies".to_string(),
1248 performance_tier: PerformanceTier::Advanced,
1249 convergence_speed: ConvergenceSpeed::Superior,
1250 memory_usage: MemoryUsage::High,
1251 use_cases: vec![
1252 "Research and experimentation".to_string(),
1253 "Complex training scenarios".to_string(),
1254 "When robustness is critical".to_string(),
1255 "Handling diverse gradient conditions".to_string(),
1256 ],
1257 estimated_overhead: 13.0, },
1259 ]
1260 }
1261
1262 fn rank_recommendations(&self, recommendations: &mut [OptimizerRecommendation]) {
1264 recommendations.sort_by(|a, b| {
1265 let score_a = self.calculate_suitability_score(a);
1266 let score_b = self.calculate_suitability_score(b);
1267 score_b.partial_cmp(&score_a).unwrap()
1268 });
1269 }
1270
1271 fn calculate_suitability_score(&self, rec: &OptimizerRecommendation) -> f32 {
1273 let mut score = 0.0;
1274
1275 if self.time_sensitive {
1277 score += match rec.performance_tier {
1278 PerformanceTier::Fastest => 10.0,
1279 PerformanceTier::Moderate => 5.0,
1280 PerformanceTier::Advanced => 1.0,
1281 };
1282 }
1283
1284 if self.memory_constrained {
1286 score += match rec.memory_usage {
1287 MemoryUsage::Low => 10.0,
1288 MemoryUsage::Standard => 5.0,
1289 MemoryUsage::High => 1.0,
1290 };
1291 }
1292
1293 if self.fast_convergence {
1295 score += match rec.convergence_speed {
1296 ConvergenceSpeed::Superior => 10.0,
1297 ConvergenceSpeed::Fast => 7.0,
1298 ConvergenceSpeed::Moderate => 3.0,
1299 };
1300 }
1301
1302 if self.robustness_priority {
1304 match rec.name.as_str() {
1305 "BGE-Adam" => score += 10.0, "HN-Adam" => score += 7.0, "AdamW" => score += 5.0, _ => score += 3.0,
1309 }
1310 }
1311
1312 if self.advanced_features {
1314 match rec.name.as_str() {
1315 "BGE-Adam" => score += 10.0, "HN-Adam" => score += 8.0, _ => score += 2.0,
1318 }
1319 }
1320
1321 if self.model_size > 1_000_000 {
1323 match rec.name.as_str() {
1325 "AdamW" | "Adam" => score += 5.0,
1326 "HN-Adam" => score += 3.0,
1327 _ => score += 1.0,
1328 }
1329 }
1330
1331 match rec.name.as_str() {
1333 "AdamW" => score += 8.0, "Adam" => score += 7.0, "HN-Adam" => score += 6.0, "SGD" => score += 5.0, "BGE-Adam" => score += 4.0, _ => score += 2.0,
1339 }
1340
1341 score
1342 }
1343
1344 pub fn generate_report(&self) -> String {
1346 let recommendations = self.get_recommendations();
1347 let mut report = String::new();
1348
1349 report.push_str("🚀 TrustformeRS Optimizer Selection Report\n");
1350 report.push_str("=========================================\n\n");
1351
1352 report.push_str("📊 Model Configuration:\n");
1353 report.push_str(&format!(
1354 " • Model size: {} parameters\n",
1355 self.model_size
1356 ));
1357 report.push_str(&format!(" • Time sensitive: {}\n", self.time_sensitive));
1358 report.push_str(&format!(
1359 " • Memory constrained: {}\n",
1360 self.memory_constrained
1361 ));
1362 report.push_str(&format!(
1363 " • Fast convergence priority: {}\n",
1364 self.fast_convergence
1365 ));
1366 report.push_str(&format!(
1367 " • Robustness priority: {}\n",
1368 self.robustness_priority
1369 ));
1370 report.push_str(&format!(
1371 " • Advanced features: {}\n\n",
1372 self.advanced_features
1373 ));
1374
1375 report.push_str("🏆 Recommended Optimizers (ranked by suitability):\n\n");
1376
1377 for (i, rec) in recommendations.iter().enumerate() {
1378 let rank_emoji = match i {
1379 0 => "🥇",
1380 1 => "🥈",
1381 2 => "🥉",
1382 _ => "📊",
1383 };
1384
1385 report.push_str(&format!(
1386 "{} {} - {}\n",
1387 rank_emoji, rec.name, rec.description
1388 ));
1389 report.push_str(&format!(
1390 " Performance: {:?} | Convergence: {:?} | Memory: {:?}\n",
1391 rec.performance_tier, rec.convergence_speed, rec.memory_usage
1392 ));
1393 report.push_str(&format!(
1394 " Overhead: {:.1}x compared to baseline\n",
1395 rec.estimated_overhead
1396 ));
1397 report.push_str(&format!(" Use cases: {}\n\n", rec.use_cases.join(", ")));
1398 }
1399
1400 report.push_str("💡 Performance Insights from Latest Benchmarks:\n");
1401 report.push_str(" • AdamW: 238µs/iter (100K params) - Fast and reliable\n");
1402 report.push_str(" • Adam: 248µs/iter - Slightly slower than AdamW\n");
1403 report.push_str(" • SGD: 257µs/iter - Simple and memory efficient\n");
1404 report.push_str(" • HN-Adam: 633µs/iter - 2.5x slower, adaptive step sizes\n");
1405 report.push_str(" • BGE-Adam: 3.3ms/iter - 13x slower, entropy-based robustness\n\n");
1406
1407 report.push_str("🎯 Quick Selection Guide:\n");
1408 report.push_str(" • Production training: AdamW\n");
1409 report.push_str(" • Memory constrained: SGD\n");
1410 report.push_str(" • Research/experimentation: HN-Adam or BGE-Adam\n");
1411 report.push_str(" • Maximum robustness: BGE-Adam\n");
1412 report.push_str(" • Adaptive learning rates: HN-Adam\n");
1413
1414 report
1415 }
1416}
1417
1418#[cfg(test)]
1419mod optimizer_selection_tests {
1420 use super::*;
1421
1422 #[test]
1423 fn test_optimizer_selector_basic() {
1424 let selector = OptimizerSelector::new(10000);
1425 let recommendations = selector.get_recommendations();
1426 assert!(!recommendations.is_empty());
1427 assert_eq!(recommendations.len(), 5); }
1429
1430 #[test]
1431 fn test_time_sensitive_selection() {
1432 let selector = OptimizerSelector::new(10000).time_sensitive(true);
1433
1434 let recommendations = selector.get_recommendations();
1435 let top_rec = &recommendations[0];
1436
1437 assert!(matches!(top_rec.performance_tier, PerformanceTier::Fastest));
1439 assert!(top_rec.name == "AdamW" || top_rec.name == "Adam" || top_rec.name == "SGD");
1440 }
1441
1442 #[test]
1443 fn test_memory_constrained_selection() {
1444 let selector = OptimizerSelector::new(10000).memory_constrained(true);
1445
1446 let recommendations = selector.get_recommendations();
1447 let top_rec = &recommendations[0];
1448
1449 assert!(top_rec.name == "SGD" || matches!(top_rec.memory_usage, MemoryUsage::Low));
1451 }
1452
1453 #[test]
1454 fn test_robustness_priority_selection() {
1455 let selector = OptimizerSelector::new(10000).robustness_priority(true);
1456
1457 let recommendations = selector.get_recommendations();
1458 let top_rec = &recommendations[0];
1459
1460 assert!(top_rec.name == "BGE-Adam" || top_rec.name == "HN-Adam");
1462 }
1463
1464 #[test]
1465 fn test_advanced_features_selection() {
1466 let selector = OptimizerSelector::new(10000).advanced_features(true);
1467
1468 let recommendations = selector.get_recommendations();
1469 let top_rec = &recommendations[0];
1470
1471 assert!(top_rec.name == "BGE-Adam" || top_rec.name == "HN-Adam");
1473 }
1474
1475 #[test]
1476 fn test_report_generation() {
1477 let selector = OptimizerSelector::new(50000).time_sensitive(true).fast_convergence(true);
1478
1479 let report = selector.generate_report();
1480 assert!(report.contains("TrustformeRS Optimizer Selection Report"));
1481 assert!(report.contains("Model size: 50000"));
1482 assert!(report.contains("Time sensitive: true"));
1483 assert!(report.contains("🥇")); assert!(report.contains("Performance Insights"));
1485 }
1486}