ruvector_sona/training/
metrics.rs

1//! Training Metrics for SONA
2//!
3//! Comprehensive analytics for training sessions.
4
5use serde::{Deserialize, Serialize};
6
7/// Training metrics collection
8#[derive(Clone, Debug, Default, Serialize, Deserialize)]
9pub struct TrainingMetrics {
10    /// Pipeline/agent name
11    pub name: String,
12    /// Total examples processed
13    pub total_examples: usize,
14    /// Total training sessions
15    pub training_sessions: u64,
16    /// Patterns learned
17    pub patterns_learned: usize,
18    /// Quality samples for averaging
19    pub quality_samples: Vec<f32>,
20    /// Validation quality (if validation was run)
21    pub validation_quality: Option<f32>,
22    /// Performance metrics
23    pub performance: PerformanceMetrics,
24}
25
26impl TrainingMetrics {
27    /// Create new metrics
28    pub fn new(name: &str) -> Self {
29        Self {
30            name: name.to_string(),
31            ..Default::default()
32        }
33    }
34
35    /// Add quality sample
36    pub fn add_quality_sample(&mut self, quality: f32) {
37        self.quality_samples.push(quality);
38        // Keep last 10000 samples
39        if self.quality_samples.len() > 10000 {
40            self.quality_samples.remove(0);
41        }
42    }
43
44    /// Get average quality
45    pub fn avg_quality(&self) -> f32 {
46        if self.quality_samples.is_empty() {
47            0.0
48        } else {
49            self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
50        }
51    }
52
53    /// Get quality percentile
54    pub fn quality_percentile(&self, percentile: f32) -> f32 {
55        if self.quality_samples.is_empty() {
56            return 0.0;
57        }
58
59        let mut sorted = self.quality_samples.clone();
60        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
61
62        let idx = ((percentile / 100.0) * (sorted.len() - 1) as f32) as usize;
63        sorted[idx.min(sorted.len() - 1)]
64    }
65
66    /// Get quality statistics
67    pub fn quality_stats(&self) -> QualityMetrics {
68        if self.quality_samples.is_empty() {
69            return QualityMetrics::default();
70        }
71
72        let avg = self.avg_quality();
73        let min = self.quality_samples.iter().cloned().fold(f32::MAX, f32::min);
74        let max = self.quality_samples.iter().cloned().fold(f32::MIN, f32::max);
75
76        let variance = self.quality_samples.iter()
77            .map(|q| (q - avg).powi(2))
78            .sum::<f32>() / self.quality_samples.len() as f32;
79        let std_dev = variance.sqrt();
80
81        QualityMetrics {
82            avg,
83            min,
84            max,
85            std_dev,
86            p25: self.quality_percentile(25.0),
87            p50: self.quality_percentile(50.0),
88            p75: self.quality_percentile(75.0),
89            p95: self.quality_percentile(95.0),
90            sample_count: self.quality_samples.len(),
91        }
92    }
93
94    /// Reset metrics
95    pub fn reset(&mut self) {
96        self.total_examples = 0;
97        self.training_sessions = 0;
98        self.patterns_learned = 0;
99        self.quality_samples.clear();
100        self.validation_quality = None;
101        self.performance = PerformanceMetrics::default();
102    }
103
104    /// Merge with another metrics instance
105    pub fn merge(&mut self, other: &TrainingMetrics) {
106        self.total_examples += other.total_examples;
107        self.training_sessions += other.training_sessions;
108        self.patterns_learned = other.patterns_learned; // Take latest
109        self.quality_samples.extend(&other.quality_samples);
110
111        // Keep last 10000
112        if self.quality_samples.len() > 10000 {
113            let excess = self.quality_samples.len() - 10000;
114            self.quality_samples.drain(0..excess);
115        }
116    }
117}
118
119/// Quality metrics summary
120#[derive(Clone, Debug, Default, Serialize, Deserialize)]
121pub struct QualityMetrics {
122    /// Average quality
123    pub avg: f32,
124    /// Minimum quality
125    pub min: f32,
126    /// Maximum quality
127    pub max: f32,
128    /// Standard deviation
129    pub std_dev: f32,
130    /// 25th percentile
131    pub p25: f32,
132    /// 50th percentile (median)
133    pub p50: f32,
134    /// 75th percentile
135    pub p75: f32,
136    /// 95th percentile
137    pub p95: f32,
138    /// Number of samples
139    pub sample_count: usize,
140}
141
142impl std::fmt::Display for QualityMetrics {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        write!(
145            f,
146            "avg={:.4}, std={:.4}, min={:.4}, max={:.4}, p50={:.4}, p95={:.4} (n={})",
147            self.avg, self.std_dev, self.min, self.max, self.p50, self.p95, self.sample_count
148        )
149    }
150}
151
152/// Performance metrics
153#[derive(Clone, Debug, Default, Serialize, Deserialize)]
154pub struct PerformanceMetrics {
155    /// Total training time in seconds
156    pub total_training_secs: f64,
157    /// Average batch processing time in milliseconds
158    pub avg_batch_time_ms: f64,
159    /// Average example processing time in microseconds
160    pub avg_example_time_us: f64,
161    /// Peak memory usage in MB
162    pub peak_memory_mb: usize,
163    /// Examples per second throughput
164    pub examples_per_sec: f64,
165    /// Pattern extraction time in milliseconds
166    pub pattern_extraction_ms: f64,
167}
168
169impl PerformanceMetrics {
170    /// Calculate throughput
171    pub fn calculate_throughput(&mut self, examples: usize, duration_secs: f64) {
172        if duration_secs > 0.0 {
173            self.examples_per_sec = examples as f64 / duration_secs;
174            self.avg_example_time_us = (duration_secs * 1_000_000.0) / examples as f64;
175        }
176    }
177}
178
179/// Epoch statistics
180#[derive(Clone, Debug, Serialize, Deserialize)]
181pub struct EpochStats {
182    /// Epoch number (0-indexed)
183    pub epoch: usize,
184    /// Examples processed in this epoch
185    pub examples_processed: usize,
186    /// Average quality for this epoch
187    pub avg_quality: f32,
188    /// Duration in seconds
189    pub duration_secs: f64,
190}
191
192impl std::fmt::Display for EpochStats {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        write!(
195            f,
196            "Epoch {}: {} examples, avg_quality={:.4}, {:.2}s",
197            self.epoch + 1, self.examples_processed, self.avg_quality, self.duration_secs
198        )
199    }
200}
201
202/// Training result summary
203#[derive(Clone, Debug, Serialize, Deserialize)]
204pub struct TrainingResult {
205    /// Pipeline name
206    pub pipeline_name: String,
207    /// Number of epochs completed
208    pub epochs_completed: usize,
209    /// Total examples processed
210    pub total_examples: usize,
211    /// Patterns learned
212    pub patterns_learned: usize,
213    /// Final average quality
214    pub final_avg_quality: f32,
215    /// Total duration in seconds
216    pub total_duration_secs: f64,
217    /// Per-epoch statistics
218    pub epoch_stats: Vec<EpochStats>,
219    /// Validation quality (if validation was run)
220    pub validation_quality: Option<f32>,
221}
222
223impl TrainingResult {
224    /// Get examples per second
225    pub fn examples_per_sec(&self) -> f64 {
226        if self.total_duration_secs > 0.0 {
227            self.total_examples as f64 / self.total_duration_secs
228        } else {
229            0.0
230        }
231    }
232
233    /// Get average epoch duration
234    pub fn avg_epoch_duration(&self) -> f64 {
235        if self.epochs_completed > 0 {
236            self.total_duration_secs / self.epochs_completed as f64
237        } else {
238            0.0
239        }
240    }
241
242    /// Check if training improved quality
243    pub fn quality_improved(&self) -> bool {
244        if self.epoch_stats.len() < 2 {
245            return false;
246        }
247        let first = self.epoch_stats.first().unwrap().avg_quality;
248        let last = self.epoch_stats.last().unwrap().avg_quality;
249        last > first
250    }
251
252    /// Get quality improvement
253    pub fn quality_improvement(&self) -> f32 {
254        if self.epoch_stats.len() < 2 {
255            return 0.0;
256        }
257        let first = self.epoch_stats.first().unwrap().avg_quality;
258        let last = self.epoch_stats.last().unwrap().avg_quality;
259        last - first
260    }
261}
262
263impl std::fmt::Display for TrainingResult {
264    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265        write!(
266            f,
267            "TrainingResult(pipeline={}, epochs={}, examples={}, patterns={}, \
268             final_quality={:.4}, duration={:.2}s, throughput={:.1}/s)",
269            self.pipeline_name,
270            self.epochs_completed,
271            self.total_examples,
272            self.patterns_learned,
273            self.final_avg_quality,
274            self.total_duration_secs,
275            self.examples_per_sec()
276        )
277    }
278}
279
280/// Comparison metrics between training runs
281#[derive(Clone, Debug, Serialize, Deserialize)]
282pub struct TrainingComparison {
283    /// Baseline result name
284    pub baseline_name: String,
285    /// Comparison result name
286    pub comparison_name: String,
287    /// Quality difference (comparison - baseline)
288    pub quality_diff: f32,
289    /// Quality improvement percentage
290    pub quality_improvement_pct: f32,
291    /// Throughput difference
292    pub throughput_diff: f64,
293    /// Duration difference in seconds
294    pub duration_diff: f64,
295}
296
297impl TrainingComparison {
298    /// Compare two training results
299    pub fn compare(baseline: &TrainingResult, comparison: &TrainingResult) -> Self {
300        let quality_diff = comparison.final_avg_quality - baseline.final_avg_quality;
301        let quality_improvement_pct = if baseline.final_avg_quality > 0.0 {
302            (quality_diff / baseline.final_avg_quality) * 100.0
303        } else {
304            0.0
305        };
306
307        Self {
308            baseline_name: baseline.pipeline_name.clone(),
309            comparison_name: comparison.pipeline_name.clone(),
310            quality_diff,
311            quality_improvement_pct,
312            throughput_diff: comparison.examples_per_sec() - baseline.examples_per_sec(),
313            duration_diff: comparison.total_duration_secs - baseline.total_duration_secs,
314        }
315    }
316}
317
318impl std::fmt::Display for TrainingComparison {
319    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320        let quality_sign = if self.quality_diff >= 0.0 { "+" } else { "" };
321        let throughput_sign = if self.throughput_diff >= 0.0 { "+" } else { "" };
322
323        write!(
324            f,
325            "Comparison {} vs {}: quality {}{:.4} ({}{:.1}%), throughput {}{:.1}/s",
326            self.comparison_name,
327            self.baseline_name,
328            quality_sign, self.quality_diff,
329            quality_sign, self.quality_improvement_pct,
330            throughput_sign, self.throughput_diff
331        )
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_metrics_creation() {
341        let metrics = TrainingMetrics::new("test");
342        assert_eq!(metrics.name, "test");
343        assert_eq!(metrics.total_examples, 0);
344    }
345
346    #[test]
347    fn test_quality_samples() {
348        let mut metrics = TrainingMetrics::new("test");
349
350        for i in 0..10 {
351            metrics.add_quality_sample(i as f32 / 10.0);
352        }
353
354        assert_eq!(metrics.quality_samples.len(), 10);
355        assert!((metrics.avg_quality() - 0.45).abs() < 0.01);
356    }
357
358    #[test]
359    fn test_quality_percentiles() {
360        let mut metrics = TrainingMetrics::new("test");
361
362        for i in 0..100 {
363            metrics.add_quality_sample(i as f32 / 100.0);
364        }
365
366        assert!((metrics.quality_percentile(50.0) - 0.5).abs() < 0.02);
367        assert!((metrics.quality_percentile(95.0) - 0.95).abs() < 0.02);
368    }
369
370    #[test]
371    fn test_quality_stats() {
372        let mut metrics = TrainingMetrics::new("test");
373        metrics.add_quality_sample(0.5);
374        metrics.add_quality_sample(0.7);
375        metrics.add_quality_sample(0.9);
376
377        let stats = metrics.quality_stats();
378        assert!((stats.avg - 0.7).abs() < 0.01);
379        assert!((stats.min - 0.5).abs() < 0.01);
380        assert!((stats.max - 0.9).abs() < 0.01);
381    }
382
383    #[test]
384    fn test_training_result() {
385        let result = TrainingResult {
386            pipeline_name: "test".into(),
387            epochs_completed: 3,
388            total_examples: 1000,
389            patterns_learned: 50,
390            final_avg_quality: 0.85,
391            total_duration_secs: 10.0,
392            epoch_stats: vec![
393                EpochStats { epoch: 0, examples_processed: 333, avg_quality: 0.75, duration_secs: 3.0 },
394                EpochStats { epoch: 1, examples_processed: 333, avg_quality: 0.80, duration_secs: 3.5 },
395                EpochStats { epoch: 2, examples_processed: 334, avg_quality: 0.85, duration_secs: 3.5 },
396            ],
397            validation_quality: Some(0.82),
398        };
399
400        assert_eq!(result.examples_per_sec(), 100.0);
401        assert!(result.quality_improved());
402        assert!((result.quality_improvement() - 0.10).abs() < 0.01);
403    }
404
405    #[test]
406    fn test_training_comparison() {
407        let baseline = TrainingResult {
408            pipeline_name: "baseline".into(),
409            epochs_completed: 2,
410            total_examples: 500,
411            patterns_learned: 25,
412            final_avg_quality: 0.70,
413            total_duration_secs: 5.0,
414            epoch_stats: vec![],
415            validation_quality: None,
416        };
417
418        let improved = TrainingResult {
419            pipeline_name: "improved".into(),
420            epochs_completed: 2,
421            total_examples: 500,
422            patterns_learned: 30,
423            final_avg_quality: 0.85,
424            total_duration_secs: 4.0,
425            epoch_stats: vec![],
426            validation_quality: None,
427        };
428
429        let comparison = TrainingComparison::compare(&baseline, &improved);
430        assert!((comparison.quality_diff - 0.15).abs() < 0.01);
431        assert!(comparison.quality_improvement_pct > 20.0);
432        assert!(comparison.throughput_diff > 0.0);
433    }
434}