Skip to main content

ruvector_sona/training/
metrics.rs

1//! Training Metrics for SONA
2//!
3//! Comprehensive analytics for training sessions.
4
5use serde::{Deserialize, Serialize};
6
7/// Training metrics collection
8#[derive(Clone, Debug, Default, Serialize, Deserialize)]
9pub struct TrainingMetrics {
10    /// Pipeline/agent name
11    pub name: String,
12    /// Total examples processed
13    pub total_examples: usize,
14    /// Total training sessions
15    pub training_sessions: u64,
16    /// Patterns learned
17    pub patterns_learned: usize,
18    /// Quality samples for averaging
19    pub quality_samples: Vec<f32>,
20    /// Validation quality (if validation was run)
21    pub validation_quality: Option<f32>,
22    /// Performance metrics
23    pub performance: PerformanceMetrics,
24}
25
26impl TrainingMetrics {
27    /// Create new metrics
28    pub fn new(name: &str) -> Self {
29        Self {
30            name: name.to_string(),
31            ..Default::default()
32        }
33    }
34
35    /// Add quality sample
36    pub fn add_quality_sample(&mut self, quality: f32) {
37        self.quality_samples.push(quality);
38        // Keep last 10000 samples
39        if self.quality_samples.len() > 10000 {
40            self.quality_samples.remove(0);
41        }
42    }
43
44    /// Get average quality
45    pub fn avg_quality(&self) -> f32 {
46        if self.quality_samples.is_empty() {
47            0.0
48        } else {
49            self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
50        }
51    }
52
53    /// Get quality percentile
54    pub fn quality_percentile(&self, percentile: f32) -> f32 {
55        if self.quality_samples.is_empty() {
56            return 0.0;
57        }
58
59        let mut sorted = self.quality_samples.clone();
60        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
61
62        let idx = ((percentile / 100.0) * (sorted.len() - 1) as f32) as usize;
63        sorted[idx.min(sorted.len() - 1)]
64    }
65
66    /// Get quality statistics
67    pub fn quality_stats(&self) -> QualityMetrics {
68        if self.quality_samples.is_empty() {
69            return QualityMetrics::default();
70        }
71
72        let avg = self.avg_quality();
73        let min = self
74            .quality_samples
75            .iter()
76            .cloned()
77            .fold(f32::MAX, f32::min);
78        let max = self
79            .quality_samples
80            .iter()
81            .cloned()
82            .fold(f32::MIN, f32::max);
83
84        let variance = self
85            .quality_samples
86            .iter()
87            .map(|q| (q - avg).powi(2))
88            .sum::<f32>()
89            / self.quality_samples.len() as f32;
90        let std_dev = variance.sqrt();
91
92        QualityMetrics {
93            avg,
94            min,
95            max,
96            std_dev,
97            p25: self.quality_percentile(25.0),
98            p50: self.quality_percentile(50.0),
99            p75: self.quality_percentile(75.0),
100            p95: self.quality_percentile(95.0),
101            sample_count: self.quality_samples.len(),
102        }
103    }
104
105    /// Reset metrics
106    pub fn reset(&mut self) {
107        self.total_examples = 0;
108        self.training_sessions = 0;
109        self.patterns_learned = 0;
110        self.quality_samples.clear();
111        self.validation_quality = None;
112        self.performance = PerformanceMetrics::default();
113    }
114
115    /// Merge with another metrics instance
116    pub fn merge(&mut self, other: &TrainingMetrics) {
117        self.total_examples += other.total_examples;
118        self.training_sessions += other.training_sessions;
119        self.patterns_learned = other.patterns_learned; // Take latest
120        self.quality_samples.extend(&other.quality_samples);
121
122        // Keep last 10000
123        if self.quality_samples.len() > 10000 {
124            let excess = self.quality_samples.len() - 10000;
125            self.quality_samples.drain(0..excess);
126        }
127    }
128}
129
130/// Quality metrics summary
131#[derive(Clone, Debug, Default, Serialize, Deserialize)]
132pub struct QualityMetrics {
133    /// Average quality
134    pub avg: f32,
135    /// Minimum quality
136    pub min: f32,
137    /// Maximum quality
138    pub max: f32,
139    /// Standard deviation
140    pub std_dev: f32,
141    /// 25th percentile
142    pub p25: f32,
143    /// 50th percentile (median)
144    pub p50: f32,
145    /// 75th percentile
146    pub p75: f32,
147    /// 95th percentile
148    pub p95: f32,
149    /// Number of samples
150    pub sample_count: usize,
151}
152
153impl std::fmt::Display for QualityMetrics {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        write!(
156            f,
157            "avg={:.4}, std={:.4}, min={:.4}, max={:.4}, p50={:.4}, p95={:.4} (n={})",
158            self.avg, self.std_dev, self.min, self.max, self.p50, self.p95, self.sample_count
159        )
160    }
161}
162
163/// Performance metrics
164#[derive(Clone, Debug, Default, Serialize, Deserialize)]
165pub struct PerformanceMetrics {
166    /// Total training time in seconds
167    pub total_training_secs: f64,
168    /// Average batch processing time in milliseconds
169    pub avg_batch_time_ms: f64,
170    /// Average example processing time in microseconds
171    pub avg_example_time_us: f64,
172    /// Peak memory usage in MB
173    pub peak_memory_mb: usize,
174    /// Examples per second throughput
175    pub examples_per_sec: f64,
176    /// Pattern extraction time in milliseconds
177    pub pattern_extraction_ms: f64,
178}
179
180impl PerformanceMetrics {
181    /// Calculate throughput
182    pub fn calculate_throughput(&mut self, examples: usize, duration_secs: f64) {
183        if duration_secs > 0.0 {
184            self.examples_per_sec = examples as f64 / duration_secs;
185            self.avg_example_time_us = (duration_secs * 1_000_000.0) / examples as f64;
186        }
187    }
188}
189
190/// Epoch statistics
191#[derive(Clone, Debug, Serialize, Deserialize)]
192pub struct EpochStats {
193    /// Epoch number (0-indexed)
194    pub epoch: usize,
195    /// Examples processed in this epoch
196    pub examples_processed: usize,
197    /// Average quality for this epoch
198    pub avg_quality: f32,
199    /// Duration in seconds
200    pub duration_secs: f64,
201}
202
203impl std::fmt::Display for EpochStats {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        write!(
206            f,
207            "Epoch {}: {} examples, avg_quality={:.4}, {:.2}s",
208            self.epoch + 1,
209            self.examples_processed,
210            self.avg_quality,
211            self.duration_secs
212        )
213    }
214}
215
216/// Training result summary
217#[derive(Clone, Debug, Serialize, Deserialize)]
218pub struct TrainingResult {
219    /// Pipeline name
220    pub pipeline_name: String,
221    /// Number of epochs completed
222    pub epochs_completed: usize,
223    /// Total examples processed
224    pub total_examples: usize,
225    /// Patterns learned
226    pub patterns_learned: usize,
227    /// Final average quality
228    pub final_avg_quality: f32,
229    /// Total duration in seconds
230    pub total_duration_secs: f64,
231    /// Per-epoch statistics
232    pub epoch_stats: Vec<EpochStats>,
233    /// Validation quality (if validation was run)
234    pub validation_quality: Option<f32>,
235}
236
237impl TrainingResult {
238    /// Get examples per second
239    pub fn examples_per_sec(&self) -> f64 {
240        if self.total_duration_secs > 0.0 {
241            self.total_examples as f64 / self.total_duration_secs
242        } else {
243            0.0
244        }
245    }
246
247    /// Get average epoch duration
248    pub fn avg_epoch_duration(&self) -> f64 {
249        if self.epochs_completed > 0 {
250            self.total_duration_secs / self.epochs_completed as f64
251        } else {
252            0.0
253        }
254    }
255
256    /// Check if training improved quality
257    pub fn quality_improved(&self) -> bool {
258        if self.epoch_stats.len() < 2 {
259            return false;
260        }
261        let first = self.epoch_stats.first().unwrap().avg_quality;
262        let last = self.epoch_stats.last().unwrap().avg_quality;
263        last > first
264    }
265
266    /// Get quality improvement
267    pub fn quality_improvement(&self) -> f32 {
268        if self.epoch_stats.len() < 2 {
269            return 0.0;
270        }
271        let first = self.epoch_stats.first().unwrap().avg_quality;
272        let last = self.epoch_stats.last().unwrap().avg_quality;
273        last - first
274    }
275}
276
277impl std::fmt::Display for TrainingResult {
278    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        write!(
280            f,
281            "TrainingResult(pipeline={}, epochs={}, examples={}, patterns={}, \
282             final_quality={:.4}, duration={:.2}s, throughput={:.1}/s)",
283            self.pipeline_name,
284            self.epochs_completed,
285            self.total_examples,
286            self.patterns_learned,
287            self.final_avg_quality,
288            self.total_duration_secs,
289            self.examples_per_sec()
290        )
291    }
292}
293
294/// Comparison metrics between training runs
295#[derive(Clone, Debug, Serialize, Deserialize)]
296pub struct TrainingComparison {
297    /// Baseline result name
298    pub baseline_name: String,
299    /// Comparison result name
300    pub comparison_name: String,
301    /// Quality difference (comparison - baseline)
302    pub quality_diff: f32,
303    /// Quality improvement percentage
304    pub quality_improvement_pct: f32,
305    /// Throughput difference
306    pub throughput_diff: f64,
307    /// Duration difference in seconds
308    pub duration_diff: f64,
309}
310
311impl TrainingComparison {
312    /// Compare two training results
313    pub fn compare(baseline: &TrainingResult, comparison: &TrainingResult) -> Self {
314        let quality_diff = comparison.final_avg_quality - baseline.final_avg_quality;
315        let quality_improvement_pct = if baseline.final_avg_quality > 0.0 {
316            (quality_diff / baseline.final_avg_quality) * 100.0
317        } else {
318            0.0
319        };
320
321        Self {
322            baseline_name: baseline.pipeline_name.clone(),
323            comparison_name: comparison.pipeline_name.clone(),
324            quality_diff,
325            quality_improvement_pct,
326            throughput_diff: comparison.examples_per_sec() - baseline.examples_per_sec(),
327            duration_diff: comparison.total_duration_secs - baseline.total_duration_secs,
328        }
329    }
330}
331
332impl std::fmt::Display for TrainingComparison {
333    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334        let quality_sign = if self.quality_diff >= 0.0 { "+" } else { "" };
335        let throughput_sign = if self.throughput_diff >= 0.0 { "+" } else { "" };
336
337        write!(
338            f,
339            "Comparison {} vs {}: quality {}{:.4} ({}{:.1}%), throughput {}{:.1}/s",
340            self.comparison_name,
341            self.baseline_name,
342            quality_sign,
343            self.quality_diff,
344            quality_sign,
345            self.quality_improvement_pct,
346            throughput_sign,
347            self.throughput_diff
348        )
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_metrics_creation() {
358        let metrics = TrainingMetrics::new("test");
359        assert_eq!(metrics.name, "test");
360        assert_eq!(metrics.total_examples, 0);
361    }
362
363    #[test]
364    fn test_quality_samples() {
365        let mut metrics = TrainingMetrics::new("test");
366
367        for i in 0..10 {
368            metrics.add_quality_sample(i as f32 / 10.0);
369        }
370
371        assert_eq!(metrics.quality_samples.len(), 10);
372        assert!((metrics.avg_quality() - 0.45).abs() < 0.01);
373    }
374
375    #[test]
376    fn test_quality_percentiles() {
377        let mut metrics = TrainingMetrics::new("test");
378
379        for i in 0..100 {
380            metrics.add_quality_sample(i as f32 / 100.0);
381        }
382
383        assert!((metrics.quality_percentile(50.0) - 0.5).abs() < 0.02);
384        assert!((metrics.quality_percentile(95.0) - 0.95).abs() < 0.02);
385    }
386
387    #[test]
388    fn test_quality_stats() {
389        let mut metrics = TrainingMetrics::new("test");
390        metrics.add_quality_sample(0.5);
391        metrics.add_quality_sample(0.7);
392        metrics.add_quality_sample(0.9);
393
394        let stats = metrics.quality_stats();
395        assert!((stats.avg - 0.7).abs() < 0.01);
396        assert!((stats.min - 0.5).abs() < 0.01);
397        assert!((stats.max - 0.9).abs() < 0.01);
398    }
399
400    #[test]
401    fn test_training_result() {
402        let result = TrainingResult {
403            pipeline_name: "test".into(),
404            epochs_completed: 3,
405            total_examples: 1000,
406            patterns_learned: 50,
407            final_avg_quality: 0.85,
408            total_duration_secs: 10.0,
409            epoch_stats: vec![
410                EpochStats {
411                    epoch: 0,
412                    examples_processed: 333,
413                    avg_quality: 0.75,
414                    duration_secs: 3.0,
415                },
416                EpochStats {
417                    epoch: 1,
418                    examples_processed: 333,
419                    avg_quality: 0.80,
420                    duration_secs: 3.5,
421                },
422                EpochStats {
423                    epoch: 2,
424                    examples_processed: 334,
425                    avg_quality: 0.85,
426                    duration_secs: 3.5,
427                },
428            ],
429            validation_quality: Some(0.82),
430        };
431
432        assert_eq!(result.examples_per_sec(), 100.0);
433        assert!(result.quality_improved());
434        assert!((result.quality_improvement() - 0.10).abs() < 0.01);
435    }
436
437    #[test]
438    fn test_training_comparison() {
439        let baseline = TrainingResult {
440            pipeline_name: "baseline".into(),
441            epochs_completed: 2,
442            total_examples: 500,
443            patterns_learned: 25,
444            final_avg_quality: 0.70,
445            total_duration_secs: 5.0,
446            epoch_stats: vec![],
447            validation_quality: None,
448        };
449
450        let improved = TrainingResult {
451            pipeline_name: "improved".into(),
452            epochs_completed: 2,
453            total_examples: 500,
454            patterns_learned: 30,
455            final_avg_quality: 0.85,
456            total_duration_secs: 4.0,
457            epoch_stats: vec![],
458            validation_quality: None,
459        };
460
461        let comparison = TrainingComparison::compare(&baseline, &improved);
462        assert!((comparison.quality_diff - 0.15).abs() < 0.01);
463        assert!(comparison.quality_improvement_pct > 20.0);
464        assert!(comparison.throughput_diff > 0.0);
465    }
466}