1use crate::{OptimizerError, OptimizerResult, OptimizerState};
7use parking_lot::RwLock;
8use std::collections::{HashMap, VecDeque};
9use std::sync::Arc;
10use torsh_core::error::Result;
11use torsh_tensor::Tensor;
12
13pub struct OptimizerAnalyzer {
15 step_history: Vec<OptimizationStep>,
17 gradient_stats: GradientStatistics,
19 parameter_stats: ParameterStatistics,
21 convergence_tracker: ConvergenceTracker,
23 config: AnalyzerConfig,
25}
26
27#[derive(Debug, Clone)]
29pub struct AnalyzerConfig {
30 pub max_history_size: usize,
32 pub track_gradient_norms: bool,
34 pub track_parameter_norms: bool,
36 pub track_gradient_flow: bool,
38 pub moving_average_window: usize,
40}
41
42impl Default for AnalyzerConfig {
43 fn default() -> Self {
44 Self {
45 max_history_size: 10000,
46 track_gradient_norms: true,
47 track_parameter_norms: true,
48 track_gradient_flow: true,
49 moving_average_window: 100,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct OptimizationStep {
57 pub step: usize,
59 pub learning_rates: Vec<f32>,
61 pub gradient_norms: Vec<f32>,
63 pub parameter_norms: Vec<f32>,
65 pub update_norms: Vec<f32>,
67 pub loss: Option<f32>,
69 pub timestamp: std::time::Instant,
71}
72
73#[derive(Debug, Clone)]
75pub struct GradientStatistics {
76 pub norm_history: VecDeque<f32>,
78 pub average_norm: f32,
80 pub max_norm: f32,
82 pub min_norm: f32,
84 pub explosion_count: usize,
86 pub vanishing_count: usize,
88 pub recent_variance: f32,
90}
91
92#[derive(Debug, Clone)]
94pub struct ParameterStatistics {
95 pub norm_history: VecDeque<f32>,
97 pub update_ratios: VecDeque<f32>,
99 pub average_update_ratio: f32,
101 pub velocity: f32,
103 pub stability_score: f32,
105}
106
107#[derive(Debug, Clone)]
109pub struct ConvergenceTracker {
110 pub loss_history: VecDeque<f32>,
112 pub loss_moving_average: f32,
114 pub convergence_rate: f32,
116 pub is_converged: bool,
118 pub steps_since_improvement: usize,
120 pub best_loss: f32,
122 pub plateau_length: usize,
124}
125
126impl OptimizerAnalyzer {
127 pub fn new(config: Option<AnalyzerConfig>) -> Self {
129 let config = config.unwrap_or_default();
130 let max_size = config.moving_average_window;
131
132 Self {
133 step_history: Vec::new(),
134 gradient_stats: GradientStatistics {
135 norm_history: VecDeque::with_capacity(max_size),
136 average_norm: 0.0,
137 max_norm: 0.0,
138 min_norm: f32::INFINITY,
139 explosion_count: 0,
140 vanishing_count: 0,
141 recent_variance: 0.0,
142 },
143 parameter_stats: ParameterStatistics {
144 norm_history: VecDeque::with_capacity(max_size),
145 update_ratios: VecDeque::with_capacity(max_size),
146 average_update_ratio: 0.0,
147 velocity: 0.0,
148 stability_score: 1.0,
149 },
150 convergence_tracker: ConvergenceTracker {
151 loss_history: VecDeque::with_capacity(max_size),
152 loss_moving_average: 0.0,
153 convergence_rate: 0.0,
154 is_converged: false,
155 steps_since_improvement: 0,
156 best_loss: f32::INFINITY,
157 plateau_length: 0,
158 },
159 config,
160 }
161 }
162
163 pub fn analyze_step(
165 &mut self,
166 step: usize,
167 params: &[Arc<RwLock<Tensor>>],
168 learning_rates: &[f32],
169 loss: Option<f32>,
170 ) -> Result<()> {
171 let timestamp = std::time::Instant::now();
172
173 let mut gradient_norms = Vec::new();
175 let mut parameter_norms = Vec::new();
176 let mut update_norms = Vec::new();
177
178 for param in params {
179 let param_tensor = param.read();
180
181 let param_norm = param_tensor.norm()?;
183 parameter_norms.push(param_norm.item()?);
184
185 if let Some(grad) = param_tensor.grad() {
187 let grad_norm = grad.norm()?;
188 gradient_norms.push(grad_norm.item()?);
189
190 let lr = learning_rates.get(0).copied().unwrap_or(0.001);
192 update_norms.push(lr * grad_norm.item()?);
193 } else {
194 gradient_norms.push(0.0);
195 update_norms.push(0.0);
196 }
197 }
198
199 let opt_step = OptimizationStep {
201 step,
202 learning_rates: learning_rates.to_vec(),
203 gradient_norms: gradient_norms.clone(),
204 parameter_norms: parameter_norms.clone(),
205 update_norms: update_norms.clone(),
206 loss,
207 timestamp,
208 };
209
210 self.step_history.push(opt_step);
212 if self.step_history.len() > self.config.max_history_size {
213 self.step_history.remove(0);
214 }
215
216 if self.config.track_gradient_norms {
218 self.update_gradient_stats(&gradient_norms)?;
219 }
220
221 if self.config.track_parameter_norms {
223 self.update_parameter_stats(¶meter_norms, &update_norms)?;
224 }
225
226 if let Some(loss_val) = loss {
228 self.update_convergence_tracking(loss_val)?;
229 }
230
231 Ok(())
232 }
233
234 fn update_gradient_stats(&mut self, gradient_norms: &[f32]) -> Result<()> {
236 for &norm in gradient_norms {
237 self.gradient_stats.norm_history.push_back(norm);
239 if self.gradient_stats.norm_history.len() > self.config.moving_average_window {
240 self.gradient_stats.norm_history.pop_front();
241 }
242
243 self.gradient_stats.max_norm = self.gradient_stats.max_norm.max(norm);
245 self.gradient_stats.min_norm = self.gradient_stats.min_norm.min(norm);
246
247 if norm > 10.0 {
249 self.gradient_stats.explosion_count += 1;
250 }
251
252 if norm < 1e-7 {
254 self.gradient_stats.vanishing_count += 1;
255 }
256 }
257
258 if !self.gradient_stats.norm_history.is_empty() {
260 let sum: f32 = self.gradient_stats.norm_history.iter().sum();
261 self.gradient_stats.average_norm = sum / self.gradient_stats.norm_history.len() as f32;
262
263 let mean = self.gradient_stats.average_norm;
265 let variance: f32 = self
266 .gradient_stats
267 .norm_history
268 .iter()
269 .map(|&x| (x - mean).powi(2))
270 .sum::<f32>()
271 / self.gradient_stats.norm_history.len() as f32;
272 self.gradient_stats.recent_variance = variance;
273 }
274
275 Ok(())
276 }
277
278 fn update_parameter_stats(
280 &mut self,
281 parameter_norms: &[f32],
282 update_norms: &[f32],
283 ) -> Result<()> {
284 for (param_norm, update_norm) in parameter_norms.iter().zip(update_norms.iter()) {
285 self.parameter_stats.norm_history.push_back(*param_norm);
287 if self.parameter_stats.norm_history.len() > self.config.moving_average_window {
288 self.parameter_stats.norm_history.pop_front();
289 }
290
291 let ratio = if *param_norm > 1e-8 {
293 update_norm / param_norm
294 } else {
295 0.0
296 };
297
298 self.parameter_stats.update_ratios.push_back(ratio);
299 if self.parameter_stats.update_ratios.len() > self.config.moving_average_window {
300 self.parameter_stats.update_ratios.pop_front();
301 }
302 }
303
304 if !self.parameter_stats.update_ratios.is_empty() {
306 let sum: f32 = self.parameter_stats.update_ratios.iter().sum();
307 self.parameter_stats.average_update_ratio =
308 sum / self.parameter_stats.update_ratios.len() as f32;
309 }
310
311 if self.parameter_stats.norm_history.len() >= 2 {
313 let current = self
314 .parameter_stats
315 .norm_history
316 .back()
317 .expect("norm_history should not be empty");
318 let previous = self
319 .parameter_stats
320 .norm_history
321 .get(self.parameter_stats.norm_history.len() - 2)
322 .expect("second-to-last element should exist");
323 self.parameter_stats.velocity = (current - previous).abs();
324 }
325
326 if self.parameter_stats.update_ratios.len() > 1 {
328 let mean = self.parameter_stats.average_update_ratio;
329 let variance: f32 = self
330 .parameter_stats
331 .update_ratios
332 .iter()
333 .map(|&x| (x - mean).powi(2))
334 .sum::<f32>()
335 / self.parameter_stats.update_ratios.len() as f32;
336 self.parameter_stats.stability_score = 1.0 / (1.0 + variance);
337 }
338
339 Ok(())
340 }
341
342 fn update_convergence_tracking(&mut self, loss: f32) -> Result<()> {
344 self.convergence_tracker.loss_history.push_back(loss);
346 if self.convergence_tracker.loss_history.len() > self.config.moving_average_window {
347 self.convergence_tracker.loss_history.pop_front();
348 }
349
350 if !self.convergence_tracker.loss_history.is_empty() {
352 let sum: f32 = self.convergence_tracker.loss_history.iter().sum();
353 self.convergence_tracker.loss_moving_average =
354 sum / self.convergence_tracker.loss_history.len() as f32;
355 }
356
357 if loss < self.convergence_tracker.best_loss {
359 self.convergence_tracker.best_loss = loss;
360 self.convergence_tracker.steps_since_improvement = 0;
361 self.convergence_tracker.plateau_length = 0;
362 } else {
363 self.convergence_tracker.steps_since_improvement += 1;
364 self.convergence_tracker.plateau_length += 1;
365 }
366
367 if self.convergence_tracker.loss_history.len() >= 10 {
369 let recent_losses: Vec<f32> = self
370 .convergence_tracker
371 .loss_history
372 .iter()
373 .rev()
374 .take(10)
375 .cloned()
376 .collect();
377
378 let n = recent_losses.len() as f32;
380 let x_mean = (n - 1.0) / 2.0;
381 let y_mean = recent_losses.iter().sum::<f32>() / n;
382
383 let numerator: f32 = recent_losses
384 .iter()
385 .enumerate()
386 .map(|(i, &y)| (i as f32 - x_mean) * (y - y_mean))
387 .sum();
388
389 let denominator: f32 = (0..recent_losses.len())
390 .map(|i| (i as f32 - x_mean).powi(2))
391 .sum();
392
393 if denominator > 1e-8 {
394 self.convergence_tracker.convergence_rate = numerator / denominator;
395 }
396 }
397
398 self.convergence_tracker.is_converged = self.convergence_tracker.plateau_length > 1000
400 || (self.convergence_tracker.convergence_rate.abs() < 1e-6
401 && self.convergence_tracker.loss_history.len() > 100);
402
403 Ok(())
404 }
405
406 pub fn generate_report(&self) -> AnalysisReport {
408 AnalysisReport {
409 total_steps: self.step_history.len(),
410 gradient_stats: self.gradient_stats.clone(),
411 parameter_stats: self.parameter_stats.clone(),
412 convergence_tracker: self.convergence_tracker.clone(),
413 recommendations: self.generate_recommendations(),
414 }
415 }
416
417 fn generate_recommendations(&self) -> Vec<OptimizationRecommendation> {
419 let mut recommendations = Vec::new();
420
421 if self.gradient_stats.explosion_count > 10 {
423 recommendations.push(OptimizationRecommendation {
424 category: RecommendationCategory::GradientNorms,
425 severity: Severity::High,
426 message: "Frequent gradient explosions detected. Consider gradient clipping or reducing learning rate.".to_string(),
427 suggested_actions: vec![
428 "Add gradient clipping with max_norm=1.0".to_string(),
429 "Reduce learning rate by factor of 10".to_string(),
430 "Use adaptive optimizers like Adam".to_string(),
431 ],
432 });
433 }
434
435 if self.gradient_stats.vanishing_count > 10 {
437 recommendations.push(OptimizationRecommendation {
438 category: RecommendationCategory::GradientNorms,
439 severity: Severity::Medium,
440 message: "Frequent gradient vanishing detected. Model may be too deep or have saturation issues.".to_string(),
441 suggested_actions: vec![
442 "Check activation functions for saturation".to_string(),
443 "Consider batch normalization".to_string(),
444 "Use residual connections".to_string(),
445 ],
446 });
447 }
448
449 if self.parameter_stats.average_update_ratio > 0.1 {
451 recommendations.push(OptimizationRecommendation {
452 category: RecommendationCategory::LearningRate,
453 severity: Severity::Medium,
454 message: "Update ratios are high. Learning rate might be too large.".to_string(),
455 suggested_actions: vec![
456 "Reduce learning rate by factor of 2-5".to_string(),
457 "Use learning rate scheduling".to_string(),
458 ],
459 });
460 } else if self.parameter_stats.average_update_ratio < 0.001 {
461 recommendations.push(OptimizationRecommendation {
462 category: RecommendationCategory::LearningRate,
463 severity: Severity::Low,
464 message: "Update ratios are very small. Learning rate might be too small."
465 .to_string(),
466 suggested_actions: vec![
467 "Increase learning rate by factor of 2-10".to_string(),
468 "Consider warmup schedule".to_string(),
469 ],
470 });
471 }
472
473 if self.convergence_tracker.plateau_length > 500 {
475 recommendations.push(OptimizationRecommendation {
476 category: RecommendationCategory::Convergence,
477 severity: Severity::Medium,
478 message: "Training has plateaued for many steps.".to_string(),
479 suggested_actions: vec![
480 "Reduce learning rate".to_string(),
481 "Add regularization".to_string(),
482 "Consider early stopping".to_string(),
483 ],
484 });
485 }
486
487 recommendations
488 }
489
490 pub fn get_gradient_flow_data(&self, num_steps: usize) -> Vec<GradientFlowPoint> {
492 self.step_history
493 .iter()
494 .rev()
495 .take(num_steps)
496 .map(|step| GradientFlowPoint {
497 step: step.step,
498 gradient_norms: step.gradient_norms.clone(),
499 parameter_norms: step.parameter_norms.clone(),
500 update_norms: step.update_norms.clone(),
501 loss: step.loss,
502 })
503 .collect()
504 }
505}
506
507#[derive(Debug, Clone)]
509pub struct AnalysisReport {
510 pub total_steps: usize,
511 pub gradient_stats: GradientStatistics,
512 pub parameter_stats: ParameterStatistics,
513 pub convergence_tracker: ConvergenceTracker,
514 pub recommendations: Vec<OptimizationRecommendation>,
515}
516
517#[derive(Debug, Clone)]
519pub struct OptimizationRecommendation {
520 pub category: RecommendationCategory,
521 pub severity: Severity,
522 pub message: String,
523 pub suggested_actions: Vec<String>,
524}
525
526#[derive(Debug, Clone, PartialEq)]
528pub enum RecommendationCategory {
529 GradientNorms,
530 LearningRate,
531 Convergence,
532 Stability,
533 Performance,
534}
535
536#[derive(Debug, Clone, PartialEq)]
538pub enum Severity {
539 Low,
540 Medium,
541 High,
542 Critical,
543}
544
545#[derive(Debug, Clone)]
547pub struct GradientFlowPoint {
548 pub step: usize,
549 pub gradient_norms: Vec<f32>,
550 pub parameter_norms: Vec<f32>,
551 pub update_norms: Vec<f32>,
552 pub loss: Option<f32>,
553}
554
555pub struct HyperparameterSensitivity {
557 sensitivity_data: HashMap<String, SensitivityResult>,
559 base_config: HashMap<String, f32>,
561}
562
563#[derive(Debug, Clone)]
565pub struct SensitivityResult {
566 pub name: String,
568 pub test_values: Vec<f32>,
570 pub performance_metrics: Vec<f32>,
572 pub sensitivity_score: f32,
574 pub optimal_value: f32,
576}
577
578impl HyperparameterSensitivity {
579 pub fn new(base_config: HashMap<String, f32>) -> Self {
581 Self {
582 sensitivity_data: HashMap::new(),
583 base_config,
584 }
585 }
586
587 pub fn analyze_parameter(
589 &mut self,
590 param_name: &str,
591 test_values: Vec<f32>,
592 performance_evaluator: impl Fn(f32) -> Result<f32>,
593 ) -> Result<SensitivityResult> {
594 let mut performance_metrics = Vec::new();
595
596 for &value in &test_values {
598 let performance = performance_evaluator(value)?;
599 performance_metrics.push(performance);
600 }
601
602 let max_perf = performance_metrics
604 .iter()
605 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
606 let min_perf = performance_metrics
607 .iter()
608 .fold(f32::INFINITY, |a, &b| a.min(b));
609 let sensitivity_score = if max_perf != min_perf {
610 (max_perf - min_perf) / max_perf.abs()
611 } else {
612 0.0
613 };
614
615 let optimal_idx = performance_metrics
617 .iter()
618 .enumerate()
619 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
620 .map(|(i, _)| i)
621 .unwrap_or(0);
622 let optimal_value = test_values[optimal_idx];
623
624 let result = SensitivityResult {
625 name: param_name.to_string(),
626 test_values,
627 performance_metrics,
628 sensitivity_score,
629 optimal_value,
630 };
631
632 self.sensitivity_data
633 .insert(param_name.to_string(), result.clone());
634 Ok(result)
635 }
636
637 pub fn get_sensitivity_ranking(&self) -> Vec<(&str, f32)> {
639 let mut ranking: Vec<_> = self
640 .sensitivity_data
641 .iter()
642 .map(|(name, result)| (name.as_str(), result.sensitivity_score))
643 .collect();
644
645 ranking.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
646 ranking
647 }
648
649 pub fn generate_sensitivity_report(&self) -> SensitivityReport {
651 let ranking = self.get_sensitivity_ranking();
652 let most_sensitive = ranking
653 .first()
654 .map(|(name, score)| (name.to_string(), *score));
655 let least_sensitive = ranking
656 .last()
657 .map(|(name, score)| (name.to_string(), *score));
658
659 SensitivityReport {
660 analyzed_parameters: self.sensitivity_data.keys().cloned().collect(),
661 sensitivity_ranking: ranking
662 .into_iter()
663 .map(|(n, s)| (n.to_string(), s))
664 .collect(),
665 most_sensitive_parameter: most_sensitive,
666 least_sensitive_parameter: least_sensitive,
667 recommendations: self.generate_sensitivity_recommendations(),
668 }
669 }
670
671 fn generate_sensitivity_recommendations(&self) -> Vec<String> {
673 let mut recommendations = Vec::new();
674 let ranking = self.get_sensitivity_ranking();
675
676 if let Some((most_sensitive, score)) = ranking.first() {
677 if *score > 0.5 {
678 recommendations.push(format!(
679 "Parameter '{}' is highly sensitive (score: {:.3}). Fine-tune carefully.",
680 most_sensitive, score
681 ));
682 }
683 }
684
685 if let Some((least_sensitive, score)) = ranking.last() {
686 if *score < 0.1 {
687 recommendations.push(format!(
688 "Parameter '{}' has low sensitivity (score: {:.3}). Consider using default values.",
689 least_sensitive, score
690 ));
691 }
692 }
693
694 recommendations
695 }
696}
697
698#[derive(Debug, Clone)]
700pub struct SensitivityReport {
701 pub analyzed_parameters: Vec<String>,
702 pub sensitivity_ranking: Vec<(String, f32)>,
703 pub most_sensitive_parameter: Option<(String, f32)>,
704 pub least_sensitive_parameter: Option<(String, f32)>,
705 pub recommendations: Vec<String>,
706}
707
708#[cfg(test)]
709mod tests {
710 use super::*;
711 use torsh_tensor::creation::randn;
712
713 #[test]
714 fn test_optimizer_analyzer_creation() {
715 let analyzer = OptimizerAnalyzer::new(None);
716 assert_eq!(analyzer.step_history.len(), 0);
717 }
718
719 #[test]
720 fn test_analyzer_step() {
721 let mut analyzer = OptimizerAnalyzer::new(None);
722 let params = vec![Arc::new(RwLock::new(randn::<f32>(&[10, 10]).unwrap()))];
723
724 let result = analyzer.analyze_step(1, ¶ms, &[0.01], Some(0.5));
725 assert!(result.is_ok());
726 assert_eq!(analyzer.step_history.len(), 1);
727 }
728
729 #[test]
730 fn test_sensitivity_analyzer() -> OptimizerResult<()> {
731 let base_config = [("lr".to_string(), 0.01)].iter().cloned().collect();
732 let mut sensitivity = HyperparameterSensitivity::new(base_config);
733
734 let test_values = vec![0.001, 0.01, 0.1];
735 let evaluator = |lr: f32| Ok(1.0 / lr); let _result = sensitivity.analyze_parameter("lr", test_values, evaluator)?;
738
739 let ranking = sensitivity.get_sensitivity_ranking();
740 assert_eq!(ranking.len(), 1);
741 assert_eq!(ranking[0].0, "lr");
742 Ok(())
743 }
744}