1use super::metrics::{EpochStats, TrainingMetrics, TrainingResult};
6use super::templates::{DataSizeHint, TrainingMethod, TrainingTemplate};
7use crate::engine::SonaEngine;
8use crate::time_compat::Instant;
9use crate::types::SonaConfig;
10use serde::{Deserialize, Serialize};
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct TrainingExample {
15 pub embedding: Vec<f32>,
17 pub activations: Option<Vec<f32>>,
19 pub attention: Option<Vec<f32>>,
21 pub quality: f32,
23 pub reward: Option<f32>,
25 pub route: Option<String>,
27 pub context: Vec<String>,
29 pub weight: f32,
31 pub tags: Vec<String>,
33}
34
35impl TrainingExample {
36 pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
38 Self {
39 embedding,
40 activations: None,
41 attention: None,
42 quality,
43 reward: None,
44 route: None,
45 context: Vec::new(),
46 weight: 1.0,
47 tags: Vec::new(),
48 }
49 }
50
51 pub fn with_activations(mut self, activations: Vec<f32>) -> Self {
53 self.activations = Some(activations);
54 self
55 }
56
57 pub fn with_attention(mut self, attention: Vec<f32>) -> Self {
59 self.attention = Some(attention);
60 self
61 }
62
63 pub fn with_reward(mut self, reward: f32) -> Self {
65 self.reward = Some(reward);
66 self
67 }
68
69 pub fn with_route(mut self, route: impl Into<String>) -> Self {
71 self.route = Some(route.into());
72 self
73 }
74
75 pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
77 self.context.push(ctx.into());
78 self
79 }
80
81 pub fn with_weight(mut self, weight: f32) -> Self {
83 self.weight = weight;
84 self
85 }
86
87 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
89 self.tags.push(tag.into());
90 self
91 }
92
93 pub fn get_activations(&self) -> Vec<f32> {
95 self.activations
96 .clone()
97 .unwrap_or_else(|| self.embedding.clone())
98 }
99
100 pub fn get_attention(&self) -> Vec<f32> {
102 self.attention
103 .clone()
104 .unwrap_or_else(|| vec![1.0 / 64.0; 64])
105 }
106
107 pub fn get_reward(&self) -> f32 {
109 self.reward.unwrap_or(self.quality)
110 }
111}
112
113#[derive(Clone, Debug, Serialize, Deserialize)]
115pub struct BatchConfig {
116 pub batch_size: usize,
118 pub shuffle: bool,
120 pub drop_last: bool,
122 pub epochs: usize,
124 pub early_stopping_patience: Option<usize>,
126 pub min_quality_improvement: f32,
128}
129
130impl Default for BatchConfig {
131 fn default() -> Self {
132 Self {
133 batch_size: 32,
134 shuffle: true,
135 drop_last: false,
136 epochs: 1,
137 early_stopping_patience: None,
138 min_quality_improvement: 0.001,
139 }
140 }
141}
142
143impl BatchConfig {
144 pub fn single_pass() -> Self {
146 Self {
147 batch_size: usize::MAX,
148 shuffle: false,
149 drop_last: false,
150 epochs: 1,
151 early_stopping_patience: None,
152 min_quality_improvement: 0.0,
153 }
154 }
155
156 pub fn for_data_size(hint: &DataSizeHint) -> Self {
158 match hint {
159 DataSizeHint::Tiny => Self {
160 batch_size: 8,
161 epochs: 10,
162 early_stopping_patience: Some(3),
163 ..Default::default()
164 },
165 DataSizeHint::Small => Self {
166 batch_size: 16,
167 epochs: 5,
168 early_stopping_patience: Some(2),
169 ..Default::default()
170 },
171 DataSizeHint::Medium => Self {
172 batch_size: 32,
173 epochs: 3,
174 early_stopping_patience: Some(2),
175 ..Default::default()
176 },
177 DataSizeHint::Large => Self {
178 batch_size: 64,
179 epochs: 2,
180 ..Default::default()
181 },
182 DataSizeHint::Massive => Self {
183 batch_size: 128,
184 epochs: 1,
185 ..Default::default()
186 },
187 }
188 }
189}
190
191#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
193pub enum PipelineStage {
194 Idle,
196 Preprocessing,
198 Training,
200 Validation,
202 PatternExtraction,
204 Export,
206 Completed,
208 Failed,
210}
211
212impl std::fmt::Display for PipelineStage {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 match self {
215 PipelineStage::Idle => write!(f, "idle"),
216 PipelineStage::Preprocessing => write!(f, "preprocessing"),
217 PipelineStage::Training => write!(f, "training"),
218 PipelineStage::Validation => write!(f, "validation"),
219 PipelineStage::PatternExtraction => write!(f, "pattern_extraction"),
220 PipelineStage::Export => write!(f, "export"),
221 PipelineStage::Completed => write!(f, "completed"),
222 PipelineStage::Failed => write!(f, "failed"),
223 }
224 }
225}
226
227pub trait TrainingCallback: Send + Sync {
229 fn on_stage_change(&self, _stage: &PipelineStage) {}
231
232 fn on_batch_complete(&self, _batch_idx: usize, _total_batches: usize, _avg_quality: f32) {}
234
235 fn on_epoch_complete(&self, _epoch: usize, _stats: &EpochStats) {}
237
238 fn on_training_complete(&self, _result: &TrainingResult) {}
240
241 fn on_error(&self, _error: &str) {}
243}
244
245pub struct NoOpCallback;
247impl TrainingCallback for NoOpCallback {}
248
249pub struct LoggingCallback {
251 prefix: String,
252}
253
254impl LoggingCallback {
255 pub fn new(prefix: impl Into<String>) -> Self {
257 Self {
258 prefix: prefix.into(),
259 }
260 }
261}
262
263impl TrainingCallback for LoggingCallback {
264 fn on_stage_change(&self, stage: &PipelineStage) {
265 println!("[{}] Stage: {}", self.prefix, stage);
266 }
267
268 fn on_batch_complete(&self, batch_idx: usize, total_batches: usize, avg_quality: f32) {
269 if batch_idx % 10 == 0 || batch_idx == total_batches - 1 {
270 println!(
271 "[{}] Batch {}/{}: avg_quality={:.4}",
272 self.prefix,
273 batch_idx + 1,
274 total_batches,
275 avg_quality
276 );
277 }
278 }
279
280 fn on_epoch_complete(&self, epoch: usize, stats: &EpochStats) {
281 println!(
282 "[{}] Epoch {}: examples={}, avg_quality={:.4}, duration={:.2}s",
283 self.prefix,
284 epoch + 1,
285 stats.examples_processed,
286 stats.avg_quality,
287 stats.duration_secs
288 );
289 }
290
291 fn on_training_complete(&self, result: &TrainingResult) {
292 println!(
293 "[{}] Training complete: epochs={}, patterns={}, final_quality={:.4}",
294 self.prefix, result.epochs_completed, result.patterns_learned, result.final_avg_quality
295 );
296 }
297
298 fn on_error(&self, error: &str) {
299 eprintln!("[{}] ERROR: {}", self.prefix, error);
300 }
301}
302
303pub struct TrainingPipeline {
305 name: String,
307 engine: SonaEngine,
309 batch_config: BatchConfig,
311 training_method: TrainingMethod,
313 stage: PipelineStage,
315 examples: Vec<TrainingExample>,
317 validation_examples: Vec<TrainingExample>,
319 metrics: TrainingMetrics,
321 callback: Box<dyn TrainingCallback>,
323 extract_patterns: bool,
325}
326
327impl TrainingPipeline {
328 pub fn new(name: impl Into<String>, config: SonaConfig) -> Self {
330 let name = name.into();
331 Self {
332 name: name.clone(),
333 engine: SonaEngine::with_config(config),
334 batch_config: BatchConfig::default(),
335 training_method: TrainingMethod::default(),
336 stage: PipelineStage::Idle,
337 examples: Vec::new(),
338 validation_examples: Vec::new(),
339 metrics: TrainingMetrics::new(&name),
340 callback: Box::new(NoOpCallback),
341 extract_patterns: true,
342 }
343 }
344
345 pub fn from_template(template: TrainingTemplate) -> Self {
347 let batch_config = BatchConfig::for_data_size(&template.expected_data_size);
348 let mut pipeline = Self::new(&template.name, template.sona_config);
349 pipeline.batch_config = batch_config;
350 pipeline.training_method = template.training_method;
351 pipeline
352 }
353
354 pub fn with_batch_config(mut self, config: BatchConfig) -> Self {
356 self.batch_config = config;
357 self
358 }
359
360 pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
362 self.training_method = method;
363 self
364 }
365
366 pub fn with_callback<C: TrainingCallback + 'static>(mut self, callback: C) -> Self {
368 self.callback = Box::new(callback);
369 self
370 }
371
372 pub fn with_pattern_extraction(mut self, enabled: bool) -> Self {
374 self.extract_patterns = enabled;
375 self
376 }
377
378 pub fn add_example(&mut self, example: TrainingExample) {
380 self.examples.push(example);
381 }
382
383 pub fn add_examples(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
385 self.examples.extend(examples);
386 }
387
388 pub fn add_validation_example(&mut self, example: TrainingExample) {
390 self.validation_examples.push(example);
391 }
392
393 pub fn stage(&self) -> &PipelineStage {
395 &self.stage
396 }
397
398 pub fn example_count(&self) -> usize {
400 self.examples.len()
401 }
402
403 pub fn metrics(&self) -> &TrainingMetrics {
405 &self.metrics
406 }
407
408 pub fn engine(&self) -> &SonaEngine {
410 &self.engine
411 }
412
413 pub fn engine_mut(&mut self) -> &mut SonaEngine {
415 &mut self.engine
416 }
417
418 pub fn train(&mut self) -> Result<TrainingResult, String> {
420 let start = Instant::now();
421
422 self.set_stage(PipelineStage::Preprocessing);
424 self.preprocess()?;
425
426 self.set_stage(PipelineStage::Training);
428 let epoch_stats = self.run_training()?;
429
430 if !self.validation_examples.is_empty() {
432 self.set_stage(PipelineStage::Validation);
433 self.run_validation()?;
434 }
435
436 if self.extract_patterns {
438 self.set_stage(PipelineStage::PatternExtraction);
439 self.engine.force_learn();
440 }
441
442 self.set_stage(PipelineStage::Completed);
443
444 let result = TrainingResult {
445 pipeline_name: self.name.clone(),
446 epochs_completed: epoch_stats.len(),
447 total_examples: self.metrics.total_examples,
448 patterns_learned: self.metrics.patterns_learned,
449 final_avg_quality: self.metrics.avg_quality(),
450 total_duration_secs: start.elapsed().as_secs_f64(),
451 epoch_stats,
452 validation_quality: self.metrics.validation_quality,
453 };
454
455 self.callback.on_training_complete(&result);
456 Ok(result)
457 }
458
459 fn set_stage(&mut self, stage: PipelineStage) {
461 self.stage = stage.clone();
462 self.callback.on_stage_change(&stage);
463 }
464
465 fn preprocess(&mut self) -> Result<(), String> {
467 if self.examples.is_empty() {
468 return Err("No training examples provided".into());
469 }
470
471 if self.batch_config.shuffle {
473 use rand::seq::SliceRandom;
474 let mut rng = rand::thread_rng();
475 self.examples.shuffle(&mut rng);
476 }
477
478 Ok(())
479 }
480
481 fn run_training(&mut self) -> Result<Vec<EpochStats>, String> {
483 let mut all_epoch_stats = Vec::new();
484 let mut best_quality = 0.0f32;
485 let mut patience_counter = 0usize;
486
487 for epoch in 0..self.batch_config.epochs {
488 let epoch_start = Instant::now();
489 let mut epoch_quality_sum = 0.0f32;
490 let mut epoch_examples = 0usize;
491
492 let batch_size = self.batch_config.batch_size;
494 let total_examples = self.examples.len();
495 let mut batch_indices: Vec<(usize, usize)> = Vec::new();
496 let mut start = 0;
497 while start < total_examples {
498 let end = (start + batch_size).min(total_examples);
499 if end > start && (!self.batch_config.drop_last || end - start == batch_size) {
500 batch_indices.push((start, end));
501 }
502 start = end;
503 }
504 let total_batches = batch_indices.len();
505
506 for (batch_idx, (start, end)) in batch_indices.into_iter().enumerate() {
507 let batch_quality = self.train_batch_range(start, end)?;
508 let batch_len = end - start;
509 epoch_quality_sum += batch_quality * batch_len as f32;
510 epoch_examples += batch_len;
511
512 self.callback.on_batch_complete(
513 batch_idx,
514 total_batches,
515 epoch_quality_sum / epoch_examples as f32,
516 );
517 }
518
519 let epoch_avg_quality = if epoch_examples > 0 {
520 epoch_quality_sum / epoch_examples as f32
521 } else {
522 0.0
523 };
524
525 let epoch_stats = EpochStats {
526 epoch,
527 examples_processed: epoch_examples,
528 avg_quality: epoch_avg_quality,
529 duration_secs: epoch_start.elapsed().as_secs_f64(),
530 };
531
532 self.callback.on_epoch_complete(epoch, &epoch_stats);
533 all_epoch_stats.push(epoch_stats);
534
535 if let Some(patience) = self.batch_config.early_stopping_patience {
537 let improvement = epoch_avg_quality - best_quality;
538 if improvement > self.batch_config.min_quality_improvement {
539 best_quality = epoch_avg_quality;
540 patience_counter = 0;
541 } else {
542 patience_counter += 1;
543 if patience_counter >= patience {
544 break; }
546 }
547 }
548
549 if self.batch_config.shuffle && epoch + 1 < self.batch_config.epochs {
551 use rand::seq::SliceRandom;
552 let mut rng = rand::thread_rng();
553 self.examples.shuffle(&mut rng);
554 }
555 }
556
557 Ok(all_epoch_stats)
558 }
559
560 fn train_batch_range(&mut self, start: usize, end: usize) -> Result<f32, String> {
562 let mut quality_sum = 0.0f32;
563 let batch_len = end - start;
564
565 for idx in start..end {
566 let example = &self.examples[idx];
567
568 let mut builder = self.engine.begin_trajectory(example.embedding.clone());
570
571 if let Some(ref route) = example.route {
573 builder.set_model_route(route);
574 }
575
576 for ctx in &example.context {
578 builder.add_context(ctx);
579 }
580
581 builder.add_step(
583 example.get_activations(),
584 example.get_attention(),
585 example.get_reward() * example.weight,
586 );
587
588 self.engine.end_trajectory(builder, example.quality);
590
591 quality_sum += example.quality;
592 self.metrics.total_examples += 1;
593 self.metrics.add_quality_sample(example.quality);
594 }
595
596 self.engine.tick();
598
599 Ok(quality_sum / batch_len as f32)
600 }
601
602 fn run_validation(&mut self) -> Result<(), String> {
604 let mut quality_sum = 0.0f32;
605
606 for example in &self.validation_examples {
607 let mut output = vec![0.0f32; example.embedding.len()];
609 self.engine
610 .apply_micro_lora(&example.embedding, &mut output);
611
612 quality_sum += example.quality;
615 }
616
617 self.metrics.validation_quality = Some(quality_sum / self.validation_examples.len() as f32);
618
619 Ok(())
620 }
621
622 pub fn clear_examples(&mut self) {
624 self.examples.clear();
625 self.validation_examples.clear();
626 }
627
628 pub fn reset(&mut self) {
630 self.clear_examples();
631 self.metrics = TrainingMetrics::new(&self.name);
632 self.stage = PipelineStage::Idle;
633 }
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[test]
641 fn test_training_example() {
642 let example = TrainingExample::new(vec![0.1; 256], 0.8)
643 .with_route("test")
644 .with_context("ctx1")
645 .with_weight(1.5)
646 .with_tag("test");
647
648 assert_eq!(example.quality, 0.8);
649 assert_eq!(example.route, Some("test".into()));
650 assert_eq!(example.weight, 1.5);
651 }
652
653 #[test]
654 fn test_batch_config() {
655 let config = BatchConfig::for_data_size(&DataSizeHint::Small);
656 assert_eq!(config.batch_size, 16);
657 assert_eq!(config.epochs, 5);
658 }
659
660 #[test]
661 fn test_pipeline_creation() {
662 let pipeline = TrainingPipeline::new("test", SonaConfig::default());
663 assert_eq!(pipeline.stage(), &PipelineStage::Idle);
664 assert_eq!(pipeline.example_count(), 0);
665 }
666
667 #[test]
668 fn test_pipeline_from_template() {
669 let template = TrainingTemplate::code_agent().with_hidden_dim(256);
670 let pipeline = TrainingPipeline::from_template(template);
671 assert_eq!(pipeline.name, "code-agent");
672 }
673
674 #[test]
675 fn test_pipeline_training() {
676 let mut pipeline =
677 TrainingPipeline::new("test", SonaConfig::default()).with_batch_config(BatchConfig {
678 batch_size: 2,
679 epochs: 2,
680 ..Default::default()
681 });
682
683 for i in 0..5 {
685 pipeline.add_example(TrainingExample::new(
686 vec![i as f32 * 0.1; 256],
687 0.7 + i as f32 * 0.05,
688 ));
689 }
690
691 let result = pipeline.train().unwrap();
692 assert_eq!(result.epochs_completed, 2);
693 assert!(result.total_examples > 0);
694 }
695
696 #[test]
697 fn test_pipeline_with_validation() {
698 let mut pipeline = TrainingPipeline::new("test", SonaConfig::default())
699 .with_batch_config(BatchConfig::single_pass());
700
701 pipeline.add_example(TrainingExample::new(vec![0.1; 256], 0.8));
702 pipeline.add_validation_example(TrainingExample::new(vec![0.2; 256], 0.9));
703
704 let result = pipeline.train().unwrap();
705 assert!(result.validation_quality.is_some());
706 }
707}