1use crate::error::{TrainError, TrainResult};
10use crate::model::Model;
11use scirs2_core::ndarray::Array1;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
16pub struct ParameterStats {
17 pub count: usize,
19 pub mean: f64,
21 pub std: f64,
23 pub min: f64,
25 pub max: f64,
27 pub sparsity: f64,
29}
30
31impl ParameterStats {
32 pub fn from_array(params: &Array1<f64>) -> Self {
34 let count = params.len();
35 if count == 0 {
36 return Self {
37 count: 0,
38 mean: 0.0,
39 std: 0.0,
40 min: 0.0,
41 max: 0.0,
42 sparsity: 0.0,
43 };
44 }
45
46 let mean = params.mean().unwrap_or(0.0);
47 let variance = params.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / count as f64;
48 let std = variance.sqrt();
49
50 let min = params.iter().cloned().fold(f64::INFINITY, f64::min);
51 let max = params.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
52
53 let zeros = params.iter().filter(|&&x| x.abs() < 1e-10).count();
54 let sparsity = zeros as f64 / count as f64 * 100.0;
55
56 Self {
57 count,
58 mean,
59 std,
60 min,
61 max,
62 sparsity,
63 }
64 }
65
66 pub fn summary(&self) -> String {
68 format!(
69 "Parameters: {}\n\
70 Mean: {:.6}, Std: {:.6}\n\
71 Min: {:.6}, Max: {:.6}\n\
72 Sparsity: {:.2}%",
73 self.count, self.mean, self.std, self.min, self.max, self.sparsity
74 )
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct ModelSummary {
81 pub total_params: usize,
83 pub trainable_params: usize,
85 pub layer_stats: HashMap<String, ParameterStats>,
87 pub overall_stats: ParameterStats,
89}
90
91impl ModelSummary {
92 pub fn from_model<M: Model>(model: &M) -> TrainResult<Self> {
94 let state_dict = model.state_dict();
95 let mut total_params = 0;
96 let mut layer_stats = HashMap::new();
97 let mut all_params = Vec::new();
98
99 for (name, params) in state_dict.iter() {
100 total_params += params.len();
101 all_params.extend(params.iter());
102 let params_array = Array1::from_vec(params.clone());
103 layer_stats.insert(name.clone(), ParameterStats::from_array(¶ms_array));
104 }
105
106 let overall_stats = ParameterStats::from_array(&Array1::from_vec(all_params));
107 let trainable_params = total_params; Ok(Self {
110 total_params,
111 trainable_params,
112 layer_stats,
113 overall_stats,
114 })
115 }
116
117 pub fn print(&self) {
119 println!("=================================================================");
120 println!("Model Summary");
121 println!("=================================================================");
122 println!("Total Parameters: {}", self.total_params);
123 println!("Trainable Parameters: {}", self.trainable_params);
124 println!("-----------------------------------------------------------------");
125 println!("Overall Statistics:");
126 println!("{}", self.overall_stats.summary());
127 println!("-----------------------------------------------------------------");
128 println!("Layer-wise Statistics:");
129 for (name, stats) in &self.layer_stats {
130 println!("\n{}: {} parameters", name, stats.count);
131 println!(" Mean: {:.6}, Std: {:.6}", stats.mean, stats.std);
132 println!(" Range: [{:.6}, {:.6}]", stats.min, stats.max);
133 if stats.sparsity > 0.0 {
134 println!(" Sparsity: {:.2}%", stats.sparsity);
135 }
136 }
137 println!("=================================================================");
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct GradientStats {
144 pub layer_name: String,
146 pub norm: f64,
148 pub mean: f64,
150 pub std: f64,
152 pub max_abs: f64,
154}
155
156impl GradientStats {
157 pub fn compute(layer_name: String, grads: &Array1<f64>) -> Self {
159 let norm = grads.iter().map(|&g| g * g).sum::<f64>().sqrt();
160 let mean = grads.mean().unwrap_or(0.0);
161 let variance = grads.iter().map(|&g| (g - mean).powi(2)).sum::<f64>() / grads.len() as f64;
162 let std = variance.sqrt();
163 let max_abs = grads.iter().map(|&g| g.abs()).fold(0.0, f64::max);
164
165 Self {
166 layer_name,
167 norm,
168 mean,
169 std,
170 max_abs,
171 }
172 }
173
174 pub fn is_vanishing(&self, threshold: f64) -> bool {
176 self.norm < threshold
177 }
178
179 pub fn is_exploding(&self, threshold: f64) -> bool {
181 self.norm > threshold
182 }
183}
184
185pub fn compute_gradient_stats(gradients: &HashMap<String, Array1<f64>>) -> Vec<GradientStats> {
187 gradients
188 .iter()
189 .map(|(name, grads)| GradientStats::compute(name.clone(), grads))
190 .collect()
191}
192
193pub fn print_gradient_report(stats: &[GradientStats]) {
195 println!("=================================================================");
196 println!("Gradient Statistics");
197 println!("=================================================================");
198 for stat in stats {
199 println!("Layer: {}", stat.layer_name);
200 println!(" Norm: {:.6}", stat.norm);
201 println!(" Mean: {:.6}, Std: {:.6}", stat.mean, stat.std);
202 println!(" Max(abs): {:.6}", stat.max_abs);
203
204 if stat.is_vanishing(1e-7) {
205 println!(" ⚠️ WARNING: Vanishing gradients detected!");
206 }
207 if stat.is_exploding(1e3) {
208 println!(" ⚠️ WARNING: Exploding gradients detected!");
209 }
210 println!();
211 }
212 println!("=================================================================");
213}
214
215#[derive(Debug, Clone)]
217pub struct TimeEstimator {
218 samples_processed: usize,
220 time_elapsed: f64,
222 total_samples: usize,
224}
225
226impl TimeEstimator {
227 pub fn new(total_samples: usize) -> Self {
229 Self {
230 samples_processed: 0,
231 time_elapsed: 0.0,
232 total_samples,
233 }
234 }
235
236 pub fn update(&mut self, samples: usize, time_seconds: f64) {
238 self.samples_processed += samples;
239 self.time_elapsed += time_seconds;
240 }
241
242 pub fn throughput(&self) -> f64 {
244 if self.time_elapsed > 0.0 {
245 self.samples_processed as f64 / self.time_elapsed
246 } else {
247 0.0
248 }
249 }
250
251 pub fn remaining_time(&self) -> f64 {
253 let throughput = self.throughput();
254 if throughput > 0.0 {
255 let remaining_samples = self.total_samples.saturating_sub(self.samples_processed);
256 remaining_samples as f64 / throughput
257 } else {
258 0.0
259 }
260 }
261
262 pub fn remaining_time_formatted(&self) -> String {
264 let seconds = self.remaining_time();
265 format_duration(seconds)
266 }
267
268 pub fn progress(&self) -> f64 {
270 if self.total_samples > 0 {
271 (self.samples_processed as f64 / self.total_samples as f64 * 100.0).min(100.0)
272 } else {
273 0.0
274 }
275 }
276}
277
278pub fn format_duration(seconds: f64) -> String {
280 let total_seconds = seconds as u64;
281 let hours = total_seconds / 3600;
282 let minutes = (total_seconds % 3600) / 60;
283 let secs = total_seconds % 60;
284
285 if hours > 0 {
286 format!("{}h {}m {}s", hours, minutes, secs)
287 } else if minutes > 0 {
288 format!("{}m {}s", minutes, secs)
289 } else {
290 format!("{}s", secs)
291 }
292}
293
294pub fn compare_models<M: Model>(
296 model1: &M,
297 model2: &M,
298) -> TrainResult<HashMap<String, ParameterDifference>> {
299 let state1 = model1.state_dict();
300 let state2 = model2.state_dict();
301
302 let mut differences = HashMap::new();
303
304 for (name, params1) in state1.iter() {
305 if let Some(params2) = state2.get(name) {
306 if params1.len() != params2.len() {
307 return Err(TrainError::ModelError(format!(
308 "Parameter size mismatch for layer '{}': {} vs {}",
309 name,
310 params1.len(),
311 params2.len()
312 )));
313 }
314
315 let params1_array = Array1::from_vec(params1.clone());
316 let params2_array = Array1::from_vec(params2.clone());
317 let diff = ParameterDifference::compute(¶ms1_array, ¶ms2_array);
318 differences.insert(name.clone(), diff);
319 } else {
320 return Err(TrainError::ModelError(format!(
321 "Layer '{}' not found in second model",
322 name
323 )));
324 }
325 }
326
327 Ok(differences)
328}
329
330#[derive(Debug, Clone)]
332pub struct ParameterDifference {
333 pub mean_abs_diff: f64,
335 pub max_abs_diff: f64,
337 pub relative_change: f64,
339 pub cosine_similarity: f64,
341}
342
343impl ParameterDifference {
344 pub fn compute(params1: &Array1<f64>, params2: &Array1<f64>) -> Self {
346 let diff: Array1<f64> = params1 - params2;
347 let abs_diff = diff.mapv(f64::abs);
348
349 let mean_abs_diff = abs_diff.mean().unwrap_or(0.0);
350 let max_abs_diff = abs_diff.iter().cloned().fold(0.0, f64::max);
351
352 let mean_abs_value = params1.mapv(f64::abs).mean().unwrap_or(1.0);
353 let relative_change = if mean_abs_value > 0.0 {
354 mean_abs_diff / mean_abs_value
355 } else {
356 0.0
357 };
358
359 let dot_product = params1
361 .iter()
362 .zip(params2.iter())
363 .map(|(&a, &b)| a * b)
364 .sum::<f64>();
365 let norm1 = params1.iter().map(|&x| x * x).sum::<f64>().sqrt();
366 let norm2 = params2.iter().map(|&x| x * x).sum::<f64>().sqrt();
367 let cosine_similarity = if norm1 > 0.0 && norm2 > 0.0 {
368 dot_product / (norm1 * norm2)
369 } else {
370 0.0
371 };
372
373 Self {
374 mean_abs_diff,
375 max_abs_diff,
376 relative_change,
377 cosine_similarity,
378 }
379 }
380}
381
382#[derive(Debug, Clone)]
384pub struct LrRangeTestAnalyzer {
385 pub learning_rates: Vec<f64>,
387 pub losses: Vec<f64>,
389}
390
391impl LrRangeTestAnalyzer {
392 pub fn new(learning_rates: Vec<f64>, losses: Vec<f64>) -> TrainResult<Self> {
394 if learning_rates.len() != losses.len() {
395 return Err(TrainError::ConfigError(
396 "Learning rates and losses must have the same length".to_string(),
397 ));
398 }
399
400 Ok(Self {
401 learning_rates,
402 losses,
403 })
404 }
405
406 pub fn suggest_lr(&self) -> Option<f64> {
408 if self.losses.len() < 2 {
409 return None;
410 }
411
412 let mut max_gradient = f64::NEG_INFINITY;
414 let mut best_idx = 0;
415
416 for i in 1..self.losses.len() {
417 let gradient = (self.losses[i - 1] - self.losses[i])
418 / (self.learning_rates[i] - self.learning_rates[i - 1]).abs();
419
420 if gradient > max_gradient {
421 max_gradient = gradient;
422 best_idx = i;
423 }
424 }
425
426 Some(self.learning_rates[best_idx])
427 }
428
429 pub fn lr_at_min_loss(&self) -> Option<f64> {
431 self.losses
432 .iter()
433 .enumerate()
434 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
435 .map(|(idx, _)| self.learning_rates[idx])
436 }
437
438 pub fn plot_ascii(&self, width: usize, height: usize) -> String {
440 if self.losses.is_empty() {
441 return "No data to plot".to_string();
442 }
443
444 let min_loss = self.losses.iter().cloned().fold(f64::INFINITY, f64::min);
445 let max_loss = self
446 .losses
447 .iter()
448 .cloned()
449 .fold(f64::NEG_INFINITY, f64::max);
450 let loss_range = max_loss - min_loss;
451
452 let mut plot = vec![vec![' '; width]; height];
453
454 for (i, &loss) in self.losses.iter().enumerate() {
456 let x = (i * width) / self.losses.len().max(1);
457 let normalized = if loss_range > 0.0 {
458 (max_loss - loss) / loss_range
459 } else {
460 0.5
461 };
462 let y = ((normalized * (height - 1) as f64) as usize).min(height - 1);
463
464 if x < width && y < height {
465 plot[y][x] = '*';
466 }
467 }
468
469 let mut result = String::new();
471 result.push_str(&format!(
472 "Learning Rate Range Test (Loss: {:.4} - {:.4})\n",
473 min_loss, max_loss
474 ));
475 result.push_str(&format!(
476 "Suggested LR: {:.2e}\n\n",
477 self.suggest_lr().unwrap_or(0.0)
478 ));
479
480 for row in plot {
481 result.push_str(&row.iter().collect::<String>());
482 result.push('\n');
483 }
484
485 result
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use scirs2_core::ndarray::Array1;
493
494 #[test]
495 fn test_parameter_stats() {
496 let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
497 let stats = ParameterStats::from_array(¶ms);
498
499 assert_eq!(stats.count, 5);
500 assert!((stats.mean - 3.0).abs() < 1e-6);
501 assert!(stats.std > 0.0);
502 assert_eq!(stats.min, 1.0);
503 assert_eq!(stats.max, 5.0);
504 }
505
506 #[test]
507 fn test_parameter_stats_with_zeros() {
508 let params = Array1::from_vec(vec![0.0, 0.0, 1.0, 2.0]);
509 let stats = ParameterStats::from_array(¶ms);
510
511 assert_eq!(stats.count, 4);
512 assert_eq!(stats.sparsity, 50.0); }
514
515 #[test]
516 fn test_gradient_stats() {
517 let grads = Array1::from_vec(vec![0.1, 0.2, -0.1, 0.3]);
518 let stats = GradientStats::compute("test_layer".to_string(), &grads);
519
520 assert_eq!(stats.layer_name, "test_layer");
521 assert!(stats.norm > 0.0);
522 assert!(!stats.is_vanishing(1e-8));
523 assert!(!stats.is_exploding(1e3));
524 }
525
526 #[test]
527 fn test_gradient_stats_vanishing() {
528 let grads = Array1::from_vec(vec![1e-10, 1e-9, -1e-10]);
529 let stats = GradientStats::compute("vanishing".to_string(), &grads);
530
531 assert!(stats.is_vanishing(1e-7));
532 }
533
534 #[test]
535 fn test_gradient_stats_exploding() {
536 let grads = Array1::from_vec(vec![1e5, 1e6, -1e5]);
537 let stats = GradientStats::compute("exploding".to_string(), &grads);
538
539 assert!(stats.is_exploding(1e3));
540 }
541
542 #[test]
543 fn test_time_estimator() {
544 let mut estimator = TimeEstimator::new(1000);
545
546 estimator.update(100, 10.0); assert!((estimator.throughput() - 10.0).abs() < 0.1); assert!((estimator.progress() - 10.0).abs() < 0.1); let remaining = estimator.remaining_time();
551 assert!((remaining - 90.0).abs() < 1.0); }
553
554 #[test]
555 fn test_format_duration() {
556 assert_eq!(format_duration(30.0), "30s");
557 assert_eq!(format_duration(90.0), "1m 30s");
558 assert_eq!(format_duration(3665.0), "1h 1m 5s");
559 }
560
561 #[test]
562 fn test_parameter_difference() {
563 let params1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
564 let params2 = Array1::from_vec(vec![1.1, 2.1, 3.1]);
565
566 let diff = ParameterDifference::compute(¶ms1, ¶ms2);
567
568 assert!((diff.mean_abs_diff - 0.1).abs() < 1e-6);
569 assert!((diff.max_abs_diff - 0.1).abs() < 1e-6);
570 assert!(diff.cosine_similarity > 0.99); }
572
573 #[test]
574 fn test_lr_range_test_analyzer() {
575 let lrs = vec![1e-4, 1e-3, 1e-2, 1e-1];
576 let losses = vec![1.0, 0.5, 0.3, 0.8]; let analyzer = LrRangeTestAnalyzer::new(lrs.clone(), losses).unwrap();
579
580 let min_lr = analyzer.lr_at_min_loss();
581 assert_eq!(min_lr, Some(1e-2));
582
583 let suggested = analyzer.suggest_lr();
584 assert!(suggested.is_some());
585 }
586
587 #[test]
588 fn test_lr_range_test_analyzer_invalid() {
589 let lrs = vec![1e-4, 1e-3];
590 let losses = vec![1.0]; let result = LrRangeTestAnalyzer::new(lrs, losses);
593 assert!(result.is_err());
594 }
595
596 #[test]
597 fn test_compute_gradient_stats() {
598 let mut gradients = HashMap::new();
599 gradients.insert("layer1".to_string(), Array1::from_vec(vec![0.1, 0.2, 0.3]));
600 gradients.insert("layer2".to_string(), Array1::from_vec(vec![1e-10, 1e-9]));
601
602 let stats = compute_gradient_stats(&gradients);
603 assert_eq!(stats.len(), 2);
604
605 let vanishing = stats.iter().find(|s| s.layer_name == "layer2").unwrap();
607 assert!(vanishing.is_vanishing(1e-7));
608 }
609}