Skip to main content

ruvector_sona/training/
pipeline.rs

1//! Training Pipeline for SONA
2//!
3//! Structured training workflows with batching and callbacks.
4
5use super::metrics::{EpochStats, TrainingMetrics, TrainingResult};
6use super::templates::{DataSizeHint, TrainingMethod, TrainingTemplate};
7use crate::engine::SonaEngine;
8use crate::time_compat::Instant;
9use crate::types::SonaConfig;
10use serde::{Deserialize, Serialize};
11
12/// Training example with all data needed for learning
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct TrainingExample {
15    /// Input embedding
16    pub embedding: Vec<f32>,
17    /// Hidden activations (optional, defaults to embedding)
18    pub activations: Option<Vec<f32>>,
19    /// Attention weights (optional)
20    pub attention: Option<Vec<f32>>,
21    /// Quality score [0.0, 1.0]
22    pub quality: f32,
23    /// Reward signal (optional, defaults to quality)
24    pub reward: Option<f32>,
25    /// Model route identifier
26    pub route: Option<String>,
27    /// Context identifiers
28    pub context: Vec<String>,
29    /// Example weight for importance sampling
30    pub weight: f32,
31    /// Tags for filtering
32    pub tags: Vec<String>,
33}
34
35impl TrainingExample {
36    /// Create a new training example
37    pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
38        Self {
39            embedding,
40            activations: None,
41            attention: None,
42            quality,
43            reward: None,
44            route: None,
45            context: Vec::new(),
46            weight: 1.0,
47            tags: Vec::new(),
48        }
49    }
50
51    /// Set activations
52    pub fn with_activations(mut self, activations: Vec<f32>) -> Self {
53        self.activations = Some(activations);
54        self
55    }
56
57    /// Set attention
58    pub fn with_attention(mut self, attention: Vec<f32>) -> Self {
59        self.attention = Some(attention);
60        self
61    }
62
63    /// Set reward
64    pub fn with_reward(mut self, reward: f32) -> Self {
65        self.reward = Some(reward);
66        self
67    }
68
69    /// Set route
70    pub fn with_route(mut self, route: impl Into<String>) -> Self {
71        self.route = Some(route.into());
72        self
73    }
74
75    /// Add context
76    pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
77        self.context.push(ctx.into());
78        self
79    }
80
81    /// Set weight
82    pub fn with_weight(mut self, weight: f32) -> Self {
83        self.weight = weight;
84        self
85    }
86
87    /// Add tag
88    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
89        self.tags.push(tag.into());
90        self
91    }
92
93    /// Get activations or default to embedding
94    pub fn get_activations(&self) -> Vec<f32> {
95        self.activations
96            .clone()
97            .unwrap_or_else(|| self.embedding.clone())
98    }
99
100    /// Get attention or default
101    pub fn get_attention(&self) -> Vec<f32> {
102        self.attention
103            .clone()
104            .unwrap_or_else(|| vec![1.0 / 64.0; 64])
105    }
106
107    /// Get reward or default to quality
108    pub fn get_reward(&self) -> f32 {
109        self.reward.unwrap_or(self.quality)
110    }
111}
112
113/// Batch configuration for training
114#[derive(Clone, Debug, Serialize, Deserialize)]
115pub struct BatchConfig {
116    /// Batch size
117    pub batch_size: usize,
118    /// Shuffle examples
119    pub shuffle: bool,
120    /// Drop incomplete last batch
121    pub drop_last: bool,
122    /// Number of epochs
123    pub epochs: usize,
124    /// Early stopping patience (None = disabled)
125    pub early_stopping_patience: Option<usize>,
126    /// Minimum quality improvement for early stopping
127    pub min_quality_improvement: f32,
128}
129
130impl Default for BatchConfig {
131    fn default() -> Self {
132        Self {
133            batch_size: 32,
134            shuffle: true,
135            drop_last: false,
136            epochs: 1,
137            early_stopping_patience: None,
138            min_quality_improvement: 0.001,
139        }
140    }
141}
142
143impl BatchConfig {
144    /// Create config for single pass (no batching)
145    pub fn single_pass() -> Self {
146        Self {
147            batch_size: usize::MAX,
148            shuffle: false,
149            drop_last: false,
150            epochs: 1,
151            early_stopping_patience: None,
152            min_quality_improvement: 0.0,
153        }
154    }
155
156    /// Create config optimized for size hint
157    pub fn for_data_size(hint: &DataSizeHint) -> Self {
158        match hint {
159            DataSizeHint::Tiny => Self {
160                batch_size: 8,
161                epochs: 10,
162                early_stopping_patience: Some(3),
163                ..Default::default()
164            },
165            DataSizeHint::Small => Self {
166                batch_size: 16,
167                epochs: 5,
168                early_stopping_patience: Some(2),
169                ..Default::default()
170            },
171            DataSizeHint::Medium => Self {
172                batch_size: 32,
173                epochs: 3,
174                early_stopping_patience: Some(2),
175                ..Default::default()
176            },
177            DataSizeHint::Large => Self {
178                batch_size: 64,
179                epochs: 2,
180                ..Default::default()
181            },
182            DataSizeHint::Massive => Self {
183                batch_size: 128,
184                epochs: 1,
185                ..Default::default()
186            },
187        }
188    }
189}
190
191/// Pipeline stage for tracking progress
192#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
193pub enum PipelineStage {
194    /// Not started
195    Idle,
196    /// Loading and preprocessing data
197    Preprocessing,
198    /// Training in progress
199    Training,
200    /// Running validation
201    Validation,
202    /// Extracting patterns
203    PatternExtraction,
204    /// Exporting results
205    Export,
206    /// Completed successfully
207    Completed,
208    /// Failed with error
209    Failed,
210}
211
212impl std::fmt::Display for PipelineStage {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        match self {
215            PipelineStage::Idle => write!(f, "idle"),
216            PipelineStage::Preprocessing => write!(f, "preprocessing"),
217            PipelineStage::Training => write!(f, "training"),
218            PipelineStage::Validation => write!(f, "validation"),
219            PipelineStage::PatternExtraction => write!(f, "pattern_extraction"),
220            PipelineStage::Export => write!(f, "export"),
221            PipelineStage::Completed => write!(f, "completed"),
222            PipelineStage::Failed => write!(f, "failed"),
223        }
224    }
225}
226
227/// Callback trait for training events
228pub trait TrainingCallback: Send + Sync {
229    /// Called when stage changes
230    fn on_stage_change(&self, _stage: &PipelineStage) {}
231
232    /// Called after each batch
233    fn on_batch_complete(&self, _batch_idx: usize, _total_batches: usize, _avg_quality: f32) {}
234
235    /// Called after each epoch
236    fn on_epoch_complete(&self, _epoch: usize, _stats: &EpochStats) {}
237
238    /// Called when training completes
239    fn on_training_complete(&self, _result: &TrainingResult) {}
240
241    /// Called on error
242    fn on_error(&self, _error: &str) {}
243}
244
245/// No-op callback implementation
246pub struct NoOpCallback;
247impl TrainingCallback for NoOpCallback {}
248
249/// Logging callback implementation
250pub struct LoggingCallback {
251    prefix: String,
252}
253
254impl LoggingCallback {
255    /// Create with prefix
256    pub fn new(prefix: impl Into<String>) -> Self {
257        Self {
258            prefix: prefix.into(),
259        }
260    }
261}
262
263impl TrainingCallback for LoggingCallback {
264    fn on_stage_change(&self, stage: &PipelineStage) {
265        println!("[{}] Stage: {}", self.prefix, stage);
266    }
267
268    fn on_batch_complete(&self, batch_idx: usize, total_batches: usize, avg_quality: f32) {
269        if batch_idx % 10 == 0 || batch_idx == total_batches - 1 {
270            println!(
271                "[{}] Batch {}/{}: avg_quality={:.4}",
272                self.prefix,
273                batch_idx + 1,
274                total_batches,
275                avg_quality
276            );
277        }
278    }
279
280    fn on_epoch_complete(&self, epoch: usize, stats: &EpochStats) {
281        println!(
282            "[{}] Epoch {}: examples={}, avg_quality={:.4}, duration={:.2}s",
283            self.prefix,
284            epoch + 1,
285            stats.examples_processed,
286            stats.avg_quality,
287            stats.duration_secs
288        );
289    }
290
291    fn on_training_complete(&self, result: &TrainingResult) {
292        println!(
293            "[{}] Training complete: epochs={}, patterns={}, final_quality={:.4}",
294            self.prefix, result.epochs_completed, result.patterns_learned, result.final_avg_quality
295        );
296    }
297
298    fn on_error(&self, error: &str) {
299        eprintln!("[{}] ERROR: {}", self.prefix, error);
300    }
301}
302
303/// Training pipeline for structured training workflows
304pub struct TrainingPipeline {
305    /// Pipeline name
306    name: String,
307    /// SONA engine
308    engine: SonaEngine,
309    /// Batch configuration
310    batch_config: BatchConfig,
311    /// Training method
312    training_method: TrainingMethod,
313    /// Current stage
314    stage: PipelineStage,
315    /// Training examples buffer
316    examples: Vec<TrainingExample>,
317    /// Validation examples
318    validation_examples: Vec<TrainingExample>,
319    /// Training metrics
320    metrics: TrainingMetrics,
321    /// Callback
322    callback: Box<dyn TrainingCallback>,
323    /// Enable pattern extraction after training
324    extract_patterns: bool,
325}
326
327impl TrainingPipeline {
328    /// Create a new training pipeline
329    pub fn new(name: impl Into<String>, config: SonaConfig) -> Self {
330        let name = name.into();
331        Self {
332            name: name.clone(),
333            engine: SonaEngine::with_config(config),
334            batch_config: BatchConfig::default(),
335            training_method: TrainingMethod::default(),
336            stage: PipelineStage::Idle,
337            examples: Vec::new(),
338            validation_examples: Vec::new(),
339            metrics: TrainingMetrics::new(&name),
340            callback: Box::new(NoOpCallback),
341            extract_patterns: true,
342        }
343    }
344
345    /// Create from template
346    pub fn from_template(template: TrainingTemplate) -> Self {
347        let batch_config = BatchConfig::for_data_size(&template.expected_data_size);
348        let mut pipeline = Self::new(&template.name, template.sona_config);
349        pipeline.batch_config = batch_config;
350        pipeline.training_method = template.training_method;
351        pipeline
352    }
353
354    /// Set batch configuration
355    pub fn with_batch_config(mut self, config: BatchConfig) -> Self {
356        self.batch_config = config;
357        self
358    }
359
360    /// Set training method
361    pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
362        self.training_method = method;
363        self
364    }
365
366    /// Set callback
367    pub fn with_callback<C: TrainingCallback + 'static>(mut self, callback: C) -> Self {
368        self.callback = Box::new(callback);
369        self
370    }
371
372    /// Enable/disable pattern extraction
373    pub fn with_pattern_extraction(mut self, enabled: bool) -> Self {
374        self.extract_patterns = enabled;
375        self
376    }
377
378    /// Add a training example
379    pub fn add_example(&mut self, example: TrainingExample) {
380        self.examples.push(example);
381    }
382
383    /// Add multiple training examples
384    pub fn add_examples(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
385        self.examples.extend(examples);
386    }
387
388    /// Add validation example
389    pub fn add_validation_example(&mut self, example: TrainingExample) {
390        self.validation_examples.push(example);
391    }
392
393    /// Get current stage
394    pub fn stage(&self) -> &PipelineStage {
395        &self.stage
396    }
397
398    /// Get number of examples
399    pub fn example_count(&self) -> usize {
400        self.examples.len()
401    }
402
403    /// Get metrics
404    pub fn metrics(&self) -> &TrainingMetrics {
405        &self.metrics
406    }
407
408    /// Get engine reference
409    pub fn engine(&self) -> &SonaEngine {
410        &self.engine
411    }
412
413    /// Get mutable engine reference
414    pub fn engine_mut(&mut self) -> &mut SonaEngine {
415        &mut self.engine
416    }
417
418    /// Run the training pipeline
419    pub fn train(&mut self) -> Result<TrainingResult, String> {
420        let start = Instant::now();
421
422        // Preprocessing
423        self.set_stage(PipelineStage::Preprocessing);
424        self.preprocess()?;
425
426        // Training
427        self.set_stage(PipelineStage::Training);
428        let epoch_stats = self.run_training()?;
429
430        // Validation (if examples provided)
431        if !self.validation_examples.is_empty() {
432            self.set_stage(PipelineStage::Validation);
433            self.run_validation()?;
434        }
435
436        // Pattern extraction
437        if self.extract_patterns {
438            self.set_stage(PipelineStage::PatternExtraction);
439            self.engine.force_learn();
440        }
441
442        self.set_stage(PipelineStage::Completed);
443
444        let result = TrainingResult {
445            pipeline_name: self.name.clone(),
446            epochs_completed: epoch_stats.len(),
447            total_examples: self.metrics.total_examples,
448            patterns_learned: self.metrics.patterns_learned,
449            final_avg_quality: self.metrics.avg_quality(),
450            total_duration_secs: start.elapsed().as_secs_f64(),
451            epoch_stats,
452            validation_quality: self.metrics.validation_quality,
453        };
454
455        self.callback.on_training_complete(&result);
456        Ok(result)
457    }
458
459    /// Set stage and notify callback
460    fn set_stage(&mut self, stage: PipelineStage) {
461        self.stage = stage.clone();
462        self.callback.on_stage_change(&stage);
463    }
464
465    /// Preprocess examples
466    fn preprocess(&mut self) -> Result<(), String> {
467        if self.examples.is_empty() {
468            return Err("No training examples provided".into());
469        }
470
471        // Shuffle if configured
472        if self.batch_config.shuffle {
473            use rand::seq::SliceRandom;
474            let mut rng = rand::thread_rng();
475            self.examples.shuffle(&mut rng);
476        }
477
478        Ok(())
479    }
480
481    /// Run training epochs
482    fn run_training(&mut self) -> Result<Vec<EpochStats>, String> {
483        let mut all_epoch_stats = Vec::new();
484        let mut best_quality = 0.0f32;
485        let mut patience_counter = 0usize;
486
487        for epoch in 0..self.batch_config.epochs {
488            let epoch_start = Instant::now();
489            let mut epoch_quality_sum = 0.0f32;
490            let mut epoch_examples = 0usize;
491
492            // Create batch indices (to avoid borrow checker issues)
493            let batch_size = self.batch_config.batch_size;
494            let total_examples = self.examples.len();
495            let mut batch_indices: Vec<(usize, usize)> = Vec::new();
496            let mut start = 0;
497            while start < total_examples {
498                let end = (start + batch_size).min(total_examples);
499                if end > start && (!self.batch_config.drop_last || end - start == batch_size) {
500                    batch_indices.push((start, end));
501                }
502                start = end;
503            }
504            let total_batches = batch_indices.len();
505
506            for (batch_idx, (start, end)) in batch_indices.into_iter().enumerate() {
507                let batch_quality = self.train_batch_range(start, end)?;
508                let batch_len = end - start;
509                epoch_quality_sum += batch_quality * batch_len as f32;
510                epoch_examples += batch_len;
511
512                self.callback.on_batch_complete(
513                    batch_idx,
514                    total_batches,
515                    epoch_quality_sum / epoch_examples as f32,
516                );
517            }
518
519            let epoch_avg_quality = if epoch_examples > 0 {
520                epoch_quality_sum / epoch_examples as f32
521            } else {
522                0.0
523            };
524
525            let epoch_stats = EpochStats {
526                epoch,
527                examples_processed: epoch_examples,
528                avg_quality: epoch_avg_quality,
529                duration_secs: epoch_start.elapsed().as_secs_f64(),
530            };
531
532            self.callback.on_epoch_complete(epoch, &epoch_stats);
533            all_epoch_stats.push(epoch_stats);
534
535            // Early stopping check
536            if let Some(patience) = self.batch_config.early_stopping_patience {
537                let improvement = epoch_avg_quality - best_quality;
538                if improvement > self.batch_config.min_quality_improvement {
539                    best_quality = epoch_avg_quality;
540                    patience_counter = 0;
541                } else {
542                    patience_counter += 1;
543                    if patience_counter >= patience {
544                        break; // Early stop
545                    }
546                }
547            }
548
549            // Reshuffle for next epoch
550            if self.batch_config.shuffle && epoch + 1 < self.batch_config.epochs {
551                use rand::seq::SliceRandom;
552                let mut rng = rand::thread_rng();
553                self.examples.shuffle(&mut rng);
554            }
555        }
556
557        Ok(all_epoch_stats)
558    }
559
560    /// Train on examples in a range
561    fn train_batch_range(&mut self, start: usize, end: usize) -> Result<f32, String> {
562        let mut quality_sum = 0.0f32;
563        let batch_len = end - start;
564
565        for idx in start..end {
566            let example = &self.examples[idx];
567
568            // Begin trajectory using builder API
569            let mut builder = self.engine.begin_trajectory(example.embedding.clone());
570
571            // Set route
572            if let Some(ref route) = example.route {
573                builder.set_model_route(route);
574            }
575
576            // Add context
577            for ctx in &example.context {
578                builder.add_context(ctx);
579            }
580
581            // Add step
582            builder.add_step(
583                example.get_activations(),
584                example.get_attention(),
585                example.get_reward() * example.weight,
586            );
587
588            // End trajectory
589            self.engine.end_trajectory(builder, example.quality);
590
591            quality_sum += example.quality;
592            self.metrics.total_examples += 1;
593            self.metrics.add_quality_sample(example.quality);
594        }
595
596        // Run tick to process accumulated trajectories
597        self.engine.tick();
598
599        Ok(quality_sum / batch_len as f32)
600    }
601
602    /// Run validation
603    fn run_validation(&mut self) -> Result<(), String> {
604        let mut quality_sum = 0.0f32;
605
606        for example in &self.validation_examples {
607            // Apply learned transformations
608            let mut output = vec![0.0f32; example.embedding.len()];
609            self.engine
610                .apply_micro_lora(&example.embedding, &mut output);
611
612            // In a real scenario, you'd evaluate the model output
613            // For now, we track the expected quality
614            quality_sum += example.quality;
615        }
616
617        self.metrics.validation_quality = Some(quality_sum / self.validation_examples.len() as f32);
618
619        Ok(())
620    }
621
622    /// Clear examples (keep engine state)
623    pub fn clear_examples(&mut self) {
624        self.examples.clear();
625        self.validation_examples.clear();
626    }
627
628    /// Reset pipeline (clear examples and metrics)
629    pub fn reset(&mut self) {
630        self.clear_examples();
631        self.metrics = TrainingMetrics::new(&self.name);
632        self.stage = PipelineStage::Idle;
633    }
634}
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639
640    #[test]
641    fn test_training_example() {
642        let example = TrainingExample::new(vec![0.1; 256], 0.8)
643            .with_route("test")
644            .with_context("ctx1")
645            .with_weight(1.5)
646            .with_tag("test");
647
648        assert_eq!(example.quality, 0.8);
649        assert_eq!(example.route, Some("test".into()));
650        assert_eq!(example.weight, 1.5);
651    }
652
653    #[test]
654    fn test_batch_config() {
655        let config = BatchConfig::for_data_size(&DataSizeHint::Small);
656        assert_eq!(config.batch_size, 16);
657        assert_eq!(config.epochs, 5);
658    }
659
660    #[test]
661    fn test_pipeline_creation() {
662        let pipeline = TrainingPipeline::new("test", SonaConfig::default());
663        assert_eq!(pipeline.stage(), &PipelineStage::Idle);
664        assert_eq!(pipeline.example_count(), 0);
665    }
666
667    #[test]
668    fn test_pipeline_from_template() {
669        let template = TrainingTemplate::code_agent().with_hidden_dim(256);
670        let pipeline = TrainingPipeline::from_template(template);
671        assert_eq!(pipeline.name, "code-agent");
672    }
673
674    #[test]
675    fn test_pipeline_training() {
676        let mut pipeline =
677            TrainingPipeline::new("test", SonaConfig::default()).with_batch_config(BatchConfig {
678                batch_size: 2,
679                epochs: 2,
680                ..Default::default()
681            });
682
683        // Add examples
684        for i in 0..5 {
685            pipeline.add_example(TrainingExample::new(
686                vec![i as f32 * 0.1; 256],
687                0.7 + i as f32 * 0.05,
688            ));
689        }
690
691        let result = pipeline.train().unwrap();
692        assert_eq!(result.epochs_completed, 2);
693        assert!(result.total_examples > 0);
694    }
695
696    #[test]
697    fn test_pipeline_with_validation() {
698        let mut pipeline = TrainingPipeline::new("test", SonaConfig::default())
699            .with_batch_config(BatchConfig::single_pass());
700
701        pipeline.add_example(TrainingExample::new(vec![0.1; 256], 0.8));
702        pipeline.add_validation_example(TrainingExample::new(vec![0.2; 256], 0.9));
703
704        let result = pipeline.train().unwrap();
705        assert!(result.validation_quality.is_some());
706    }
707}