1use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug, Default, Serialize, Deserialize)]
9pub struct TrainingMetrics {
10 pub name: String,
12 pub total_examples: usize,
14 pub training_sessions: u64,
16 pub patterns_learned: usize,
18 pub quality_samples: Vec<f32>,
20 pub validation_quality: Option<f32>,
22 pub performance: PerformanceMetrics,
24}
25
26impl TrainingMetrics {
27 pub fn new(name: &str) -> Self {
29 Self {
30 name: name.to_string(),
31 ..Default::default()
32 }
33 }
34
35 pub fn add_quality_sample(&mut self, quality: f32) {
37 self.quality_samples.push(quality);
38 if self.quality_samples.len() > 10000 {
40 self.quality_samples.remove(0);
41 }
42 }
43
44 pub fn avg_quality(&self) -> f32 {
46 if self.quality_samples.is_empty() {
47 0.0
48 } else {
49 self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
50 }
51 }
52
53 pub fn quality_percentile(&self, percentile: f32) -> f32 {
55 if self.quality_samples.is_empty() {
56 return 0.0;
57 }
58
59 let mut sorted = self.quality_samples.clone();
60 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
61
62 let idx = ((percentile / 100.0) * (sorted.len() - 1) as f32) as usize;
63 sorted[idx.min(sorted.len() - 1)]
64 }
65
66 pub fn quality_stats(&self) -> QualityMetrics {
68 if self.quality_samples.is_empty() {
69 return QualityMetrics::default();
70 }
71
72 let avg = self.avg_quality();
73 let min = self
74 .quality_samples
75 .iter()
76 .cloned()
77 .fold(f32::MAX, f32::min);
78 let max = self
79 .quality_samples
80 .iter()
81 .cloned()
82 .fold(f32::MIN, f32::max);
83
84 let variance = self
85 .quality_samples
86 .iter()
87 .map(|q| (q - avg).powi(2))
88 .sum::<f32>()
89 / self.quality_samples.len() as f32;
90 let std_dev = variance.sqrt();
91
92 QualityMetrics {
93 avg,
94 min,
95 max,
96 std_dev,
97 p25: self.quality_percentile(25.0),
98 p50: self.quality_percentile(50.0),
99 p75: self.quality_percentile(75.0),
100 p95: self.quality_percentile(95.0),
101 sample_count: self.quality_samples.len(),
102 }
103 }
104
105 pub fn reset(&mut self) {
107 self.total_examples = 0;
108 self.training_sessions = 0;
109 self.patterns_learned = 0;
110 self.quality_samples.clear();
111 self.validation_quality = None;
112 self.performance = PerformanceMetrics::default();
113 }
114
115 pub fn merge(&mut self, other: &TrainingMetrics) {
117 self.total_examples += other.total_examples;
118 self.training_sessions += other.training_sessions;
119 self.patterns_learned = other.patterns_learned; self.quality_samples.extend(&other.quality_samples);
121
122 if self.quality_samples.len() > 10000 {
124 let excess = self.quality_samples.len() - 10000;
125 self.quality_samples.drain(0..excess);
126 }
127 }
128}
129
130#[derive(Clone, Debug, Default, Serialize, Deserialize)]
132pub struct QualityMetrics {
133 pub avg: f32,
135 pub min: f32,
137 pub max: f32,
139 pub std_dev: f32,
141 pub p25: f32,
143 pub p50: f32,
145 pub p75: f32,
147 pub p95: f32,
149 pub sample_count: usize,
151}
152
153impl std::fmt::Display for QualityMetrics {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 write!(
156 f,
157 "avg={:.4}, std={:.4}, min={:.4}, max={:.4}, p50={:.4}, p95={:.4} (n={})",
158 self.avg, self.std_dev, self.min, self.max, self.p50, self.p95, self.sample_count
159 )
160 }
161}
162
163#[derive(Clone, Debug, Default, Serialize, Deserialize)]
165pub struct PerformanceMetrics {
166 pub total_training_secs: f64,
168 pub avg_batch_time_ms: f64,
170 pub avg_example_time_us: f64,
172 pub peak_memory_mb: usize,
174 pub examples_per_sec: f64,
176 pub pattern_extraction_ms: f64,
178}
179
180impl PerformanceMetrics {
181 pub fn calculate_throughput(&mut self, examples: usize, duration_secs: f64) {
183 if duration_secs > 0.0 {
184 self.examples_per_sec = examples as f64 / duration_secs;
185 self.avg_example_time_us = (duration_secs * 1_000_000.0) / examples as f64;
186 }
187 }
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize)]
192pub struct EpochStats {
193 pub epoch: usize,
195 pub examples_processed: usize,
197 pub avg_quality: f32,
199 pub duration_secs: f64,
201}
202
203impl std::fmt::Display for EpochStats {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 write!(
206 f,
207 "Epoch {}: {} examples, avg_quality={:.4}, {:.2}s",
208 self.epoch + 1,
209 self.examples_processed,
210 self.avg_quality,
211 self.duration_secs
212 )
213 }
214}
215
216#[derive(Clone, Debug, Serialize, Deserialize)]
218pub struct TrainingResult {
219 pub pipeline_name: String,
221 pub epochs_completed: usize,
223 pub total_examples: usize,
225 pub patterns_learned: usize,
227 pub final_avg_quality: f32,
229 pub total_duration_secs: f64,
231 pub epoch_stats: Vec<EpochStats>,
233 pub validation_quality: Option<f32>,
235}
236
237impl TrainingResult {
238 pub fn examples_per_sec(&self) -> f64 {
240 if self.total_duration_secs > 0.0 {
241 self.total_examples as f64 / self.total_duration_secs
242 } else {
243 0.0
244 }
245 }
246
247 pub fn avg_epoch_duration(&self) -> f64 {
249 if self.epochs_completed > 0 {
250 self.total_duration_secs / self.epochs_completed as f64
251 } else {
252 0.0
253 }
254 }
255
256 pub fn quality_improved(&self) -> bool {
258 if self.epoch_stats.len() < 2 {
259 return false;
260 }
261 let first = self.epoch_stats.first().unwrap().avg_quality;
262 let last = self.epoch_stats.last().unwrap().avg_quality;
263 last > first
264 }
265
266 pub fn quality_improvement(&self) -> f32 {
268 if self.epoch_stats.len() < 2 {
269 return 0.0;
270 }
271 let first = self.epoch_stats.first().unwrap().avg_quality;
272 let last = self.epoch_stats.last().unwrap().avg_quality;
273 last - first
274 }
275}
276
277impl std::fmt::Display for TrainingResult {
278 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279 write!(
280 f,
281 "TrainingResult(pipeline={}, epochs={}, examples={}, patterns={}, \
282 final_quality={:.4}, duration={:.2}s, throughput={:.1}/s)",
283 self.pipeline_name,
284 self.epochs_completed,
285 self.total_examples,
286 self.patterns_learned,
287 self.final_avg_quality,
288 self.total_duration_secs,
289 self.examples_per_sec()
290 )
291 }
292}
293
294#[derive(Clone, Debug, Serialize, Deserialize)]
296pub struct TrainingComparison {
297 pub baseline_name: String,
299 pub comparison_name: String,
301 pub quality_diff: f32,
303 pub quality_improvement_pct: f32,
305 pub throughput_diff: f64,
307 pub duration_diff: f64,
309}
310
311impl TrainingComparison {
312 pub fn compare(baseline: &TrainingResult, comparison: &TrainingResult) -> Self {
314 let quality_diff = comparison.final_avg_quality - baseline.final_avg_quality;
315 let quality_improvement_pct = if baseline.final_avg_quality > 0.0 {
316 (quality_diff / baseline.final_avg_quality) * 100.0
317 } else {
318 0.0
319 };
320
321 Self {
322 baseline_name: baseline.pipeline_name.clone(),
323 comparison_name: comparison.pipeline_name.clone(),
324 quality_diff,
325 quality_improvement_pct,
326 throughput_diff: comparison.examples_per_sec() - baseline.examples_per_sec(),
327 duration_diff: comparison.total_duration_secs - baseline.total_duration_secs,
328 }
329 }
330}
331
332impl std::fmt::Display for TrainingComparison {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 let quality_sign = if self.quality_diff >= 0.0 { "+" } else { "" };
335 let throughput_sign = if self.throughput_diff >= 0.0 { "+" } else { "" };
336
337 write!(
338 f,
339 "Comparison {} vs {}: quality {}{:.4} ({}{:.1}%), throughput {}{:.1}/s",
340 self.comparison_name,
341 self.baseline_name,
342 quality_sign,
343 self.quality_diff,
344 quality_sign,
345 self.quality_improvement_pct,
346 throughput_sign,
347 self.throughput_diff
348 )
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_metrics_creation() {
358 let metrics = TrainingMetrics::new("test");
359 assert_eq!(metrics.name, "test");
360 assert_eq!(metrics.total_examples, 0);
361 }
362
363 #[test]
364 fn test_quality_samples() {
365 let mut metrics = TrainingMetrics::new("test");
366
367 for i in 0..10 {
368 metrics.add_quality_sample(i as f32 / 10.0);
369 }
370
371 assert_eq!(metrics.quality_samples.len(), 10);
372 assert!((metrics.avg_quality() - 0.45).abs() < 0.01);
373 }
374
375 #[test]
376 fn test_quality_percentiles() {
377 let mut metrics = TrainingMetrics::new("test");
378
379 for i in 0..100 {
380 metrics.add_quality_sample(i as f32 / 100.0);
381 }
382
383 assert!((metrics.quality_percentile(50.0) - 0.5).abs() < 0.02);
384 assert!((metrics.quality_percentile(95.0) - 0.95).abs() < 0.02);
385 }
386
387 #[test]
388 fn test_quality_stats() {
389 let mut metrics = TrainingMetrics::new("test");
390 metrics.add_quality_sample(0.5);
391 metrics.add_quality_sample(0.7);
392 metrics.add_quality_sample(0.9);
393
394 let stats = metrics.quality_stats();
395 assert!((stats.avg - 0.7).abs() < 0.01);
396 assert!((stats.min - 0.5).abs() < 0.01);
397 assert!((stats.max - 0.9).abs() < 0.01);
398 }
399
400 #[test]
401 fn test_training_result() {
402 let result = TrainingResult {
403 pipeline_name: "test".into(),
404 epochs_completed: 3,
405 total_examples: 1000,
406 patterns_learned: 50,
407 final_avg_quality: 0.85,
408 total_duration_secs: 10.0,
409 epoch_stats: vec![
410 EpochStats {
411 epoch: 0,
412 examples_processed: 333,
413 avg_quality: 0.75,
414 duration_secs: 3.0,
415 },
416 EpochStats {
417 epoch: 1,
418 examples_processed: 333,
419 avg_quality: 0.80,
420 duration_secs: 3.5,
421 },
422 EpochStats {
423 epoch: 2,
424 examples_processed: 334,
425 avg_quality: 0.85,
426 duration_secs: 3.5,
427 },
428 ],
429 validation_quality: Some(0.82),
430 };
431
432 assert_eq!(result.examples_per_sec(), 100.0);
433 assert!(result.quality_improved());
434 assert!((result.quality_improvement() - 0.10).abs() < 0.01);
435 }
436
437 #[test]
438 fn test_training_comparison() {
439 let baseline = TrainingResult {
440 pipeline_name: "baseline".into(),
441 epochs_completed: 2,
442 total_examples: 500,
443 patterns_learned: 25,
444 final_avg_quality: 0.70,
445 total_duration_secs: 5.0,
446 epoch_stats: vec![],
447 validation_quality: None,
448 };
449
450 let improved = TrainingResult {
451 pipeline_name: "improved".into(),
452 epochs_completed: 2,
453 total_examples: 500,
454 patterns_learned: 30,
455 final_avg_quality: 0.85,
456 total_duration_secs: 4.0,
457 epoch_stats: vec![],
458 validation_quality: None,
459 };
460
461 let comparison = TrainingComparison::compare(&baseline, &improved);
462 assert!((comparison.quality_diff - 0.15).abs() < 0.01);
463 assert!(comparison.quality_improvement_pct > 20.0);
464 assert!(comparison.throughput_diff > 0.0);
465 }
466}