1use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug, Default, Serialize, Deserialize)]
9pub struct TrainingMetrics {
10 pub name: String,
12 pub total_examples: usize,
14 pub training_sessions: u64,
16 pub patterns_learned: usize,
18 pub quality_samples: Vec<f32>,
20 pub validation_quality: Option<f32>,
22 pub performance: PerformanceMetrics,
24}
25
26impl TrainingMetrics {
27 pub fn new(name: &str) -> Self {
29 Self {
30 name: name.to_string(),
31 ..Default::default()
32 }
33 }
34
35 pub fn add_quality_sample(&mut self, quality: f32) {
37 self.quality_samples.push(quality);
38 if self.quality_samples.len() > 10000 {
40 self.quality_samples.remove(0);
41 }
42 }
43
44 pub fn avg_quality(&self) -> f32 {
46 if self.quality_samples.is_empty() {
47 0.0
48 } else {
49 self.quality_samples.iter().sum::<f32>() / self.quality_samples.len() as f32
50 }
51 }
52
53 pub fn quality_percentile(&self, percentile: f32) -> f32 {
55 if self.quality_samples.is_empty() {
56 return 0.0;
57 }
58
59 let mut sorted = self.quality_samples.clone();
60 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
61
62 let idx = ((percentile / 100.0) * (sorted.len() - 1) as f32) as usize;
63 sorted[idx.min(sorted.len() - 1)]
64 }
65
66 pub fn quality_stats(&self) -> QualityMetrics {
68 if self.quality_samples.is_empty() {
69 return QualityMetrics::default();
70 }
71
72 let avg = self.avg_quality();
73 let min = self.quality_samples.iter().cloned().fold(f32::MAX, f32::min);
74 let max = self.quality_samples.iter().cloned().fold(f32::MIN, f32::max);
75
76 let variance = self.quality_samples.iter()
77 .map(|q| (q - avg).powi(2))
78 .sum::<f32>() / self.quality_samples.len() as f32;
79 let std_dev = variance.sqrt();
80
81 QualityMetrics {
82 avg,
83 min,
84 max,
85 std_dev,
86 p25: self.quality_percentile(25.0),
87 p50: self.quality_percentile(50.0),
88 p75: self.quality_percentile(75.0),
89 p95: self.quality_percentile(95.0),
90 sample_count: self.quality_samples.len(),
91 }
92 }
93
94 pub fn reset(&mut self) {
96 self.total_examples = 0;
97 self.training_sessions = 0;
98 self.patterns_learned = 0;
99 self.quality_samples.clear();
100 self.validation_quality = None;
101 self.performance = PerformanceMetrics::default();
102 }
103
104 pub fn merge(&mut self, other: &TrainingMetrics) {
106 self.total_examples += other.total_examples;
107 self.training_sessions += other.training_sessions;
108 self.patterns_learned = other.patterns_learned; self.quality_samples.extend(&other.quality_samples);
110
111 if self.quality_samples.len() > 10000 {
113 let excess = self.quality_samples.len() - 10000;
114 self.quality_samples.drain(0..excess);
115 }
116 }
117}
118
119#[derive(Clone, Debug, Default, Serialize, Deserialize)]
121pub struct QualityMetrics {
122 pub avg: f32,
124 pub min: f32,
126 pub max: f32,
128 pub std_dev: f32,
130 pub p25: f32,
132 pub p50: f32,
134 pub p75: f32,
136 pub p95: f32,
138 pub sample_count: usize,
140}
141
142impl std::fmt::Display for QualityMetrics {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 write!(
145 f,
146 "avg={:.4}, std={:.4}, min={:.4}, max={:.4}, p50={:.4}, p95={:.4} (n={})",
147 self.avg, self.std_dev, self.min, self.max, self.p50, self.p95, self.sample_count
148 )
149 }
150}
151
152#[derive(Clone, Debug, Default, Serialize, Deserialize)]
154pub struct PerformanceMetrics {
155 pub total_training_secs: f64,
157 pub avg_batch_time_ms: f64,
159 pub avg_example_time_us: f64,
161 pub peak_memory_mb: usize,
163 pub examples_per_sec: f64,
165 pub pattern_extraction_ms: f64,
167}
168
169impl PerformanceMetrics {
170 pub fn calculate_throughput(&mut self, examples: usize, duration_secs: f64) {
172 if duration_secs > 0.0 {
173 self.examples_per_sec = examples as f64 / duration_secs;
174 self.avg_example_time_us = (duration_secs * 1_000_000.0) / examples as f64;
175 }
176 }
177}
178
179#[derive(Clone, Debug, Serialize, Deserialize)]
181pub struct EpochStats {
182 pub epoch: usize,
184 pub examples_processed: usize,
186 pub avg_quality: f32,
188 pub duration_secs: f64,
190}
191
192impl std::fmt::Display for EpochStats {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 write!(
195 f,
196 "Epoch {}: {} examples, avg_quality={:.4}, {:.2}s",
197 self.epoch + 1, self.examples_processed, self.avg_quality, self.duration_secs
198 )
199 }
200}
201
202#[derive(Clone, Debug, Serialize, Deserialize)]
204pub struct TrainingResult {
205 pub pipeline_name: String,
207 pub epochs_completed: usize,
209 pub total_examples: usize,
211 pub patterns_learned: usize,
213 pub final_avg_quality: f32,
215 pub total_duration_secs: f64,
217 pub epoch_stats: Vec<EpochStats>,
219 pub validation_quality: Option<f32>,
221}
222
223impl TrainingResult {
224 pub fn examples_per_sec(&self) -> f64 {
226 if self.total_duration_secs > 0.0 {
227 self.total_examples as f64 / self.total_duration_secs
228 } else {
229 0.0
230 }
231 }
232
233 pub fn avg_epoch_duration(&self) -> f64 {
235 if self.epochs_completed > 0 {
236 self.total_duration_secs / self.epochs_completed as f64
237 } else {
238 0.0
239 }
240 }
241
242 pub fn quality_improved(&self) -> bool {
244 if self.epoch_stats.len() < 2 {
245 return false;
246 }
247 let first = self.epoch_stats.first().unwrap().avg_quality;
248 let last = self.epoch_stats.last().unwrap().avg_quality;
249 last > first
250 }
251
252 pub fn quality_improvement(&self) -> f32 {
254 if self.epoch_stats.len() < 2 {
255 return 0.0;
256 }
257 let first = self.epoch_stats.first().unwrap().avg_quality;
258 let last = self.epoch_stats.last().unwrap().avg_quality;
259 last - first
260 }
261}
262
263impl std::fmt::Display for TrainingResult {
264 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 write!(
266 f,
267 "TrainingResult(pipeline={}, epochs={}, examples={}, patterns={}, \
268 final_quality={:.4}, duration={:.2}s, throughput={:.1}/s)",
269 self.pipeline_name,
270 self.epochs_completed,
271 self.total_examples,
272 self.patterns_learned,
273 self.final_avg_quality,
274 self.total_duration_secs,
275 self.examples_per_sec()
276 )
277 }
278}
279
280#[derive(Clone, Debug, Serialize, Deserialize)]
282pub struct TrainingComparison {
283 pub baseline_name: String,
285 pub comparison_name: String,
287 pub quality_diff: f32,
289 pub quality_improvement_pct: f32,
291 pub throughput_diff: f64,
293 pub duration_diff: f64,
295}
296
297impl TrainingComparison {
298 pub fn compare(baseline: &TrainingResult, comparison: &TrainingResult) -> Self {
300 let quality_diff = comparison.final_avg_quality - baseline.final_avg_quality;
301 let quality_improvement_pct = if baseline.final_avg_quality > 0.0 {
302 (quality_diff / baseline.final_avg_quality) * 100.0
303 } else {
304 0.0
305 };
306
307 Self {
308 baseline_name: baseline.pipeline_name.clone(),
309 comparison_name: comparison.pipeline_name.clone(),
310 quality_diff,
311 quality_improvement_pct,
312 throughput_diff: comparison.examples_per_sec() - baseline.examples_per_sec(),
313 duration_diff: comparison.total_duration_secs - baseline.total_duration_secs,
314 }
315 }
316}
317
318impl std::fmt::Display for TrainingComparison {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 let quality_sign = if self.quality_diff >= 0.0 { "+" } else { "" };
321 let throughput_sign = if self.throughput_diff >= 0.0 { "+" } else { "" };
322
323 write!(
324 f,
325 "Comparison {} vs {}: quality {}{:.4} ({}{:.1}%), throughput {}{:.1}/s",
326 self.comparison_name,
327 self.baseline_name,
328 quality_sign, self.quality_diff,
329 quality_sign, self.quality_improvement_pct,
330 throughput_sign, self.throughput_diff
331 )
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_metrics_creation() {
341 let metrics = TrainingMetrics::new("test");
342 assert_eq!(metrics.name, "test");
343 assert_eq!(metrics.total_examples, 0);
344 }
345
346 #[test]
347 fn test_quality_samples() {
348 let mut metrics = TrainingMetrics::new("test");
349
350 for i in 0..10 {
351 metrics.add_quality_sample(i as f32 / 10.0);
352 }
353
354 assert_eq!(metrics.quality_samples.len(), 10);
355 assert!((metrics.avg_quality() - 0.45).abs() < 0.01);
356 }
357
358 #[test]
359 fn test_quality_percentiles() {
360 let mut metrics = TrainingMetrics::new("test");
361
362 for i in 0..100 {
363 metrics.add_quality_sample(i as f32 / 100.0);
364 }
365
366 assert!((metrics.quality_percentile(50.0) - 0.5).abs() < 0.02);
367 assert!((metrics.quality_percentile(95.0) - 0.95).abs() < 0.02);
368 }
369
370 #[test]
371 fn test_quality_stats() {
372 let mut metrics = TrainingMetrics::new("test");
373 metrics.add_quality_sample(0.5);
374 metrics.add_quality_sample(0.7);
375 metrics.add_quality_sample(0.9);
376
377 let stats = metrics.quality_stats();
378 assert!((stats.avg - 0.7).abs() < 0.01);
379 assert!((stats.min - 0.5).abs() < 0.01);
380 assert!((stats.max - 0.9).abs() < 0.01);
381 }
382
383 #[test]
384 fn test_training_result() {
385 let result = TrainingResult {
386 pipeline_name: "test".into(),
387 epochs_completed: 3,
388 total_examples: 1000,
389 patterns_learned: 50,
390 final_avg_quality: 0.85,
391 total_duration_secs: 10.0,
392 epoch_stats: vec![
393 EpochStats { epoch: 0, examples_processed: 333, avg_quality: 0.75, duration_secs: 3.0 },
394 EpochStats { epoch: 1, examples_processed: 333, avg_quality: 0.80, duration_secs: 3.5 },
395 EpochStats { epoch: 2, examples_processed: 334, avg_quality: 0.85, duration_secs: 3.5 },
396 ],
397 validation_quality: Some(0.82),
398 };
399
400 assert_eq!(result.examples_per_sec(), 100.0);
401 assert!(result.quality_improved());
402 assert!((result.quality_improvement() - 0.10).abs() < 0.01);
403 }
404
405 #[test]
406 fn test_training_comparison() {
407 let baseline = TrainingResult {
408 pipeline_name: "baseline".into(),
409 epochs_completed: 2,
410 total_examples: 500,
411 patterns_learned: 25,
412 final_avg_quality: 0.70,
413 total_duration_secs: 5.0,
414 epoch_stats: vec![],
415 validation_quality: None,
416 };
417
418 let improved = TrainingResult {
419 pipeline_name: "improved".into(),
420 epochs_completed: 2,
421 total_examples: 500,
422 patterns_learned: 30,
423 final_avg_quality: 0.85,
424 total_duration_secs: 4.0,
425 epoch_stats: vec![],
426 validation_quality: None,
427 };
428
429 let comparison = TrainingComparison::compare(&baseline, &improved);
430 assert!((comparison.quality_diff - 0.15).abs() < 0.01);
431 assert!(comparison.quality_improvement_pct > 20.0);
432 assert!(comparison.throughput_diff > 0.0);
433 }
434}