1use crate::engine::SonaEngine;
6use crate::types::SonaConfig;
7use super::templates::{TrainingTemplate, TrainingMethod, DataSizeHint};
8use super::metrics::{TrainingMetrics, TrainingResult, EpochStats};
9use serde::{Deserialize, Serialize};
10use std::time::Instant;
11
12#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct TrainingExample {
15 pub embedding: Vec<f32>,
17 pub activations: Option<Vec<f32>>,
19 pub attention: Option<Vec<f32>>,
21 pub quality: f32,
23 pub reward: Option<f32>,
25 pub route: Option<String>,
27 pub context: Vec<String>,
29 pub weight: f32,
31 pub tags: Vec<String>,
33}
34
35impl TrainingExample {
36 pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
38 Self {
39 embedding,
40 activations: None,
41 attention: None,
42 quality,
43 reward: None,
44 route: None,
45 context: Vec::new(),
46 weight: 1.0,
47 tags: Vec::new(),
48 }
49 }
50
51 pub fn with_activations(mut self, activations: Vec<f32>) -> Self {
53 self.activations = Some(activations);
54 self
55 }
56
57 pub fn with_attention(mut self, attention: Vec<f32>) -> Self {
59 self.attention = Some(attention);
60 self
61 }
62
63 pub fn with_reward(mut self, reward: f32) -> Self {
65 self.reward = Some(reward);
66 self
67 }
68
69 pub fn with_route(mut self, route: impl Into<String>) -> Self {
71 self.route = Some(route.into());
72 self
73 }
74
75 pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
77 self.context.push(ctx.into());
78 self
79 }
80
81 pub fn with_weight(mut self, weight: f32) -> Self {
83 self.weight = weight;
84 self
85 }
86
87 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
89 self.tags.push(tag.into());
90 self
91 }
92
93 pub fn get_activations(&self) -> Vec<f32> {
95 self.activations.clone().unwrap_or_else(|| self.embedding.clone())
96 }
97
98 pub fn get_attention(&self) -> Vec<f32> {
100 self.attention.clone().unwrap_or_else(|| vec![1.0 / 64.0; 64])
101 }
102
103 pub fn get_reward(&self) -> f32 {
105 self.reward.unwrap_or(self.quality)
106 }
107}
108
109#[derive(Clone, Debug, Serialize, Deserialize)]
111pub struct BatchConfig {
112 pub batch_size: usize,
114 pub shuffle: bool,
116 pub drop_last: bool,
118 pub epochs: usize,
120 pub early_stopping_patience: Option<usize>,
122 pub min_quality_improvement: f32,
124}
125
126impl Default for BatchConfig {
127 fn default() -> Self {
128 Self {
129 batch_size: 32,
130 shuffle: true,
131 drop_last: false,
132 epochs: 1,
133 early_stopping_patience: None,
134 min_quality_improvement: 0.001,
135 }
136 }
137}
138
139impl BatchConfig {
140 pub fn single_pass() -> Self {
142 Self {
143 batch_size: usize::MAX,
144 shuffle: false,
145 drop_last: false,
146 epochs: 1,
147 early_stopping_patience: None,
148 min_quality_improvement: 0.0,
149 }
150 }
151
152 pub fn for_data_size(hint: &DataSizeHint) -> Self {
154 match hint {
155 DataSizeHint::Tiny => Self {
156 batch_size: 8,
157 epochs: 10,
158 early_stopping_patience: Some(3),
159 ..Default::default()
160 },
161 DataSizeHint::Small => Self {
162 batch_size: 16,
163 epochs: 5,
164 early_stopping_patience: Some(2),
165 ..Default::default()
166 },
167 DataSizeHint::Medium => Self {
168 batch_size: 32,
169 epochs: 3,
170 early_stopping_patience: Some(2),
171 ..Default::default()
172 },
173 DataSizeHint::Large => Self {
174 batch_size: 64,
175 epochs: 2,
176 ..Default::default()
177 },
178 DataSizeHint::Massive => Self {
179 batch_size: 128,
180 epochs: 1,
181 ..Default::default()
182 },
183 }
184 }
185}
186
187#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
189pub enum PipelineStage {
190 Idle,
192 Preprocessing,
194 Training,
196 Validation,
198 PatternExtraction,
200 Export,
202 Completed,
204 Failed,
206}
207
208impl std::fmt::Display for PipelineStage {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match self {
211 PipelineStage::Idle => write!(f, "idle"),
212 PipelineStage::Preprocessing => write!(f, "preprocessing"),
213 PipelineStage::Training => write!(f, "training"),
214 PipelineStage::Validation => write!(f, "validation"),
215 PipelineStage::PatternExtraction => write!(f, "pattern_extraction"),
216 PipelineStage::Export => write!(f, "export"),
217 PipelineStage::Completed => write!(f, "completed"),
218 PipelineStage::Failed => write!(f, "failed"),
219 }
220 }
221}
222
223pub trait TrainingCallback: Send + Sync {
225 fn on_stage_change(&self, _stage: &PipelineStage) {}
227
228 fn on_batch_complete(&self, _batch_idx: usize, _total_batches: usize, _avg_quality: f32) {}
230
231 fn on_epoch_complete(&self, _epoch: usize, _stats: &EpochStats) {}
233
234 fn on_training_complete(&self, _result: &TrainingResult) {}
236
237 fn on_error(&self, _error: &str) {}
239}
240
241pub struct NoOpCallback;
243impl TrainingCallback for NoOpCallback {}
244
245pub struct LoggingCallback {
247 prefix: String,
248}
249
250impl LoggingCallback {
251 pub fn new(prefix: impl Into<String>) -> Self {
253 Self { prefix: prefix.into() }
254 }
255}
256
257impl TrainingCallback for LoggingCallback {
258 fn on_stage_change(&self, stage: &PipelineStage) {
259 println!("[{}] Stage: {}", self.prefix, stage);
260 }
261
262 fn on_batch_complete(&self, batch_idx: usize, total_batches: usize, avg_quality: f32) {
263 if batch_idx % 10 == 0 || batch_idx == total_batches - 1 {
264 println!(
265 "[{}] Batch {}/{}: avg_quality={:.4}",
266 self.prefix, batch_idx + 1, total_batches, avg_quality
267 );
268 }
269 }
270
271 fn on_epoch_complete(&self, epoch: usize, stats: &EpochStats) {
272 println!(
273 "[{}] Epoch {}: examples={}, avg_quality={:.4}, duration={:.2}s",
274 self.prefix, epoch + 1, stats.examples_processed, stats.avg_quality, stats.duration_secs
275 );
276 }
277
278 fn on_training_complete(&self, result: &TrainingResult) {
279 println!(
280 "[{}] Training complete: epochs={}, patterns={}, final_quality={:.4}",
281 self.prefix, result.epochs_completed, result.patterns_learned, result.final_avg_quality
282 );
283 }
284
285 fn on_error(&self, error: &str) {
286 eprintln!("[{}] ERROR: {}", self.prefix, error);
287 }
288}
289
290pub struct TrainingPipeline {
292 name: String,
294 engine: SonaEngine,
296 batch_config: BatchConfig,
298 training_method: TrainingMethod,
300 stage: PipelineStage,
302 examples: Vec<TrainingExample>,
304 validation_examples: Vec<TrainingExample>,
306 metrics: TrainingMetrics,
308 callback: Box<dyn TrainingCallback>,
310 extract_patterns: bool,
312}
313
314impl TrainingPipeline {
315 pub fn new(name: impl Into<String>, config: SonaConfig) -> Self {
317 let name = name.into();
318 Self {
319 name: name.clone(),
320 engine: SonaEngine::with_config(config),
321 batch_config: BatchConfig::default(),
322 training_method: TrainingMethod::default(),
323 stage: PipelineStage::Idle,
324 examples: Vec::new(),
325 validation_examples: Vec::new(),
326 metrics: TrainingMetrics::new(&name),
327 callback: Box::new(NoOpCallback),
328 extract_patterns: true,
329 }
330 }
331
332 pub fn from_template(template: TrainingTemplate) -> Self {
334 let batch_config = BatchConfig::for_data_size(&template.expected_data_size);
335 let mut pipeline = Self::new(&template.name, template.sona_config);
336 pipeline.batch_config = batch_config;
337 pipeline.training_method = template.training_method;
338 pipeline
339 }
340
341 pub fn with_batch_config(mut self, config: BatchConfig) -> Self {
343 self.batch_config = config;
344 self
345 }
346
347 pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
349 self.training_method = method;
350 self
351 }
352
353 pub fn with_callback<C: TrainingCallback + 'static>(mut self, callback: C) -> Self {
355 self.callback = Box::new(callback);
356 self
357 }
358
359 pub fn with_pattern_extraction(mut self, enabled: bool) -> Self {
361 self.extract_patterns = enabled;
362 self
363 }
364
365 pub fn add_example(&mut self, example: TrainingExample) {
367 self.examples.push(example);
368 }
369
370 pub fn add_examples(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
372 self.examples.extend(examples);
373 }
374
375 pub fn add_validation_example(&mut self, example: TrainingExample) {
377 self.validation_examples.push(example);
378 }
379
380 pub fn stage(&self) -> &PipelineStage {
382 &self.stage
383 }
384
385 pub fn example_count(&self) -> usize {
387 self.examples.len()
388 }
389
390 pub fn metrics(&self) -> &TrainingMetrics {
392 &self.metrics
393 }
394
395 pub fn engine(&self) -> &SonaEngine {
397 &self.engine
398 }
399
400 pub fn engine_mut(&mut self) -> &mut SonaEngine {
402 &mut self.engine
403 }
404
405 pub fn train(&mut self) -> Result<TrainingResult, String> {
407 let start = Instant::now();
408
409 self.set_stage(PipelineStage::Preprocessing);
411 self.preprocess()?;
412
413 self.set_stage(PipelineStage::Training);
415 let epoch_stats = self.run_training()?;
416
417 if !self.validation_examples.is_empty() {
419 self.set_stage(PipelineStage::Validation);
420 self.run_validation()?;
421 }
422
423 if self.extract_patterns {
425 self.set_stage(PipelineStage::PatternExtraction);
426 self.engine.force_learn();
427 }
428
429 self.set_stage(PipelineStage::Completed);
430
431 let result = TrainingResult {
432 pipeline_name: self.name.clone(),
433 epochs_completed: epoch_stats.len(),
434 total_examples: self.metrics.total_examples,
435 patterns_learned: self.metrics.patterns_learned,
436 final_avg_quality: self.metrics.avg_quality(),
437 total_duration_secs: start.elapsed().as_secs_f64(),
438 epoch_stats,
439 validation_quality: self.metrics.validation_quality,
440 };
441
442 self.callback.on_training_complete(&result);
443 Ok(result)
444 }
445
446 fn set_stage(&mut self, stage: PipelineStage) {
448 self.stage = stage.clone();
449 self.callback.on_stage_change(&stage);
450 }
451
452 fn preprocess(&mut self) -> Result<(), String> {
454 if self.examples.is_empty() {
455 return Err("No training examples provided".into());
456 }
457
458 if self.batch_config.shuffle {
460 use rand::seq::SliceRandom;
461 let mut rng = rand::thread_rng();
462 self.examples.shuffle(&mut rng);
463 }
464
465 Ok(())
466 }
467
468 fn run_training(&mut self) -> Result<Vec<EpochStats>, String> {
470 let mut all_epoch_stats = Vec::new();
471 let mut best_quality = 0.0f32;
472 let mut patience_counter = 0usize;
473
474 for epoch in 0..self.batch_config.epochs {
475 let epoch_start = Instant::now();
476 let mut epoch_quality_sum = 0.0f32;
477 let mut epoch_examples = 0usize;
478
479 let batch_size = self.batch_config.batch_size;
481 let total_examples = self.examples.len();
482 let mut batch_indices: Vec<(usize, usize)> = Vec::new();
483 let mut start = 0;
484 while start < total_examples {
485 let end = (start + batch_size).min(total_examples);
486 if end > start && (!self.batch_config.drop_last || end - start == batch_size) {
487 batch_indices.push((start, end));
488 }
489 start = end;
490 }
491 let total_batches = batch_indices.len();
492
493 for (batch_idx, (start, end)) in batch_indices.into_iter().enumerate() {
494 let batch_quality = self.train_batch_range(start, end)?;
495 let batch_len = end - start;
496 epoch_quality_sum += batch_quality * batch_len as f32;
497 epoch_examples += batch_len;
498
499 self.callback.on_batch_complete(
500 batch_idx,
501 total_batches,
502 epoch_quality_sum / epoch_examples as f32,
503 );
504 }
505
506 let epoch_avg_quality = if epoch_examples > 0 {
507 epoch_quality_sum / epoch_examples as f32
508 } else {
509 0.0
510 };
511
512 let epoch_stats = EpochStats {
513 epoch,
514 examples_processed: epoch_examples,
515 avg_quality: epoch_avg_quality,
516 duration_secs: epoch_start.elapsed().as_secs_f64(),
517 };
518
519 self.callback.on_epoch_complete(epoch, &epoch_stats);
520 all_epoch_stats.push(epoch_stats);
521
522 if let Some(patience) = self.batch_config.early_stopping_patience {
524 let improvement = epoch_avg_quality - best_quality;
525 if improvement > self.batch_config.min_quality_improvement {
526 best_quality = epoch_avg_quality;
527 patience_counter = 0;
528 } else {
529 patience_counter += 1;
530 if patience_counter >= patience {
531 break; }
533 }
534 }
535
536 if self.batch_config.shuffle && epoch + 1 < self.batch_config.epochs {
538 use rand::seq::SliceRandom;
539 let mut rng = rand::thread_rng();
540 self.examples.shuffle(&mut rng);
541 }
542 }
543
544 Ok(all_epoch_stats)
545 }
546
547 fn train_batch_range(&mut self, start: usize, end: usize) -> Result<f32, String> {
549 let mut quality_sum = 0.0f32;
550 let batch_len = end - start;
551
552 for idx in start..end {
553 let example = &self.examples[idx];
554
555 let mut builder = self.engine.begin_trajectory(example.embedding.clone());
557
558 if let Some(ref route) = example.route {
560 builder.set_model_route(route);
561 }
562
563 for ctx in &example.context {
565 builder.add_context(ctx);
566 }
567
568 builder.add_step(
570 example.get_activations(),
571 example.get_attention(),
572 example.get_reward() * example.weight,
573 );
574
575 self.engine.end_trajectory(builder, example.quality);
577
578 quality_sum += example.quality;
579 self.metrics.total_examples += 1;
580 self.metrics.add_quality_sample(example.quality);
581 }
582
583 self.engine.tick();
585
586 Ok(quality_sum / batch_len as f32)
587 }
588
589 fn run_validation(&mut self) -> Result<(), String> {
591 let mut quality_sum = 0.0f32;
592
593 for example in &self.validation_examples {
594 let mut output = vec![0.0f32; example.embedding.len()];
596 self.engine.apply_micro_lora(&example.embedding, &mut output);
597
598 quality_sum += example.quality;
601 }
602
603 self.metrics.validation_quality = Some(
604 quality_sum / self.validation_examples.len() as f32
605 );
606
607 Ok(())
608 }
609
610 pub fn clear_examples(&mut self) {
612 self.examples.clear();
613 self.validation_examples.clear();
614 }
615
616 pub fn reset(&mut self) {
618 self.clear_examples();
619 self.metrics = TrainingMetrics::new(&self.name);
620 self.stage = PipelineStage::Idle;
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[test]
629 fn test_training_example() {
630 let example = TrainingExample::new(vec![0.1; 256], 0.8)
631 .with_route("test")
632 .with_context("ctx1")
633 .with_weight(1.5)
634 .with_tag("test");
635
636 assert_eq!(example.quality, 0.8);
637 assert_eq!(example.route, Some("test".into()));
638 assert_eq!(example.weight, 1.5);
639 }
640
641 #[test]
642 fn test_batch_config() {
643 let config = BatchConfig::for_data_size(&DataSizeHint::Small);
644 assert_eq!(config.batch_size, 16);
645 assert_eq!(config.epochs, 5);
646 }
647
648 #[test]
649 fn test_pipeline_creation() {
650 let pipeline = TrainingPipeline::new("test", SonaConfig::default());
651 assert_eq!(pipeline.stage(), &PipelineStage::Idle);
652 assert_eq!(pipeline.example_count(), 0);
653 }
654
655 #[test]
656 fn test_pipeline_from_template() {
657 let template = TrainingTemplate::code_agent()
658 .with_hidden_dim(256);
659 let pipeline = TrainingPipeline::from_template(template);
660 assert_eq!(pipeline.name, "code-agent");
661 }
662
663 #[test]
664 fn test_pipeline_training() {
665 let mut pipeline = TrainingPipeline::new("test", SonaConfig::default())
666 .with_batch_config(BatchConfig {
667 batch_size: 2,
668 epochs: 2,
669 ..Default::default()
670 });
671
672 for i in 0..5 {
674 pipeline.add_example(TrainingExample::new(
675 vec![i as f32 * 0.1; 256],
676 0.7 + i as f32 * 0.05,
677 ));
678 }
679
680 let result = pipeline.train().unwrap();
681 assert_eq!(result.epochs_completed, 2);
682 assert!(result.total_examples > 0);
683 }
684
685 #[test]
686 fn test_pipeline_with_validation() {
687 let mut pipeline = TrainingPipeline::new("test", SonaConfig::default())
688 .with_batch_config(BatchConfig::single_pass());
689
690 pipeline.add_example(TrainingExample::new(vec![0.1; 256], 0.8));
691 pipeline.add_validation_example(TrainingExample::new(vec![0.2; 256], 0.9));
692
693 let result = pipeline.train().unwrap();
694 assert!(result.validation_quality.is_some());
695 }
696}