1use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug, Default, Serialize, Deserialize)]
9pub struct TrainingMetrics {
10 pub name: String,
12 pub total_examples: usize,
14 pub training_sessions: u64,
16 pub patterns_learned: usize,
18 pub quality_samples: Vec<f32>,
20 pub validation_quality: Option<f32>,
22 pub performance: PerformanceMetrics,
24}
25
26impl TrainingMetrics {
27 pub fn new(name: &str) -> Self {
29 Self {
30 name: name.to_string(),
31 ..Default::default()
32 }
33 }
34
35 pub fn add_quality_sample(&mut self, quality: f32) {
37 self.quality_samples.push(quality);
38 if self.quality_samples.len() > 10000 {
40 self.quality_samples.remove(0);
41 }
42 }
43
44 pub fn avg_quality(&self) -> f32 {
46 if self.quality_samples.is_empty() {
47 0.0
48 } else {
49 self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
50 }
51 }
52
53 pub fn quality_percentile(&self, percentile: f32) -> f32 {
55 if self.quality_samples.is_empty() {
56 return 0.0;
57 }
58
59 let mut sorted = self.quality_samples.clone();
60 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
61
62 let idx = ((percentile / 100.0) * (sorted.len() - 1) as f32) as usize;
63 sorted[idx.min(sorted.len() - 1)]
64 }
65
66 pub fn quality_stats(&self) -> QualityMetrics {
68 if self.quality_samples.is_empty() {
69 return QualityMetrics::default();
70 }
71
72 let avg = self.avg_quality();
73 let min = self
74 .quality_samples
75 .iter()
76 .cloned()
77 .fold(f32::MAX, f32::min);
78 let max = self
79 .quality_samples
80 .iter()
81 .cloned()
82 .fold(f32::MIN, f32::max);
83
84 let variance = self
85 .quality_samples
86 .iter()
87 .map(|q| (q - avg).powi(2))
88 .sum::<f32>()
89 / self.quality_samples.len() as f32;
90 let std_dev = variance.sqrt();
91
92 QualityMetrics {
93 avg,
94 min,
95 max,
96 std_dev,
97 p25: self.quality_percentile(25.0),
98 p50: self.quality_percentile(50.0),
99 p75: self.quality_percentile(75.0),
100 p95: self.quality_percentile(95.0),
101 sample_count: self.quality_samples.len(),
102 }
103 }
104
105 pub fn reset(&mut self) {
107 self.total_examples = 0;
108 self.training_sessions = 0;
109 self.patterns_learned = 0;
110 self.quality_samples.clear();
111 self.validation_quality = None;
112 self.performance = PerformanceMetrics::default();
113 }
114
115 pub fn merge(&mut self, other: &TrainingMetrics) {
117 self.total_examples += other.total_examples;
118 self.training_sessions += other.training_sessions;
119 self.patterns_learned = other.patterns_learned; self.quality_samples.extend(&other.quality_samples);
121
122 if self.quality_samples.len() > 10000 {
124 let excess = self.quality_samples.len() - 10000;
125 self.quality_samples.drain(0..excess);
126 }
127 }
128}
129
130#[derive(Clone, Debug, Default, Serialize, Deserialize)]
132pub struct QualityMetrics {
133 pub avg: f32,
135 pub min: f32,
137 pub max: f32,
139 pub std_dev: f32,
141 pub p25: f32,
143 pub p50: f32,
145 pub p75: f32,
147 pub p95: f32,
149 pub sample_count: usize,
151}
152
153impl std::fmt::Display for QualityMetrics {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 write!(
156 f,
157 "avg={:.4}, std={:.4}, min={:.4}, max={:.4}, p50={:.4}, p95={:.4} (n={})",
158 self.avg, self.std_dev, self.min, self.max, self.p50, self.p95, self.sample_count
159 )
160 }
161}
162
163#[derive(Clone, Debug, Default, Serialize, Deserialize)]
165pub struct PerformanceMetrics {
166 pub total_training_secs: f64,
168 pub avg_batch_time_ms: f64,
170 pub avg_example_time_us: f64,
172 pub peak_memory_mb: usize,
174 pub examples_per_sec: f64,
176 pub pattern_extraction_ms: f64,
178}
179
180impl PerformanceMetrics {
181 pub fn calculate_throughput(&mut self, examples: usize, duration_secs: f64) {
183 if duration_secs > 0.0 {
184 self.examples_per_sec = examples as f64 / duration_secs;
185 self.avg_example_time_us = (duration_secs * 1_000_000.0) / examples as f64;
186 }
187 }
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize)]
192pub struct EpochStats {
193 pub epoch: usize,
195 pub examples_processed: usize,
197 pub avg_quality: f32,
199 pub duration_secs: f64,
201}
202
203impl std::fmt::Display for EpochStats {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 write!(
206 f,
207 "Epoch {}: {} examples, avg_quality={:.4}, {:.2}s",
208 self.epoch + 1,
209 self.examples_processed,
210 self.avg_quality,
211 self.duration_secs
212 )
213 }
214}
215
216#[derive(Clone, Debug, Serialize, Deserialize)]
218pub struct TrainingResult {
219 pub pipeline_name: String,
221 pub epochs_completed: usize,
223 pub total_examples: usize,
225 pub patterns_learned: usize,
227 pub final_avg_quality: f32,
229 pub total_duration_secs: f64,
231 pub epoch_stats: Vec<EpochStats>,
233 pub validation_quality: Option<f32>,
235}
236
237impl TrainingResult {
238 pub fn examples_per_sec(&self) -> f64 {
240 if self.total_duration_secs > 0.0 {
241 self.total_examples as f64 / self.total_duration_secs
242 } else {
243 0.0
244 }
245 }
246
247 pub fn avg_epoch_duration(&self) -> f64 {
249 if self.epochs_completed > 0 {
250 self.total_duration_secs / self.epochs_completed as f64
251 } else {
252 0.0
253 }
254 }
255
256 pub fn quality_improved(&self) -> bool {
258 if self.epoch_stats.len() < 2 {
259 return false;
260 }
261 let first = self.epoch_stats.first().unwrap().avg_quality;
262 let last = self.epoch_stats.last().unwrap().avg_quality;
263 last > first
264 }
265
266 pub fn quality_improvement(&self) -> f32 {
268 if self.epoch_stats.len() < 2 {
269 return 0.0;
270 }
271 let first = self.epoch_stats.first().unwrap().avg_quality;
272 let last = self.epoch_stats.last().unwrap().avg_quality;
273 last - first
274 }
275}
276
277impl std::fmt::Display for TrainingResult {
278 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279 write!(
280 f,
281 "TrainingResult(pipeline={}, epochs={}, examples={}, patterns={}, \
282 final_quality={:.4}, duration={:.2}s, throughput={:.1}/s)",
283 self.pipeline_name,
284 self.epochs_completed,
285 self.total_examples,
286 self.patterns_learned,
287 self.final_avg_quality,
288 self.total_duration_secs,
289 self.examples_per_sec()
290 )
291 }
292}
293
294#[derive(Clone, Debug, Serialize, Deserialize)]
296#[allow(dead_code)]
297pub struct TrainingComparison {
298 pub baseline_name: String,
300 pub comparison_name: String,
302 pub quality_diff: f32,
304 pub quality_improvement_pct: f32,
306 pub throughput_diff: f64,
308 pub duration_diff: f64,
310}
311
312#[allow(dead_code)]
313impl TrainingComparison {
314 pub fn compare(baseline: &TrainingResult, comparison: &TrainingResult) -> Self {
316 let quality_diff = comparison.final_avg_quality - baseline.final_avg_quality;
317 let quality_improvement_pct = if baseline.final_avg_quality > 0.0 {
318 (quality_diff / baseline.final_avg_quality) * 100.0
319 } else {
320 0.0
321 };
322
323 Self {
324 baseline_name: baseline.pipeline_name.clone(),
325 comparison_name: comparison.pipeline_name.clone(),
326 quality_diff,
327 quality_improvement_pct,
328 throughput_diff: comparison.examples_per_sec() - baseline.examples_per_sec(),
329 duration_diff: comparison.total_duration_secs - baseline.total_duration_secs,
330 }
331 }
332}
333
334impl std::fmt::Display for TrainingComparison {
335 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336 let quality_sign = if self.quality_diff >= 0.0 { "+" } else { "" };
337 let throughput_sign = if self.throughput_diff >= 0.0 { "+" } else { "" };
338
339 write!(
340 f,
341 "Comparison {} vs {}: quality {}{:.4} ({}{:.1}%), throughput {}{:.1}/s",
342 self.comparison_name,
343 self.baseline_name,
344 quality_sign,
345 self.quality_diff,
346 quality_sign,
347 self.quality_improvement_pct,
348 throughput_sign,
349 self.throughput_diff
350 )
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_metrics_creation() {
360 let metrics = TrainingMetrics::new("test");
361 assert_eq!(metrics.name, "test");
362 assert_eq!(metrics.total_examples, 0);
363 }
364
365 #[test]
366 fn test_quality_samples() {
367 let mut metrics = TrainingMetrics::new("test");
368
369 for i in 0..10 {
370 metrics.add_quality_sample(i as f32 / 10.0);
371 }
372
373 assert_eq!(metrics.quality_samples.len(), 10);
374 assert!((metrics.avg_quality() - 0.45).abs() < 0.01);
375 }
376
377 #[test]
378 fn test_quality_percentiles() {
379 let mut metrics = TrainingMetrics::new("test");
380
381 for i in 0..100 {
382 metrics.add_quality_sample(i as f32 / 100.0);
383 }
384
385 assert!((metrics.quality_percentile(50.0) - 0.5).abs() < 0.02);
386 assert!((metrics.quality_percentile(95.0) - 0.95).abs() < 0.02);
387 }
388
389 #[test]
390 fn test_quality_stats() {
391 let mut metrics = TrainingMetrics::new("test");
392 metrics.add_quality_sample(0.5);
393 metrics.add_quality_sample(0.7);
394 metrics.add_quality_sample(0.9);
395
396 let stats = metrics.quality_stats();
397 assert!((stats.avg - 0.7).abs() < 0.01);
398 assert!((stats.min - 0.5).abs() < 0.01);
399 assert!((stats.max - 0.9).abs() < 0.01);
400 }
401
402 #[test]
403 fn test_training_result() {
404 let result = TrainingResult {
405 pipeline_name: "test".into(),
406 epochs_completed: 3,
407 total_examples: 1000,
408 patterns_learned: 50,
409 final_avg_quality: 0.85,
410 total_duration_secs: 10.0,
411 epoch_stats: vec![
412 EpochStats {
413 epoch: 0,
414 examples_processed: 333,
415 avg_quality: 0.75,
416 duration_secs: 3.0,
417 },
418 EpochStats {
419 epoch: 1,
420 examples_processed: 333,
421 avg_quality: 0.80,
422 duration_secs: 3.5,
423 },
424 EpochStats {
425 epoch: 2,
426 examples_processed: 334,
427 avg_quality: 0.85,
428 duration_secs: 3.5,
429 },
430 ],
431 validation_quality: Some(0.82),
432 };
433
434 assert_eq!(result.examples_per_sec(), 100.0);
435 assert!(result.quality_improved());
436 assert!((result.quality_improvement() - 0.10).abs() < 0.01);
437 }
438
439 #[test]
440 fn test_training_comparison() {
441 let baseline = TrainingResult {
442 pipeline_name: "baseline".into(),
443 epochs_completed: 2,
444 total_examples: 500,
445 patterns_learned: 25,
446 final_avg_quality: 0.70,
447 total_duration_secs: 5.0,
448 epoch_stats: vec![],
449 validation_quality: None,
450 };
451
452 let improved = TrainingResult {
453 pipeline_name: "improved".into(),
454 epochs_completed: 2,
455 total_examples: 500,
456 patterns_learned: 30,
457 final_avg_quality: 0.85,
458 total_duration_secs: 4.0,
459 epoch_stats: vec![],
460 validation_quality: None,
461 };
462
463 let comparison = TrainingComparison::compare(&baseline, &improved);
464 assert!((comparison.quality_diff - 0.15).abs() < 0.01);
465 assert!(comparison.quality_improvement_pct > 20.0);
466 assert!(comparison.throughput_diff > 0.0);
467 }
468}