1use super::metrics::{EpochStats, TrainingMetrics, TrainingResult};
6use super::templates::{DataSizeHint, TrainingMethod, TrainingTemplate};
7use crate::engine::SonaEngine;
8use crate::time_compat::Instant;
9use crate::types::SonaConfig;
10use serde::{Deserialize, Serialize};
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct TrainingExample {
15 pub embedding: Vec<f32>,
17 pub activations: Option<Vec<f32>>,
19 pub attention: Option<Vec<f32>>,
21 pub quality: f32,
23 pub reward: Option<f32>,
25 pub route: Option<String>,
27 pub context: Vec<String>,
29 pub weight: f32,
31 pub tags: Vec<String>,
33}
34
35impl TrainingExample {
36 pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
38 Self {
39 embedding,
40 activations: None,
41 attention: None,
42 quality,
43 reward: None,
44 route: None,
45 context: Vec::new(),
46 weight: 1.0,
47 tags: Vec::new(),
48 }
49 }
50
51 pub fn with_activations(mut self, activations: Vec<f32>) -> Self {
53 self.activations = Some(activations);
54 self
55 }
56
57 pub fn with_attention(mut self, attention: Vec<f32>) -> Self {
59 self.attention = Some(attention);
60 self
61 }
62
63 pub fn with_reward(mut self, reward: f32) -> Self {
65 self.reward = Some(reward);
66 self
67 }
68
69 pub fn with_route(mut self, route: impl Into<String>) -> Self {
71 self.route = Some(route.into());
72 self
73 }
74
75 pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
77 self.context.push(ctx.into());
78 self
79 }
80
81 pub fn with_weight(mut self, weight: f32) -> Self {
83 self.weight = weight;
84 self
85 }
86
87 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
89 self.tags.push(tag.into());
90 self
91 }
92
93 pub fn get_activations(&self) -> Vec<f32> {
95 self.activations
96 .clone()
97 .unwrap_or_else(|| self.embedding.clone())
98 }
99
100 pub fn get_attention(&self) -> Vec<f32> {
102 self.attention
103 .clone()
104 .unwrap_or_else(|| vec![1.0 / 64.0; 64])
105 }
106
107 pub fn get_reward(&self) -> f32 {
109 self.reward.unwrap_or(self.quality)
110 }
111}
112
113#[derive(Clone, Debug, Serialize, Deserialize)]
115pub struct BatchConfig {
116 pub batch_size: usize,
118 pub shuffle: bool,
120 pub drop_last: bool,
122 pub epochs: usize,
124 pub early_stopping_patience: Option<usize>,
126 pub min_quality_improvement: f32,
128}
129
130impl Default for BatchConfig {
131 fn default() -> Self {
132 Self {
133 batch_size: 32,
134 shuffle: true,
135 drop_last: false,
136 epochs: 1,
137 early_stopping_patience: None,
138 min_quality_improvement: 0.001,
139 }
140 }
141}
142
143impl BatchConfig {
144 pub fn single_pass() -> Self {
146 Self {
147 batch_size: usize::MAX,
148 shuffle: false,
149 drop_last: false,
150 epochs: 1,
151 early_stopping_patience: None,
152 min_quality_improvement: 0.0,
153 }
154 }
155
156 pub fn for_data_size(hint: &DataSizeHint) -> Self {
158 match hint {
159 DataSizeHint::Tiny => Self {
160 batch_size: 8,
161 epochs: 10,
162 early_stopping_patience: Some(3),
163 ..Default::default()
164 },
165 DataSizeHint::Small => Self {
166 batch_size: 16,
167 epochs: 5,
168 early_stopping_patience: Some(2),
169 ..Default::default()
170 },
171 DataSizeHint::Medium => Self {
172 batch_size: 32,
173 epochs: 3,
174 early_stopping_patience: Some(2),
175 ..Default::default()
176 },
177 DataSizeHint::Large => Self {
178 batch_size: 64,
179 epochs: 2,
180 ..Default::default()
181 },
182 DataSizeHint::Massive => Self {
183 batch_size: 128,
184 epochs: 1,
185 ..Default::default()
186 },
187 }
188 }
189}
190
191#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
193pub enum PipelineStage {
194 Idle,
196 Preprocessing,
198 Training,
200 Validation,
202 PatternExtraction,
204 Export,
206 Completed,
208 Failed,
210}
211
212impl std::fmt::Display for PipelineStage {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 match self {
215 PipelineStage::Idle => write!(f, "idle"),
216 PipelineStage::Preprocessing => write!(f, "preprocessing"),
217 PipelineStage::Training => write!(f, "training"),
218 PipelineStage::Validation => write!(f, "validation"),
219 PipelineStage::PatternExtraction => write!(f, "pattern_extraction"),
220 PipelineStage::Export => write!(f, "export"),
221 PipelineStage::Completed => write!(f, "completed"),
222 PipelineStage::Failed => write!(f, "failed"),
223 }
224 }
225}
226
227pub trait TrainingCallback: Send + Sync {
229 fn on_stage_change(&self, _stage: &PipelineStage) {}
231
232 fn on_batch_complete(&self, _batch_idx: usize, _total_batches: usize, _avg_quality: f32) {}
234
235 fn on_epoch_complete(&self, _epoch: usize, _stats: &EpochStats) {}
237
238 fn on_training_complete(&self, _result: &TrainingResult) {}
240
241 fn on_error(&self, _error: &str) {}
243}
244
245pub struct NoOpCallback;
247impl TrainingCallback for NoOpCallback {}
248
249#[allow(dead_code)]
251pub struct LoggingCallback {
252 prefix: String,
253}
254
255#[allow(dead_code)]
256impl LoggingCallback {
257 pub fn new(prefix: impl Into<String>) -> Self {
259 Self {
260 prefix: prefix.into(),
261 }
262 }
263}
264
265impl TrainingCallback for LoggingCallback {
266 fn on_stage_change(&self, stage: &PipelineStage) {
267 println!("[{}] Stage: {}", self.prefix, stage);
268 }
269
270 fn on_batch_complete(&self, batch_idx: usize, total_batches: usize, avg_quality: f32) {
271 if batch_idx % 10 == 0 || batch_idx == total_batches - 1 {
272 println!(
273 "[{}] Batch {}/{}: avg_quality={:.4}",
274 self.prefix,
275 batch_idx + 1,
276 total_batches,
277 avg_quality
278 );
279 }
280 }
281
282 fn on_epoch_complete(&self, epoch: usize, stats: &EpochStats) {
283 println!(
284 "[{}] Epoch {}: examples={}, avg_quality={:.4}, duration={:.2}s",
285 self.prefix,
286 epoch + 1,
287 stats.examples_processed,
288 stats.avg_quality,
289 stats.duration_secs
290 );
291 }
292
293 fn on_training_complete(&self, result: &TrainingResult) {
294 println!(
295 "[{}] Training complete: epochs={}, patterns={}, final_quality={:.4}",
296 self.prefix, result.epochs_completed, result.patterns_learned, result.final_avg_quality
297 );
298 }
299
300 fn on_error(&self, error: &str) {
301 eprintln!("[{}] ERROR: {}", self.prefix, error);
302 }
303}
304
305pub struct TrainingPipeline {
307 name: String,
309 engine: SonaEngine,
311 batch_config: BatchConfig,
313 training_method: TrainingMethod,
315 stage: PipelineStage,
317 examples: Vec<TrainingExample>,
319 validation_examples: Vec<TrainingExample>,
321 metrics: TrainingMetrics,
323 callback: Box<dyn TrainingCallback>,
325 extract_patterns: bool,
327}
328
329impl TrainingPipeline {
330 pub fn new(name: impl Into<String>, config: SonaConfig) -> Self {
332 let name = name.into();
333 Self {
334 name: name.clone(),
335 engine: SonaEngine::with_config(config),
336 batch_config: BatchConfig::default(),
337 training_method: TrainingMethod::default(),
338 stage: PipelineStage::Idle,
339 examples: Vec::new(),
340 validation_examples: Vec::new(),
341 metrics: TrainingMetrics::new(&name),
342 callback: Box::new(NoOpCallback),
343 extract_patterns: true,
344 }
345 }
346
347 pub fn from_template(template: TrainingTemplate) -> Self {
349 let batch_config = BatchConfig::for_data_size(&template.expected_data_size);
350 let mut pipeline = Self::new(&template.name, template.sona_config);
351 pipeline.batch_config = batch_config;
352 pipeline.training_method = template.training_method;
353 pipeline
354 }
355
356 pub fn with_batch_config(mut self, config: BatchConfig) -> Self {
358 self.batch_config = config;
359 self
360 }
361
362 pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
364 self.training_method = method;
365 self
366 }
367
368 pub fn with_callback<C: TrainingCallback + 'static>(mut self, callback: C) -> Self {
370 self.callback = Box::new(callback);
371 self
372 }
373
374 pub fn with_pattern_extraction(mut self, enabled: bool) -> Self {
376 self.extract_patterns = enabled;
377 self
378 }
379
380 pub fn add_example(&mut self, example: TrainingExample) {
382 self.examples.push(example);
383 }
384
385 pub fn add_examples(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
387 self.examples.extend(examples);
388 }
389
390 pub fn add_validation_example(&mut self, example: TrainingExample) {
392 self.validation_examples.push(example);
393 }
394
395 pub fn stage(&self) -> &PipelineStage {
397 &self.stage
398 }
399
400 pub fn example_count(&self) -> usize {
402 self.examples.len()
403 }
404
405 pub fn metrics(&self) -> &TrainingMetrics {
407 &self.metrics
408 }
409
410 pub fn engine(&self) -> &SonaEngine {
412 &self.engine
413 }
414
415 pub fn engine_mut(&mut self) -> &mut SonaEngine {
417 &mut self.engine
418 }
419
420 pub fn train(&mut self) -> Result<TrainingResult, String> {
422 let start = Instant::now();
423
424 self.set_stage(PipelineStage::Preprocessing);
426 self.preprocess()?;
427
428 self.set_stage(PipelineStage::Training);
430 let epoch_stats = self.run_training()?;
431
432 if !self.validation_examples.is_empty() {
434 self.set_stage(PipelineStage::Validation);
435 self.run_validation()?;
436 }
437
438 if self.extract_patterns {
440 self.set_stage(PipelineStage::PatternExtraction);
441 self.engine.force_learn();
442 }
443
444 self.set_stage(PipelineStage::Completed);
445
446 let result = TrainingResult {
447 pipeline_name: self.name.clone(),
448 epochs_completed: epoch_stats.len(),
449 total_examples: self.metrics.total_examples,
450 patterns_learned: self.metrics.patterns_learned,
451 final_avg_quality: self.metrics.avg_quality(),
452 total_duration_secs: start.elapsed().as_secs_f64(),
453 epoch_stats,
454 validation_quality: self.metrics.validation_quality,
455 };
456
457 self.callback.on_training_complete(&result);
458 Ok(result)
459 }
460
461 fn set_stage(&mut self, stage: PipelineStage) {
463 self.stage = stage.clone();
464 self.callback.on_stage_change(&stage);
465 }
466
467 fn preprocess(&mut self) -> Result<(), String> {
469 if self.examples.is_empty() {
470 return Err("No training examples provided".into());
471 }
472
473 if self.batch_config.shuffle {
475 use rand::seq::SliceRandom;
476 let mut rng = rand::thread_rng();
477 self.examples.shuffle(&mut rng);
478 }
479
480 Ok(())
481 }
482
483 fn run_training(&mut self) -> Result<Vec<EpochStats>, String> {
485 let mut all_epoch_stats = Vec::new();
486 let mut best_quality = 0.0f32;
487 let mut patience_counter = 0usize;
488
489 for epoch in 0..self.batch_config.epochs {
490 let epoch_start = Instant::now();
491 let mut epoch_quality_sum = 0.0f32;
492 let mut epoch_examples = 0usize;
493
494 let batch_size = self.batch_config.batch_size;
496 let total_examples = self.examples.len();
497 let mut batch_indices: Vec<(usize, usize)> = Vec::new();
498 let mut start = 0;
499 while start < total_examples {
500 let end = (start + batch_size).min(total_examples);
501 if end > start && (!self.batch_config.drop_last || end - start == batch_size) {
502 batch_indices.push((start, end));
503 }
504 start = end;
505 }
506 let total_batches = batch_indices.len();
507
508 for (batch_idx, (start, end)) in batch_indices.into_iter().enumerate() {
509 let batch_quality = self.train_batch_range(start, end)?;
510 let batch_len = end - start;
511 epoch_quality_sum += batch_quality * batch_len as f32;
512 epoch_examples += batch_len;
513
514 self.callback.on_batch_complete(
515 batch_idx,
516 total_batches,
517 epoch_quality_sum / epoch_examples as f32,
518 );
519 }
520
521 let epoch_avg_quality = if epoch_examples > 0 {
522 epoch_quality_sum / epoch_examples as f32
523 } else {
524 0.0
525 };
526
527 let epoch_stats = EpochStats {
528 epoch,
529 examples_processed: epoch_examples,
530 avg_quality: epoch_avg_quality,
531 duration_secs: epoch_start.elapsed().as_secs_f64(),
532 };
533
534 self.callback.on_epoch_complete(epoch, &epoch_stats);
535 all_epoch_stats.push(epoch_stats);
536
537 if let Some(patience) = self.batch_config.early_stopping_patience {
539 let improvement = epoch_avg_quality - best_quality;
540 if improvement > self.batch_config.min_quality_improvement {
541 best_quality = epoch_avg_quality;
542 patience_counter = 0;
543 } else {
544 patience_counter += 1;
545 if patience_counter >= patience {
546 break; }
548 }
549 }
550
551 if self.batch_config.shuffle && epoch + 1 < self.batch_config.epochs {
553 use rand::seq::SliceRandom;
554 let mut rng = rand::thread_rng();
555 self.examples.shuffle(&mut rng);
556 }
557 }
558
559 Ok(all_epoch_stats)
560 }
561
562 fn train_batch_range(&mut self, start: usize, end: usize) -> Result<f32, String> {
564 let mut quality_sum = 0.0f32;
565 let batch_len = end - start;
566
567 for idx in start..end {
568 let example = &self.examples[idx];
569
570 let mut builder = self.engine.begin_trajectory(example.embedding.clone());
572
573 if let Some(ref route) = example.route {
575 builder.set_model_route(route);
576 }
577
578 for ctx in &example.context {
580 builder.add_context(ctx);
581 }
582
583 builder.add_step(
585 example.get_activations(),
586 example.get_attention(),
587 example.get_reward() * example.weight,
588 );
589
590 self.engine.end_trajectory(builder, example.quality);
592
593 quality_sum += example.quality;
594 self.metrics.total_examples += 1;
595 self.metrics.add_quality_sample(example.quality);
596 }
597
598 self.engine.tick();
600
601 Ok(quality_sum / batch_len as f32)
602 }
603
604 fn run_validation(&mut self) -> Result<(), String> {
606 let mut quality_sum = 0.0f32;
607
608 for example in &self.validation_examples {
609 let mut output = vec![0.0f32; example.embedding.len()];
611 self.engine
612 .apply_micro_lora(&example.embedding, &mut output);
613
614 quality_sum += example.quality;
617 }
618
619 self.metrics.validation_quality = Some(quality_sum / self.validation_examples.len() as f32);
620
621 Ok(())
622 }
623
624 pub fn clear_examples(&mut self) {
626 self.examples.clear();
627 self.validation_examples.clear();
628 }
629
630 pub fn reset(&mut self) {
632 self.clear_examples();
633 self.metrics = TrainingMetrics::new(&self.name);
634 self.stage = PipelineStage::Idle;
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641
642 #[test]
643 fn test_training_example() {
644 let example = TrainingExample::new(vec![0.1; 256], 0.8)
645 .with_route("test")
646 .with_context("ctx1")
647 .with_weight(1.5)
648 .with_tag("test");
649
650 assert_eq!(example.quality, 0.8);
651 assert_eq!(example.route, Some("test".into()));
652 assert_eq!(example.weight, 1.5);
653 }
654
655 #[test]
656 fn test_batch_config() {
657 let config = BatchConfig::for_data_size(&DataSizeHint::Small);
658 assert_eq!(config.batch_size, 16);
659 assert_eq!(config.epochs, 5);
660 }
661
662 #[test]
663 fn test_pipeline_creation() {
664 let pipeline = TrainingPipeline::new("test", SonaConfig::default());
665 assert_eq!(pipeline.stage(), &PipelineStage::Idle);
666 assert_eq!(pipeline.example_count(), 0);
667 }
668
669 #[test]
670 fn test_pipeline_from_template() {
671 let template = TrainingTemplate::code_agent().with_hidden_dim(256);
672 let pipeline = TrainingPipeline::from_template(template);
673 assert_eq!(pipeline.name, "code-agent");
674 }
675
676 #[test]
677 fn test_pipeline_training() {
678 let mut pipeline =
679 TrainingPipeline::new("test", SonaConfig::default()).with_batch_config(BatchConfig {
680 batch_size: 2,
681 epochs: 2,
682 ..Default::default()
683 });
684
685 for i in 0..5 {
687 pipeline.add_example(TrainingExample::new(
688 vec![i as f32 * 0.1; 256],
689 0.7 + i as f32 * 0.05,
690 ));
691 }
692
693 let result = pipeline.train().unwrap();
694 assert_eq!(result.epochs_completed, 2);
695 assert!(result.total_examples > 0);
696 }
697
698 #[test]
699 fn test_pipeline_with_validation() {
700 let mut pipeline = TrainingPipeline::new("test", SonaConfig::default())
701 .with_batch_config(BatchConfig::single_pass());
702
703 pipeline.add_example(TrainingExample::new(vec![0.1; 256], 0.8));
704 pipeline.add_validation_example(TrainingExample::new(vec![0.2; 256], 0.9));
705
706 let result = pipeline.train().unwrap();
707 assert!(result.validation_quality.is_some());
708 }
709}