Skip to main content

tensorlogic_train/
utils.rs

1//! Utility functions for model introspection, training analysis, and debugging.
2//!
3//! This module provides tools for:
4//! - Model parameter analysis and visualization
5//! - Gradient statistics computation
6//! - Training time estimation
7//! - Model summary generation
8
9use crate::error::{TrainError, TrainResult};
10use crate::model::Model;
11use scirs2_core::ndarray::Array1;
12use std::collections::HashMap;
13
14/// Model parameter statistics for a single layer or the entire model.
15#[derive(Debug, Clone)]
16pub struct ParameterStats {
17    /// Number of parameters
18    pub count: usize,
19    /// Mean of parameter values
20    pub mean: f64,
21    /// Standard deviation of parameter values
22    pub std: f64,
23    /// Minimum parameter value
24    pub min: f64,
25    /// Maximum parameter value
26    pub max: f64,
27    /// Percentage of zero parameters
28    pub sparsity: f64,
29}
30
31impl ParameterStats {
32    /// Compute statistics from a parameter array.
33    pub fn from_array(params: &Array1<f64>) -> Self {
34        let count = params.len();
35        if count == 0 {
36            return Self {
37                count: 0,
38                mean: 0.0,
39                std: 0.0,
40                min: 0.0,
41                max: 0.0,
42                sparsity: 0.0,
43            };
44        }
45
46        let mean = params.mean().unwrap_or(0.0);
47        let variance = params.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / count as f64;
48        let std = variance.sqrt();
49
50        let min = params.iter().cloned().fold(f64::INFINITY, f64::min);
51        let max = params.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
52
53        let zeros = params.iter().filter(|&&x| x.abs() < 1e-10).count();
54        let sparsity = zeros as f64 / count as f64 * 100.0;
55
56        Self {
57            count,
58            mean,
59            std,
60            min,
61            max,
62            sparsity,
63        }
64    }
65
66    /// Pretty print the statistics.
67    pub fn summary(&self) -> String {
68        format!(
69            "Parameters: {}\n\
70             Mean: {:.6}, Std: {:.6}\n\
71             Min: {:.6}, Max: {:.6}\n\
72             Sparsity: {:.2}%",
73            self.count, self.mean, self.std, self.min, self.max, self.sparsity
74        )
75    }
76}
77
78/// Model summary containing layer-wise parameter information.
79#[derive(Debug, Clone)]
80pub struct ModelSummary {
81    /// Total number of parameters
82    pub total_params: usize,
83    /// Trainable parameters
84    pub trainable_params: usize,
85    /// Layer-wise statistics
86    pub layer_stats: HashMap<String, ParameterStats>,
87    /// Overall model statistics
88    pub overall_stats: ParameterStats,
89}
90
91impl ModelSummary {
92    /// Generate a model summary from a model's state dict.
93    pub fn from_model<M: Model>(model: &M) -> TrainResult<Self> {
94        let state_dict = model.state_dict();
95        let mut total_params = 0;
96        let mut layer_stats = HashMap::new();
97        let mut all_params = Vec::new();
98
99        for (name, params) in state_dict.iter() {
100            total_params += params.len();
101            all_params.extend(params.iter());
102            let params_array = Array1::from_vec(params.clone());
103            layer_stats.insert(name.clone(), ParameterStats::from_array(&params_array));
104        }
105
106        let overall_stats = ParameterStats::from_array(&Array1::from_vec(all_params));
107        let trainable_params = total_params; // Assuming all params are trainable by default
108
109        Ok(Self {
110            total_params,
111            trainable_params,
112            layer_stats,
113            overall_stats,
114        })
115    }
116
117    /// Print a formatted summary of the model.
118    pub fn print(&self) {
119        println!("=================================================================");
120        println!("Model Summary");
121        println!("=================================================================");
122        println!("Total Parameters: {}", self.total_params);
123        println!("Trainable Parameters: {}", self.trainable_params);
124        println!("-----------------------------------------------------------------");
125        println!("Overall Statistics:");
126        println!("{}", self.overall_stats.summary());
127        println!("-----------------------------------------------------------------");
128        println!("Layer-wise Statistics:");
129        for (name, stats) in &self.layer_stats {
130            println!("\n{}: {} parameters", name, stats.count);
131            println!("  Mean: {:.6}, Std: {:.6}", stats.mean, stats.std);
132            println!("  Range: [{:.6}, {:.6}]", stats.min, stats.max);
133            if stats.sparsity > 0.0 {
134                println!("  Sparsity: {:.2}%", stats.sparsity);
135            }
136        }
137        println!("=================================================================");
138    }
139}
140
141/// Gradient statistics for monitoring gradient flow.
142#[derive(Debug, Clone)]
143pub struct GradientStats {
144    /// Layer name
145    pub layer_name: String,
146    /// L2 norm of gradients
147    pub norm: f64,
148    /// Mean gradient value
149    pub mean: f64,
150    /// Standard deviation of gradients
151    pub std: f64,
152    /// Maximum absolute gradient
153    pub max_abs: f64,
154}
155
156impl GradientStats {
157    /// Compute gradient statistics from a gradient array.
158    pub fn compute(layer_name: String, grads: &Array1<f64>) -> Self {
159        let norm = grads.iter().map(|&g| g * g).sum::<f64>().sqrt();
160        let mean = grads.mean().unwrap_or(0.0);
161        let variance = grads.iter().map(|&g| (g - mean).powi(2)).sum::<f64>() / grads.len() as f64;
162        let std = variance.sqrt();
163        let max_abs = grads.iter().map(|&g| g.abs()).fold(0.0, f64::max);
164
165        Self {
166            layer_name,
167            norm,
168            mean,
169            std,
170            max_abs,
171        }
172    }
173
174    /// Check if gradients are vanishing (too small).
175    pub fn is_vanishing(&self, threshold: f64) -> bool {
176        self.norm < threshold
177    }
178
179    /// Check if gradients are exploding (too large).
180    pub fn is_exploding(&self, threshold: f64) -> bool {
181        self.norm > threshold
182    }
183}
184
185/// Compute gradient statistics for all layers in a gradient dictionary.
186pub fn compute_gradient_stats(gradients: &HashMap<String, Array1<f64>>) -> Vec<GradientStats> {
187    gradients
188        .iter()
189        .map(|(name, grads)| GradientStats::compute(name.clone(), grads))
190        .collect()
191}
192
193/// Print a formatted report of gradient statistics.
194pub fn print_gradient_report(stats: &[GradientStats]) {
195    println!("=================================================================");
196    println!("Gradient Statistics");
197    println!("=================================================================");
198    for stat in stats {
199        println!("Layer: {}", stat.layer_name);
200        println!("  Norm: {:.6}", stat.norm);
201        println!("  Mean: {:.6}, Std: {:.6}", stat.mean, stat.std);
202        println!("  Max(abs): {:.6}", stat.max_abs);
203
204        if stat.is_vanishing(1e-7) {
205            println!("  ⚠️  WARNING: Vanishing gradients detected!");
206        }
207        if stat.is_exploding(1e3) {
208            println!("  ⚠️  WARNING: Exploding gradients detected!");
209        }
210        println!();
211    }
212    println!("=================================================================");
213}
214
215/// Training time estimation based on iteration timing.
216#[derive(Debug, Clone)]
217pub struct TimeEstimator {
218    /// Number of samples processed so far
219    samples_processed: usize,
220    /// Total time elapsed (in seconds)
221    time_elapsed: f64,
222    /// Total number of samples to process
223    total_samples: usize,
224}
225
226impl TimeEstimator {
227    /// Create a new time estimator.
228    pub fn new(total_samples: usize) -> Self {
229        Self {
230            samples_processed: 0,
231            time_elapsed: 0.0,
232            total_samples,
233        }
234    }
235
236    /// Update with the number of samples processed in this iteration and time taken.
237    pub fn update(&mut self, samples: usize, time_seconds: f64) {
238        self.samples_processed += samples;
239        self.time_elapsed += time_seconds;
240    }
241
242    /// Get the current throughput (samples per second).
243    pub fn throughput(&self) -> f64 {
244        if self.time_elapsed > 0.0 {
245            self.samples_processed as f64 / self.time_elapsed
246        } else {
247            0.0
248        }
249    }
250
251    /// Estimate remaining time in seconds.
252    pub fn remaining_time(&self) -> f64 {
253        let throughput = self.throughput();
254        if throughput > 0.0 {
255            let remaining_samples = self.total_samples.saturating_sub(self.samples_processed);
256            remaining_samples as f64 / throughput
257        } else {
258            0.0
259        }
260    }
261
262    /// Format remaining time as a human-readable string.
263    pub fn remaining_time_formatted(&self) -> String {
264        let seconds = self.remaining_time();
265        format_duration(seconds)
266    }
267
268    /// Get progress percentage.
269    pub fn progress(&self) -> f64 {
270        if self.total_samples > 0 {
271            (self.samples_processed as f64 / self.total_samples as f64 * 100.0).min(100.0)
272        } else {
273            0.0
274        }
275    }
276}
277
278/// Format a duration in seconds to a human-readable string.
279pub fn format_duration(seconds: f64) -> String {
280    let total_seconds = seconds as u64;
281    let hours = total_seconds / 3600;
282    let minutes = (total_seconds % 3600) / 60;
283    let secs = total_seconds % 60;
284
285    if hours > 0 {
286        format!("{}h {}m {}s", hours, minutes, secs)
287    } else if minutes > 0 {
288        format!("{}m {}s", minutes, secs)
289    } else {
290        format!("{}s", secs)
291    }
292}
293
294/// Compare two models and report differences in parameters.
295pub fn compare_models<M: Model>(
296    model1: &M,
297    model2: &M,
298) -> TrainResult<HashMap<String, ParameterDifference>> {
299    let state1 = model1.state_dict();
300    let state2 = model2.state_dict();
301
302    let mut differences = HashMap::new();
303
304    for (name, params1) in state1.iter() {
305        if let Some(params2) = state2.get(name) {
306            if params1.len() != params2.len() {
307                return Err(TrainError::ModelError(format!(
308                    "Parameter size mismatch for layer '{}': {} vs {}",
309                    name,
310                    params1.len(),
311                    params2.len()
312                )));
313            }
314
315            let params1_array = Array1::from_vec(params1.clone());
316            let params2_array = Array1::from_vec(params2.clone());
317            let diff = ParameterDifference::compute(&params1_array, &params2_array);
318            differences.insert(name.clone(), diff);
319        } else {
320            return Err(TrainError::ModelError(format!(
321                "Layer '{}' not found in second model",
322                name
323            )));
324        }
325    }
326
327    Ok(differences)
328}
329
330/// Statistics about parameter differences between two models.
331#[derive(Debug, Clone)]
332pub struct ParameterDifference {
333    /// Mean absolute difference
334    pub mean_abs_diff: f64,
335    /// Maximum absolute difference
336    pub max_abs_diff: f64,
337    /// Relative change (mean abs diff / mean abs value)
338    pub relative_change: f64,
339    /// Cosine similarity between parameter vectors
340    pub cosine_similarity: f64,
341}
342
343impl ParameterDifference {
344    /// Compute parameter difference statistics.
345    pub fn compute(params1: &Array1<f64>, params2: &Array1<f64>) -> Self {
346        let diff: Array1<f64> = params1 - params2;
347        let abs_diff = diff.mapv(f64::abs);
348
349        let mean_abs_diff = abs_diff.mean().unwrap_or(0.0);
350        let max_abs_diff = abs_diff.iter().cloned().fold(0.0, f64::max);
351
352        let mean_abs_value = params1.mapv(f64::abs).mean().unwrap_or(1.0);
353        let relative_change = if mean_abs_value > 0.0 {
354            mean_abs_diff / mean_abs_value
355        } else {
356            0.0
357        };
358
359        // Cosine similarity
360        let dot_product = params1
361            .iter()
362            .zip(params2.iter())
363            .map(|(&a, &b)| a * b)
364            .sum::<f64>();
365        let norm1 = params1.iter().map(|&x| x * x).sum::<f64>().sqrt();
366        let norm2 = params2.iter().map(|&x| x * x).sum::<f64>().sqrt();
367        let cosine_similarity = if norm1 > 0.0 && norm2 > 0.0 {
368            dot_product / (norm1 * norm2)
369        } else {
370            0.0
371        };
372
373        Self {
374            mean_abs_diff,
375            max_abs_diff,
376            relative_change,
377            cosine_similarity,
378        }
379    }
380}
381
382/// Learning rate range test analyzer for finding optimal learning rates.
383#[derive(Debug, Clone)]
384pub struct LrRangeTestAnalyzer {
385    /// Learning rates tested
386    pub learning_rates: Vec<f64>,
387    /// Losses observed at each learning rate
388    pub losses: Vec<f64>,
389}
390
391impl LrRangeTestAnalyzer {
392    /// Create a new analyzer.
393    pub fn new(learning_rates: Vec<f64>, losses: Vec<f64>) -> TrainResult<Self> {
394        if learning_rates.len() != losses.len() {
395            return Err(TrainError::ConfigError(
396                "Learning rates and losses must have the same length".to_string(),
397            ));
398        }
399
400        Ok(Self {
401            learning_rates,
402            losses,
403        })
404    }
405
406    /// Find the learning rate with the steepest loss decrease.
407    pub fn suggest_lr(&self) -> Option<f64> {
408        if self.losses.len() < 2 {
409            return None;
410        }
411
412        // Compute gradients (rate of loss change)
413        let mut max_gradient = f64::NEG_INFINITY;
414        let mut best_idx = 0;
415
416        for i in 1..self.losses.len() {
417            let gradient = (self.losses[i - 1] - self.losses[i])
418                / (self.learning_rates[i] - self.learning_rates[i - 1]).abs();
419
420            if gradient > max_gradient {
421                max_gradient = gradient;
422                best_idx = i;
423            }
424        }
425
426        Some(self.learning_rates[best_idx])
427    }
428
429    /// Find the learning rate at minimum loss.
430    pub fn lr_at_min_loss(&self) -> Option<f64> {
431        self.losses
432            .iter()
433            .enumerate()
434            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
435            .map(|(idx, _)| self.learning_rates[idx])
436    }
437
438    /// Plot the LR range test results (returns a simple ASCII plot).
439    pub fn plot_ascii(&self, width: usize, height: usize) -> String {
440        if self.losses.is_empty() {
441            return "No data to plot".to_string();
442        }
443
444        let min_loss = self.losses.iter().cloned().fold(f64::INFINITY, f64::min);
445        let max_loss = self
446            .losses
447            .iter()
448            .cloned()
449            .fold(f64::NEG_INFINITY, f64::max);
450        let loss_range = max_loss - min_loss;
451
452        let mut plot = vec![vec![' '; width]; height];
453
454        // Plot points
455        for (i, &loss) in self.losses.iter().enumerate() {
456            let x = (i * width) / self.losses.len().max(1);
457            let normalized = if loss_range > 0.0 {
458                (max_loss - loss) / loss_range
459            } else {
460                0.5
461            };
462            let y = ((normalized * (height - 1) as f64) as usize).min(height - 1);
463
464            if x < width && y < height {
465                plot[y][x] = '*';
466            }
467        }
468
469        // Convert to string
470        let mut result = String::new();
471        result.push_str(&format!(
472            "Learning Rate Range Test (Loss: {:.4} - {:.4})\n",
473            min_loss, max_loss
474        ));
475        result.push_str(&format!(
476            "Suggested LR: {:.2e}\n\n",
477            self.suggest_lr().unwrap_or(0.0)
478        ));
479
480        for row in plot {
481            result.push_str(&row.iter().collect::<String>());
482            result.push('\n');
483        }
484
485        result
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use scirs2_core::ndarray::Array1;
493
494    #[test]
495    fn test_parameter_stats() {
496        let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
497        let stats = ParameterStats::from_array(&params);
498
499        assert_eq!(stats.count, 5);
500        assert!((stats.mean - 3.0).abs() < 1e-6);
501        assert!(stats.std > 0.0);
502        assert_eq!(stats.min, 1.0);
503        assert_eq!(stats.max, 5.0);
504    }
505
506    #[test]
507    fn test_parameter_stats_with_zeros() {
508        let params = Array1::from_vec(vec![0.0, 0.0, 1.0, 2.0]);
509        let stats = ParameterStats::from_array(&params);
510
511        assert_eq!(stats.count, 4);
512        assert_eq!(stats.sparsity, 50.0); // 2 out of 4 are zeros
513    }
514
515    #[test]
516    fn test_gradient_stats() {
517        let grads = Array1::from_vec(vec![0.1, 0.2, -0.1, 0.3]);
518        let stats = GradientStats::compute("test_layer".to_string(), &grads);
519
520        assert_eq!(stats.layer_name, "test_layer");
521        assert!(stats.norm > 0.0);
522        assert!(!stats.is_vanishing(1e-8));
523        assert!(!stats.is_exploding(1e3));
524    }
525
526    #[test]
527    fn test_gradient_stats_vanishing() {
528        let grads = Array1::from_vec(vec![1e-10, 1e-9, -1e-10]);
529        let stats = GradientStats::compute("vanishing".to_string(), &grads);
530
531        assert!(stats.is_vanishing(1e-7));
532    }
533
534    #[test]
535    fn test_gradient_stats_exploding() {
536        let grads = Array1::from_vec(vec![1e5, 1e6, -1e5]);
537        let stats = GradientStats::compute("exploding".to_string(), &grads);
538
539        assert!(stats.is_exploding(1e3));
540    }
541
542    #[test]
543    fn test_time_estimator() {
544        let mut estimator = TimeEstimator::new(1000);
545
546        estimator.update(100, 10.0); // 10 seconds for 100 samples
547        assert!((estimator.throughput() - 10.0).abs() < 0.1); // 10 samples/sec
548        assert!((estimator.progress() - 10.0).abs() < 0.1); // 10% progress
549
550        let remaining = estimator.remaining_time();
551        assert!((remaining - 90.0).abs() < 1.0); // ~90 seconds remaining
552    }
553
554    #[test]
555    fn test_format_duration() {
556        assert_eq!(format_duration(30.0), "30s");
557        assert_eq!(format_duration(90.0), "1m 30s");
558        assert_eq!(format_duration(3665.0), "1h 1m 5s");
559    }
560
561    #[test]
562    fn test_parameter_difference() {
563        let params1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
564        let params2 = Array1::from_vec(vec![1.1, 2.1, 3.1]);
565
566        let diff = ParameterDifference::compute(&params1, &params2);
567
568        assert!((diff.mean_abs_diff - 0.1).abs() < 1e-6);
569        assert!((diff.max_abs_diff - 0.1).abs() < 1e-6);
570        assert!(diff.cosine_similarity > 0.99); // Very similar vectors
571    }
572
573    #[test]
574    fn test_lr_range_test_analyzer() {
575        let lrs = vec![1e-4, 1e-3, 1e-2, 1e-1];
576        let losses = vec![1.0, 0.5, 0.3, 0.8]; // Min at 1e-2
577
578        let analyzer = LrRangeTestAnalyzer::new(lrs.clone(), losses).unwrap();
579
580        let min_lr = analyzer.lr_at_min_loss();
581        assert_eq!(min_lr, Some(1e-2));
582
583        let suggested = analyzer.suggest_lr();
584        assert!(suggested.is_some());
585    }
586
587    #[test]
588    fn test_lr_range_test_analyzer_invalid() {
589        let lrs = vec![1e-4, 1e-3];
590        let losses = vec![1.0]; // Mismatched length
591
592        let result = LrRangeTestAnalyzer::new(lrs, losses);
593        assert!(result.is_err());
594    }
595
596    #[test]
597    fn test_compute_gradient_stats() {
598        let mut gradients = HashMap::new();
599        gradients.insert("layer1".to_string(), Array1::from_vec(vec![0.1, 0.2, 0.3]));
600        gradients.insert("layer2".to_string(), Array1::from_vec(vec![1e-10, 1e-9]));
601
602        let stats = compute_gradient_stats(&gradients);
603        assert_eq!(stats.len(), 2);
604
605        // Find the vanishing layer
606        let vanishing = stats.iter().find(|s| s.layer_name == "layer2").unwrap();
607        assert!(vanishing.is_vanishing(1e-7));
608    }
609}