Skip to main content

ruvector_sona/training/
pipeline.rs

1//! Training Pipeline for SONA
2//!
3//! Structured training workflows with batching and callbacks.
4
5use super::metrics::{EpochStats, TrainingMetrics, TrainingResult};
6use super::templates::{DataSizeHint, TrainingMethod, TrainingTemplate};
7use crate::engine::SonaEngine;
8use crate::time_compat::Instant;
9use crate::types::SonaConfig;
10use serde::{Deserialize, Serialize};
11
12/// Training example with all data needed for learning
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct TrainingExample {
15    /// Input embedding
16    pub embedding: Vec<f32>,
17    /// Hidden activations (optional, defaults to embedding)
18    pub activations: Option<Vec<f32>>,
19    /// Attention weights (optional)
20    pub attention: Option<Vec<f32>>,
21    /// Quality score [0.0, 1.0]
22    pub quality: f32,
23    /// Reward signal (optional, defaults to quality)
24    pub reward: Option<f32>,
25    /// Model route identifier
26    pub route: Option<String>,
27    /// Context identifiers
28    pub context: Vec<String>,
29    /// Example weight for importance sampling
30    pub weight: f32,
31    /// Tags for filtering
32    pub tags: Vec<String>,
33}
34
35impl TrainingExample {
36    /// Create a new training example
37    pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
38        Self {
39            embedding,
40            activations: None,
41            attention: None,
42            quality,
43            reward: None,
44            route: None,
45            context: Vec::new(),
46            weight: 1.0,
47            tags: Vec::new(),
48        }
49    }
50
51    /// Set activations
52    pub fn with_activations(mut self, activations: Vec<f32>) -> Self {
53        self.activations = Some(activations);
54        self
55    }
56
57    /// Set attention
58    pub fn with_attention(mut self, attention: Vec<f32>) -> Self {
59        self.attention = Some(attention);
60        self
61    }
62
63    /// Set reward
64    pub fn with_reward(mut self, reward: f32) -> Self {
65        self.reward = Some(reward);
66        self
67    }
68
69    /// Set route
70    pub fn with_route(mut self, route: impl Into<String>) -> Self {
71        self.route = Some(route.into());
72        self
73    }
74
75    /// Add context
76    pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
77        self.context.push(ctx.into());
78        self
79    }
80
81    /// Set weight
82    pub fn with_weight(mut self, weight: f32) -> Self {
83        self.weight = weight;
84        self
85    }
86
87    /// Add tag
88    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
89        self.tags.push(tag.into());
90        self
91    }
92
93    /// Get activations or default to embedding
94    pub fn get_activations(&self) -> Vec<f32> {
95        self.activations
96            .clone()
97            .unwrap_or_else(|| self.embedding.clone())
98    }
99
100    /// Get attention or default
101    pub fn get_attention(&self) -> Vec<f32> {
102        self.attention
103            .clone()
104            .unwrap_or_else(|| vec![1.0 / 64.0; 64])
105    }
106
107    /// Get reward or default to quality
108    pub fn get_reward(&self) -> f32 {
109        self.reward.unwrap_or(self.quality)
110    }
111}
112
113/// Batch configuration for training
114#[derive(Clone, Debug, Serialize, Deserialize)]
115pub struct BatchConfig {
116    /// Batch size
117    pub batch_size: usize,
118    /// Shuffle examples
119    pub shuffle: bool,
120    /// Drop incomplete last batch
121    pub drop_last: bool,
122    /// Number of epochs
123    pub epochs: usize,
124    /// Early stopping patience (None = disabled)
125    pub early_stopping_patience: Option<usize>,
126    /// Minimum quality improvement for early stopping
127    pub min_quality_improvement: f32,
128}
129
130impl Default for BatchConfig {
131    fn default() -> Self {
132        Self {
133            batch_size: 32,
134            shuffle: true,
135            drop_last: false,
136            epochs: 1,
137            early_stopping_patience: None,
138            min_quality_improvement: 0.001,
139        }
140    }
141}
142
143impl BatchConfig {
144    /// Create config for single pass (no batching)
145    pub fn single_pass() -> Self {
146        Self {
147            batch_size: usize::MAX,
148            shuffle: false,
149            drop_last: false,
150            epochs: 1,
151            early_stopping_patience: None,
152            min_quality_improvement: 0.0,
153        }
154    }
155
156    /// Create config optimized for size hint
157    pub fn for_data_size(hint: &DataSizeHint) -> Self {
158        match hint {
159            DataSizeHint::Tiny => Self {
160                batch_size: 8,
161                epochs: 10,
162                early_stopping_patience: Some(3),
163                ..Default::default()
164            },
165            DataSizeHint::Small => Self {
166                batch_size: 16,
167                epochs: 5,
168                early_stopping_patience: Some(2),
169                ..Default::default()
170            },
171            DataSizeHint::Medium => Self {
172                batch_size: 32,
173                epochs: 3,
174                early_stopping_patience: Some(2),
175                ..Default::default()
176            },
177            DataSizeHint::Large => Self {
178                batch_size: 64,
179                epochs: 2,
180                ..Default::default()
181            },
182            DataSizeHint::Massive => Self {
183                batch_size: 128,
184                epochs: 1,
185                ..Default::default()
186            },
187        }
188    }
189}
190
191/// Pipeline stage for tracking progress
192#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
193pub enum PipelineStage {
194    /// Not started
195    Idle,
196    /// Loading and preprocessing data
197    Preprocessing,
198    /// Training in progress
199    Training,
200    /// Running validation
201    Validation,
202    /// Extracting patterns
203    PatternExtraction,
204    /// Exporting results
205    Export,
206    /// Completed successfully
207    Completed,
208    /// Failed with error
209    Failed,
210}
211
212impl std::fmt::Display for PipelineStage {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        match self {
215            PipelineStage::Idle => write!(f, "idle"),
216            PipelineStage::Preprocessing => write!(f, "preprocessing"),
217            PipelineStage::Training => write!(f, "training"),
218            PipelineStage::Validation => write!(f, "validation"),
219            PipelineStage::PatternExtraction => write!(f, "pattern_extraction"),
220            PipelineStage::Export => write!(f, "export"),
221            PipelineStage::Completed => write!(f, "completed"),
222            PipelineStage::Failed => write!(f, "failed"),
223        }
224    }
225}
226
227/// Callback trait for training events
228pub trait TrainingCallback: Send + Sync {
229    /// Called when stage changes
230    fn on_stage_change(&self, _stage: &PipelineStage) {}
231
232    /// Called after each batch
233    fn on_batch_complete(&self, _batch_idx: usize, _total_batches: usize, _avg_quality: f32) {}
234
235    /// Called after each epoch
236    fn on_epoch_complete(&self, _epoch: usize, _stats: &EpochStats) {}
237
238    /// Called when training completes
239    fn on_training_complete(&self, _result: &TrainingResult) {}
240
241    /// Called on error
242    fn on_error(&self, _error: &str) {}
243}
244
245/// No-op callback implementation
246pub struct NoOpCallback;
247impl TrainingCallback for NoOpCallback {}
248
249/// Logging callback implementation
250#[allow(dead_code)]
251pub struct LoggingCallback {
252    prefix: String,
253}
254
255#[allow(dead_code)]
256impl LoggingCallback {
257    /// Create with prefix
258    pub fn new(prefix: impl Into<String>) -> Self {
259        Self {
260            prefix: prefix.into(),
261        }
262    }
263}
264
265impl TrainingCallback for LoggingCallback {
266    fn on_stage_change(&self, stage: &PipelineStage) {
267        println!("[{}] Stage: {}", self.prefix, stage);
268    }
269
270    fn on_batch_complete(&self, batch_idx: usize, total_batches: usize, avg_quality: f32) {
271        if batch_idx % 10 == 0 || batch_idx == total_batches - 1 {
272            println!(
273                "[{}] Batch {}/{}: avg_quality={:.4}",
274                self.prefix,
275                batch_idx + 1,
276                total_batches,
277                avg_quality
278            );
279        }
280    }
281
282    fn on_epoch_complete(&self, epoch: usize, stats: &EpochStats) {
283        println!(
284            "[{}] Epoch {}: examples={}, avg_quality={:.4}, duration={:.2}s",
285            self.prefix,
286            epoch + 1,
287            stats.examples_processed,
288            stats.avg_quality,
289            stats.duration_secs
290        );
291    }
292
293    fn on_training_complete(&self, result: &TrainingResult) {
294        println!(
295            "[{}] Training complete: epochs={}, patterns={}, final_quality={:.4}",
296            self.prefix, result.epochs_completed, result.patterns_learned, result.final_avg_quality
297        );
298    }
299
300    fn on_error(&self, error: &str) {
301        eprintln!("[{}] ERROR: {}", self.prefix, error);
302    }
303}
304
305/// Training pipeline for structured training workflows
306pub struct TrainingPipeline {
307    /// Pipeline name
308    name: String,
309    /// SONA engine
310    engine: SonaEngine,
311    /// Batch configuration
312    batch_config: BatchConfig,
313    /// Training method
314    training_method: TrainingMethod,
315    /// Current stage
316    stage: PipelineStage,
317    /// Training examples buffer
318    examples: Vec<TrainingExample>,
319    /// Validation examples
320    validation_examples: Vec<TrainingExample>,
321    /// Training metrics
322    metrics: TrainingMetrics,
323    /// Callback
324    callback: Box<dyn TrainingCallback>,
325    /// Enable pattern extraction after training
326    extract_patterns: bool,
327}
328
329impl TrainingPipeline {
330    /// Create a new training pipeline
331    pub fn new(name: impl Into<String>, config: SonaConfig) -> Self {
332        let name = name.into();
333        Self {
334            name: name.clone(),
335            engine: SonaEngine::with_config(config),
336            batch_config: BatchConfig::default(),
337            training_method: TrainingMethod::default(),
338            stage: PipelineStage::Idle,
339            examples: Vec::new(),
340            validation_examples: Vec::new(),
341            metrics: TrainingMetrics::new(&name),
342            callback: Box::new(NoOpCallback),
343            extract_patterns: true,
344        }
345    }
346
347    /// Create from template
348    pub fn from_template(template: TrainingTemplate) -> Self {
349        let batch_config = BatchConfig::for_data_size(&template.expected_data_size);
350        let mut pipeline = Self::new(&template.name, template.sona_config);
351        pipeline.batch_config = batch_config;
352        pipeline.training_method = template.training_method;
353        pipeline
354    }
355
356    /// Set batch configuration
357    pub fn with_batch_config(mut self, config: BatchConfig) -> Self {
358        self.batch_config = config;
359        self
360    }
361
362    /// Set training method
363    pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
364        self.training_method = method;
365        self
366    }
367
368    /// Set callback
369    pub fn with_callback<C: TrainingCallback + 'static>(mut self, callback: C) -> Self {
370        self.callback = Box::new(callback);
371        self
372    }
373
374    /// Enable/disable pattern extraction
375    pub fn with_pattern_extraction(mut self, enabled: bool) -> Self {
376        self.extract_patterns = enabled;
377        self
378    }
379
380    /// Add a training example
381    pub fn add_example(&mut self, example: TrainingExample) {
382        self.examples.push(example);
383    }
384
385    /// Add multiple training examples
386    pub fn add_examples(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
387        self.examples.extend(examples);
388    }
389
390    /// Add validation example
391    pub fn add_validation_example(&mut self, example: TrainingExample) {
392        self.validation_examples.push(example);
393    }
394
395    /// Get current stage
396    pub fn stage(&self) -> &PipelineStage {
397        &self.stage
398    }
399
400    /// Get number of examples
401    pub fn example_count(&self) -> usize {
402        self.examples.len()
403    }
404
405    /// Get metrics
406    pub fn metrics(&self) -> &TrainingMetrics {
407        &self.metrics
408    }
409
410    /// Get engine reference
411    pub fn engine(&self) -> &SonaEngine {
412        &self.engine
413    }
414
415    /// Get mutable engine reference
416    pub fn engine_mut(&mut self) -> &mut SonaEngine {
417        &mut self.engine
418    }
419
420    /// Run the training pipeline
421    pub fn train(&mut self) -> Result<TrainingResult, String> {
422        let start = Instant::now();
423
424        // Preprocessing
425        self.set_stage(PipelineStage::Preprocessing);
426        self.preprocess()?;
427
428        // Training
429        self.set_stage(PipelineStage::Training);
430        let epoch_stats = self.run_training()?;
431
432        // Validation (if examples provided)
433        if !self.validation_examples.is_empty() {
434            self.set_stage(PipelineStage::Validation);
435            self.run_validation()?;
436        }
437
438        // Pattern extraction
439        if self.extract_patterns {
440            self.set_stage(PipelineStage::PatternExtraction);
441            self.engine.force_learn();
442        }
443
444        self.set_stage(PipelineStage::Completed);
445
446        let result = TrainingResult {
447            pipeline_name: self.name.clone(),
448            epochs_completed: epoch_stats.len(),
449            total_examples: self.metrics.total_examples,
450            patterns_learned: self.metrics.patterns_learned,
451            final_avg_quality: self.metrics.avg_quality(),
452            total_duration_secs: start.elapsed().as_secs_f64(),
453            epoch_stats,
454            validation_quality: self.metrics.validation_quality,
455        };
456
457        self.callback.on_training_complete(&result);
458        Ok(result)
459    }
460
461    /// Set stage and notify callback
462    fn set_stage(&mut self, stage: PipelineStage) {
463        self.stage = stage.clone();
464        self.callback.on_stage_change(&stage);
465    }
466
467    /// Preprocess examples
468    fn preprocess(&mut self) -> Result<(), String> {
469        if self.examples.is_empty() {
470            return Err("No training examples provided".into());
471        }
472
473        // Shuffle if configured
474        if self.batch_config.shuffle {
475            use rand::seq::SliceRandom;
476            let mut rng = rand::thread_rng();
477            self.examples.shuffle(&mut rng);
478        }
479
480        Ok(())
481    }
482
483    /// Run training epochs
484    fn run_training(&mut self) -> Result<Vec<EpochStats>, String> {
485        let mut all_epoch_stats = Vec::new();
486        let mut best_quality = 0.0f32;
487        let mut patience_counter = 0usize;
488
489        for epoch in 0..self.batch_config.epochs {
490            let epoch_start = Instant::now();
491            let mut epoch_quality_sum = 0.0f32;
492            let mut epoch_examples = 0usize;
493
494            // Create batch indices (to avoid borrow checker issues)
495            let batch_size = self.batch_config.batch_size;
496            let total_examples = self.examples.len();
497            let mut batch_indices: Vec<(usize, usize)> = Vec::new();
498            let mut start = 0;
499            while start < total_examples {
500                let end = (start + batch_size).min(total_examples);
501                if end > start && (!self.batch_config.drop_last || end - start == batch_size) {
502                    batch_indices.push((start, end));
503                }
504                start = end;
505            }
506            let total_batches = batch_indices.len();
507
508            for (batch_idx, (start, end)) in batch_indices.into_iter().enumerate() {
509                let batch_quality = self.train_batch_range(start, end)?;
510                let batch_len = end - start;
511                epoch_quality_sum += batch_quality * batch_len as f32;
512                epoch_examples += batch_len;
513
514                self.callback.on_batch_complete(
515                    batch_idx,
516                    total_batches,
517                    epoch_quality_sum / epoch_examples as f32,
518                );
519            }
520
521            let epoch_avg_quality = if epoch_examples > 0 {
522                epoch_quality_sum / epoch_examples as f32
523            } else {
524                0.0
525            };
526
527            let epoch_stats = EpochStats {
528                epoch,
529                examples_processed: epoch_examples,
530                avg_quality: epoch_avg_quality,
531                duration_secs: epoch_start.elapsed().as_secs_f64(),
532            };
533
534            self.callback.on_epoch_complete(epoch, &epoch_stats);
535            all_epoch_stats.push(epoch_stats);
536
537            // Early stopping check
538            if let Some(patience) = self.batch_config.early_stopping_patience {
539                let improvement = epoch_avg_quality - best_quality;
540                if improvement > self.batch_config.min_quality_improvement {
541                    best_quality = epoch_avg_quality;
542                    patience_counter = 0;
543                } else {
544                    patience_counter += 1;
545                    if patience_counter >= patience {
546                        break; // Early stop
547                    }
548                }
549            }
550
551            // Reshuffle for next epoch
552            if self.batch_config.shuffle && epoch + 1 < self.batch_config.epochs {
553                use rand::seq::SliceRandom;
554                let mut rng = rand::thread_rng();
555                self.examples.shuffle(&mut rng);
556            }
557        }
558
559        Ok(all_epoch_stats)
560    }
561
562    /// Train on examples in a range
563    fn train_batch_range(&mut self, start: usize, end: usize) -> Result<f32, String> {
564        let mut quality_sum = 0.0f32;
565        let batch_len = end - start;
566
567        for idx in start..end {
568            let example = &self.examples[idx];
569
570            // Begin trajectory using builder API
571            let mut builder = self.engine.begin_trajectory(example.embedding.clone());
572
573            // Set route
574            if let Some(ref route) = example.route {
575                builder.set_model_route(route);
576            }
577
578            // Add context
579            for ctx in &example.context {
580                builder.add_context(ctx);
581            }
582
583            // Add step
584            builder.add_step(
585                example.get_activations(),
586                example.get_attention(),
587                example.get_reward() * example.weight,
588            );
589
590            // End trajectory
591            self.engine.end_trajectory(builder, example.quality);
592
593            quality_sum += example.quality;
594            self.metrics.total_examples += 1;
595            self.metrics.add_quality_sample(example.quality);
596        }
597
598        // Run tick to process accumulated trajectories
599        self.engine.tick();
600
601        Ok(quality_sum / batch_len as f32)
602    }
603
604    /// Run validation
605    fn run_validation(&mut self) -> Result<(), String> {
606        let mut quality_sum = 0.0f32;
607
608        for example in &self.validation_examples {
609            // Apply learned transformations
610            let mut output = vec![0.0f32; example.embedding.len()];
611            self.engine
612                .apply_micro_lora(&example.embedding, &mut output);
613
614            // In a real scenario, you'd evaluate the model output
615            // For now, we track the expected quality
616            quality_sum += example.quality;
617        }
618
619        self.metrics.validation_quality = Some(quality_sum / self.validation_examples.len() as f32);
620
621        Ok(())
622    }
623
624    /// Clear examples (keep engine state)
625    pub fn clear_examples(&mut self) {
626        self.examples.clear();
627        self.validation_examples.clear();
628    }
629
630    /// Reset pipeline (clear examples and metrics)
631    pub fn reset(&mut self) {
632        self.clear_examples();
633        self.metrics = TrainingMetrics::new(&self.name);
634        self.stage = PipelineStage::Idle;
635    }
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641
642    #[test]
643    fn test_training_example() {
644        let example = TrainingExample::new(vec![0.1; 256], 0.8)
645            .with_route("test")
646            .with_context("ctx1")
647            .with_weight(1.5)
648            .with_tag("test");
649
650        assert_eq!(example.quality, 0.8);
651        assert_eq!(example.route, Some("test".into()));
652        assert_eq!(example.weight, 1.5);
653    }
654
655    #[test]
656    fn test_batch_config() {
657        let config = BatchConfig::for_data_size(&DataSizeHint::Small);
658        assert_eq!(config.batch_size, 16);
659        assert_eq!(config.epochs, 5);
660    }
661
662    #[test]
663    fn test_pipeline_creation() {
664        let pipeline = TrainingPipeline::new("test", SonaConfig::default());
665        assert_eq!(pipeline.stage(), &PipelineStage::Idle);
666        assert_eq!(pipeline.example_count(), 0);
667    }
668
669    #[test]
670    fn test_pipeline_from_template() {
671        let template = TrainingTemplate::code_agent().with_hidden_dim(256);
672        let pipeline = TrainingPipeline::from_template(template);
673        assert_eq!(pipeline.name, "code-agent");
674    }
675
676    #[test]
677    fn test_pipeline_training() {
678        let mut pipeline =
679            TrainingPipeline::new("test", SonaConfig::default()).with_batch_config(BatchConfig {
680                batch_size: 2,
681                epochs: 2,
682                ..Default::default()
683            });
684
685        // Add examples
686        for i in 0..5 {
687            pipeline.add_example(TrainingExample::new(
688                vec![i as f32 * 0.1; 256],
689                0.7 + i as f32 * 0.05,
690            ));
691        }
692
693        let result = pipeline.train().unwrap();
694        assert_eq!(result.epochs_completed, 2);
695        assert!(result.total_examples > 0);
696    }
697
698    #[test]
699    fn test_pipeline_with_validation() {
700        let mut pipeline = TrainingPipeline::new("test", SonaConfig::default())
701            .with_batch_config(BatchConfig::single_pass());
702
703        pipeline.add_example(TrainingExample::new(vec![0.1; 256], 0.8));
704        pipeline.add_validation_example(TrainingExample::new(vec![0.2; 256], 0.9));
705
706        let result = pipeline.train().unwrap();
707        assert!(result.validation_quality.is_some());
708    }
709}