ruvector_sona/training/
pipeline.rs

1//! Training Pipeline for SONA
2//!
3//! Structured training workflows with batching and callbacks.
4
5use crate::engine::SonaEngine;
6use crate::types::SonaConfig;
7use super::templates::{TrainingTemplate, TrainingMethod, DataSizeHint};
8use super::metrics::{TrainingMetrics, TrainingResult, EpochStats};
9use serde::{Deserialize, Serialize};
10use std::time::Instant;
11
12/// Training example with all data needed for learning
13#[derive(Clone, Debug, Serialize, Deserialize)]
14pub struct TrainingExample {
15    /// Input embedding
16    pub embedding: Vec<f32>,
17    /// Hidden activations (optional, defaults to embedding)
18    pub activations: Option<Vec<f32>>,
19    /// Attention weights (optional)
20    pub attention: Option<Vec<f32>>,
21    /// Quality score [0.0, 1.0]
22    pub quality: f32,
23    /// Reward signal (optional, defaults to quality)
24    pub reward: Option<f32>,
25    /// Model route identifier
26    pub route: Option<String>,
27    /// Context identifiers
28    pub context: Vec<String>,
29    /// Example weight for importance sampling
30    pub weight: f32,
31    /// Tags for filtering
32    pub tags: Vec<String>,
33}
34
35impl TrainingExample {
36    /// Create a new training example
37    pub fn new(embedding: Vec<f32>, quality: f32) -> Self {
38        Self {
39            embedding,
40            activations: None,
41            attention: None,
42            quality,
43            reward: None,
44            route: None,
45            context: Vec::new(),
46            weight: 1.0,
47            tags: Vec::new(),
48        }
49    }
50
51    /// Set activations
52    pub fn with_activations(mut self, activations: Vec<f32>) -> Self {
53        self.activations = Some(activations);
54        self
55    }
56
57    /// Set attention
58    pub fn with_attention(mut self, attention: Vec<f32>) -> Self {
59        self.attention = Some(attention);
60        self
61    }
62
63    /// Set reward
64    pub fn with_reward(mut self, reward: f32) -> Self {
65        self.reward = Some(reward);
66        self
67    }
68
69    /// Set route
70    pub fn with_route(mut self, route: impl Into<String>) -> Self {
71        self.route = Some(route.into());
72        self
73    }
74
75    /// Add context
76    pub fn with_context(mut self, ctx: impl Into<String>) -> Self {
77        self.context.push(ctx.into());
78        self
79    }
80
81    /// Set weight
82    pub fn with_weight(mut self, weight: f32) -> Self {
83        self.weight = weight;
84        self
85    }
86
87    /// Add tag
88    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
89        self.tags.push(tag.into());
90        self
91    }
92
93    /// Get activations or default to embedding
94    pub fn get_activations(&self) -> Vec<f32> {
95        self.activations.clone().unwrap_or_else(|| self.embedding.clone())
96    }
97
98    /// Get attention or default
99    pub fn get_attention(&self) -> Vec<f32> {
100        self.attention.clone().unwrap_or_else(|| vec![1.0 / 64.0; 64])
101    }
102
103    /// Get reward or default to quality
104    pub fn get_reward(&self) -> f32 {
105        self.reward.unwrap_or(self.quality)
106    }
107}
108
109/// Batch configuration for training
110#[derive(Clone, Debug, Serialize, Deserialize)]
111pub struct BatchConfig {
112    /// Batch size
113    pub batch_size: usize,
114    /// Shuffle examples
115    pub shuffle: bool,
116    /// Drop incomplete last batch
117    pub drop_last: bool,
118    /// Number of epochs
119    pub epochs: usize,
120    /// Early stopping patience (None = disabled)
121    pub early_stopping_patience: Option<usize>,
122    /// Minimum quality improvement for early stopping
123    pub min_quality_improvement: f32,
124}
125
126impl Default for BatchConfig {
127    fn default() -> Self {
128        Self {
129            batch_size: 32,
130            shuffle: true,
131            drop_last: false,
132            epochs: 1,
133            early_stopping_patience: None,
134            min_quality_improvement: 0.001,
135        }
136    }
137}
138
139impl BatchConfig {
140    /// Create config for single pass (no batching)
141    pub fn single_pass() -> Self {
142        Self {
143            batch_size: usize::MAX,
144            shuffle: false,
145            drop_last: false,
146            epochs: 1,
147            early_stopping_patience: None,
148            min_quality_improvement: 0.0,
149        }
150    }
151
152    /// Create config optimized for size hint
153    pub fn for_data_size(hint: &DataSizeHint) -> Self {
154        match hint {
155            DataSizeHint::Tiny => Self {
156                batch_size: 8,
157                epochs: 10,
158                early_stopping_patience: Some(3),
159                ..Default::default()
160            },
161            DataSizeHint::Small => Self {
162                batch_size: 16,
163                epochs: 5,
164                early_stopping_patience: Some(2),
165                ..Default::default()
166            },
167            DataSizeHint::Medium => Self {
168                batch_size: 32,
169                epochs: 3,
170                early_stopping_patience: Some(2),
171                ..Default::default()
172            },
173            DataSizeHint::Large => Self {
174                batch_size: 64,
175                epochs: 2,
176                ..Default::default()
177            },
178            DataSizeHint::Massive => Self {
179                batch_size: 128,
180                epochs: 1,
181                ..Default::default()
182            },
183        }
184    }
185}
186
187/// Pipeline stage for tracking progress
188#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
189pub enum PipelineStage {
190    /// Not started
191    Idle,
192    /// Loading and preprocessing data
193    Preprocessing,
194    /// Training in progress
195    Training,
196    /// Running validation
197    Validation,
198    /// Extracting patterns
199    PatternExtraction,
200    /// Exporting results
201    Export,
202    /// Completed successfully
203    Completed,
204    /// Failed with error
205    Failed,
206}
207
208impl std::fmt::Display for PipelineStage {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        match self {
211            PipelineStage::Idle => write!(f, "idle"),
212            PipelineStage::Preprocessing => write!(f, "preprocessing"),
213            PipelineStage::Training => write!(f, "training"),
214            PipelineStage::Validation => write!(f, "validation"),
215            PipelineStage::PatternExtraction => write!(f, "pattern_extraction"),
216            PipelineStage::Export => write!(f, "export"),
217            PipelineStage::Completed => write!(f, "completed"),
218            PipelineStage::Failed => write!(f, "failed"),
219        }
220    }
221}
222
223/// Callback trait for training events
224pub trait TrainingCallback: Send + Sync {
225    /// Called when stage changes
226    fn on_stage_change(&self, _stage: &PipelineStage) {}
227
228    /// Called after each batch
229    fn on_batch_complete(&self, _batch_idx: usize, _total_batches: usize, _avg_quality: f32) {}
230
231    /// Called after each epoch
232    fn on_epoch_complete(&self, _epoch: usize, _stats: &EpochStats) {}
233
234    /// Called when training completes
235    fn on_training_complete(&self, _result: &TrainingResult) {}
236
237    /// Called on error
238    fn on_error(&self, _error: &str) {}
239}
240
241/// No-op callback implementation
242pub struct NoOpCallback;
243impl TrainingCallback for NoOpCallback {}
244
245/// Logging callback implementation
246pub struct LoggingCallback {
247    prefix: String,
248}
249
250impl LoggingCallback {
251    /// Create with prefix
252    pub fn new(prefix: impl Into<String>) -> Self {
253        Self { prefix: prefix.into() }
254    }
255}
256
257impl TrainingCallback for LoggingCallback {
258    fn on_stage_change(&self, stage: &PipelineStage) {
259        println!("[{}] Stage: {}", self.prefix, stage);
260    }
261
262    fn on_batch_complete(&self, batch_idx: usize, total_batches: usize, avg_quality: f32) {
263        if batch_idx % 10 == 0 || batch_idx == total_batches - 1 {
264            println!(
265                "[{}] Batch {}/{}: avg_quality={:.4}",
266                self.prefix, batch_idx + 1, total_batches, avg_quality
267            );
268        }
269    }
270
271    fn on_epoch_complete(&self, epoch: usize, stats: &EpochStats) {
272        println!(
273            "[{}] Epoch {}: examples={}, avg_quality={:.4}, duration={:.2}s",
274            self.prefix, epoch + 1, stats.examples_processed, stats.avg_quality, stats.duration_secs
275        );
276    }
277
278    fn on_training_complete(&self, result: &TrainingResult) {
279        println!(
280            "[{}] Training complete: epochs={}, patterns={}, final_quality={:.4}",
281            self.prefix, result.epochs_completed, result.patterns_learned, result.final_avg_quality
282        );
283    }
284
285    fn on_error(&self, error: &str) {
286        eprintln!("[{}] ERROR: {}", self.prefix, error);
287    }
288}
289
290/// Training pipeline for structured training workflows
291pub struct TrainingPipeline {
292    /// Pipeline name
293    name: String,
294    /// SONA engine
295    engine: SonaEngine,
296    /// Batch configuration
297    batch_config: BatchConfig,
298    /// Training method
299    training_method: TrainingMethod,
300    /// Current stage
301    stage: PipelineStage,
302    /// Training examples buffer
303    examples: Vec<TrainingExample>,
304    /// Validation examples
305    validation_examples: Vec<TrainingExample>,
306    /// Training metrics
307    metrics: TrainingMetrics,
308    /// Callback
309    callback: Box<dyn TrainingCallback>,
310    /// Enable pattern extraction after training
311    extract_patterns: bool,
312}
313
314impl TrainingPipeline {
315    /// Create a new training pipeline
316    pub fn new(name: impl Into<String>, config: SonaConfig) -> Self {
317        let name = name.into();
318        Self {
319            name: name.clone(),
320            engine: SonaEngine::with_config(config),
321            batch_config: BatchConfig::default(),
322            training_method: TrainingMethod::default(),
323            stage: PipelineStage::Idle,
324            examples: Vec::new(),
325            validation_examples: Vec::new(),
326            metrics: TrainingMetrics::new(&name),
327            callback: Box::new(NoOpCallback),
328            extract_patterns: true,
329        }
330    }
331
332    /// Create from template
333    pub fn from_template(template: TrainingTemplate) -> Self {
334        let batch_config = BatchConfig::for_data_size(&template.expected_data_size);
335        let mut pipeline = Self::new(&template.name, template.sona_config);
336        pipeline.batch_config = batch_config;
337        pipeline.training_method = template.training_method;
338        pipeline
339    }
340
341    /// Set batch configuration
342    pub fn with_batch_config(mut self, config: BatchConfig) -> Self {
343        self.batch_config = config;
344        self
345    }
346
347    /// Set training method
348    pub fn with_training_method(mut self, method: TrainingMethod) -> Self {
349        self.training_method = method;
350        self
351    }
352
353    /// Set callback
354    pub fn with_callback<C: TrainingCallback + 'static>(mut self, callback: C) -> Self {
355        self.callback = Box::new(callback);
356        self
357    }
358
359    /// Enable/disable pattern extraction
360    pub fn with_pattern_extraction(mut self, enabled: bool) -> Self {
361        self.extract_patterns = enabled;
362        self
363    }
364
365    /// Add a training example
366    pub fn add_example(&mut self, example: TrainingExample) {
367        self.examples.push(example);
368    }
369
370    /// Add multiple training examples
371    pub fn add_examples(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
372        self.examples.extend(examples);
373    }
374
375    /// Add validation example
376    pub fn add_validation_example(&mut self, example: TrainingExample) {
377        self.validation_examples.push(example);
378    }
379
380    /// Get current stage
381    pub fn stage(&self) -> &PipelineStage {
382        &self.stage
383    }
384
385    /// Get number of examples
386    pub fn example_count(&self) -> usize {
387        self.examples.len()
388    }
389
390    /// Get metrics
391    pub fn metrics(&self) -> &TrainingMetrics {
392        &self.metrics
393    }
394
395    /// Get engine reference
396    pub fn engine(&self) -> &SonaEngine {
397        &self.engine
398    }
399
400    /// Get mutable engine reference
401    pub fn engine_mut(&mut self) -> &mut SonaEngine {
402        &mut self.engine
403    }
404
405    /// Run the training pipeline
406    pub fn train(&mut self) -> Result<TrainingResult, String> {
407        let start = Instant::now();
408
409        // Preprocessing
410        self.set_stage(PipelineStage::Preprocessing);
411        self.preprocess()?;
412
413        // Training
414        self.set_stage(PipelineStage::Training);
415        let epoch_stats = self.run_training()?;
416
417        // Validation (if examples provided)
418        if !self.validation_examples.is_empty() {
419            self.set_stage(PipelineStage::Validation);
420            self.run_validation()?;
421        }
422
423        // Pattern extraction
424        if self.extract_patterns {
425            self.set_stage(PipelineStage::PatternExtraction);
426            self.engine.force_learn();
427        }
428
429        self.set_stage(PipelineStage::Completed);
430
431        let result = TrainingResult {
432            pipeline_name: self.name.clone(),
433            epochs_completed: epoch_stats.len(),
434            total_examples: self.metrics.total_examples,
435            patterns_learned: self.metrics.patterns_learned,
436            final_avg_quality: self.metrics.avg_quality(),
437            total_duration_secs: start.elapsed().as_secs_f64(),
438            epoch_stats,
439            validation_quality: self.metrics.validation_quality,
440        };
441
442        self.callback.on_training_complete(&result);
443        Ok(result)
444    }
445
446    /// Set stage and notify callback
447    fn set_stage(&mut self, stage: PipelineStage) {
448        self.stage = stage.clone();
449        self.callback.on_stage_change(&stage);
450    }
451
452    /// Preprocess examples
453    fn preprocess(&mut self) -> Result<(), String> {
454        if self.examples.is_empty() {
455            return Err("No training examples provided".into());
456        }
457
458        // Shuffle if configured
459        if self.batch_config.shuffle {
460            use rand::seq::SliceRandom;
461            let mut rng = rand::thread_rng();
462            self.examples.shuffle(&mut rng);
463        }
464
465        Ok(())
466    }
467
468    /// Run training epochs
469    fn run_training(&mut self) -> Result<Vec<EpochStats>, String> {
470        let mut all_epoch_stats = Vec::new();
471        let mut best_quality = 0.0f32;
472        let mut patience_counter = 0usize;
473
474        for epoch in 0..self.batch_config.epochs {
475            let epoch_start = Instant::now();
476            let mut epoch_quality_sum = 0.0f32;
477            let mut epoch_examples = 0usize;
478
479            // Create batch indices (to avoid borrow checker issues)
480            let batch_size = self.batch_config.batch_size;
481            let total_examples = self.examples.len();
482            let mut batch_indices: Vec<(usize, usize)> = Vec::new();
483            let mut start = 0;
484            while start < total_examples {
485                let end = (start + batch_size).min(total_examples);
486                if end > start && (!self.batch_config.drop_last || end - start == batch_size) {
487                    batch_indices.push((start, end));
488                }
489                start = end;
490            }
491            let total_batches = batch_indices.len();
492
493            for (batch_idx, (start, end)) in batch_indices.into_iter().enumerate() {
494                let batch_quality = self.train_batch_range(start, end)?;
495                let batch_len = end - start;
496                epoch_quality_sum += batch_quality * batch_len as f32;
497                epoch_examples += batch_len;
498
499                self.callback.on_batch_complete(
500                    batch_idx,
501                    total_batches,
502                    epoch_quality_sum / epoch_examples as f32,
503                );
504            }
505
506            let epoch_avg_quality = if epoch_examples > 0 {
507                epoch_quality_sum / epoch_examples as f32
508            } else {
509                0.0
510            };
511
512            let epoch_stats = EpochStats {
513                epoch,
514                examples_processed: epoch_examples,
515                avg_quality: epoch_avg_quality,
516                duration_secs: epoch_start.elapsed().as_secs_f64(),
517            };
518
519            self.callback.on_epoch_complete(epoch, &epoch_stats);
520            all_epoch_stats.push(epoch_stats);
521
522            // Early stopping check
523            if let Some(patience) = self.batch_config.early_stopping_patience {
524                let improvement = epoch_avg_quality - best_quality;
525                if improvement > self.batch_config.min_quality_improvement {
526                    best_quality = epoch_avg_quality;
527                    patience_counter = 0;
528                } else {
529                    patience_counter += 1;
530                    if patience_counter >= patience {
531                        break; // Early stop
532                    }
533                }
534            }
535
536            // Reshuffle for next epoch
537            if self.batch_config.shuffle && epoch + 1 < self.batch_config.epochs {
538                use rand::seq::SliceRandom;
539                let mut rng = rand::thread_rng();
540                self.examples.shuffle(&mut rng);
541            }
542        }
543
544        Ok(all_epoch_stats)
545    }
546
547    /// Train on examples in a range
548    fn train_batch_range(&mut self, start: usize, end: usize) -> Result<f32, String> {
549        let mut quality_sum = 0.0f32;
550        let batch_len = end - start;
551
552        for idx in start..end {
553            let example = &self.examples[idx];
554
555            // Begin trajectory using builder API
556            let mut builder = self.engine.begin_trajectory(example.embedding.clone());
557
558            // Set route
559            if let Some(ref route) = example.route {
560                builder.set_model_route(route);
561            }
562
563            // Add context
564            for ctx in &example.context {
565                builder.add_context(ctx);
566            }
567
568            // Add step
569            builder.add_step(
570                example.get_activations(),
571                example.get_attention(),
572                example.get_reward() * example.weight,
573            );
574
575            // End trajectory
576            self.engine.end_trajectory(builder, example.quality);
577
578            quality_sum += example.quality;
579            self.metrics.total_examples += 1;
580            self.metrics.add_quality_sample(example.quality);
581        }
582
583        // Run tick to process accumulated trajectories
584        self.engine.tick();
585
586        Ok(quality_sum / batch_len as f32)
587    }
588
589    /// Run validation
590    fn run_validation(&mut self) -> Result<(), String> {
591        let mut quality_sum = 0.0f32;
592
593        for example in &self.validation_examples {
594            // Apply learned transformations
595            let mut output = vec![0.0f32; example.embedding.len()];
596            self.engine.apply_micro_lora(&example.embedding, &mut output);
597
598            // In a real scenario, you'd evaluate the model output
599            // For now, we track the expected quality
600            quality_sum += example.quality;
601        }
602
603        self.metrics.validation_quality = Some(
604            quality_sum / self.validation_examples.len() as f32
605        );
606
607        Ok(())
608    }
609
610    /// Clear examples (keep engine state)
611    pub fn clear_examples(&mut self) {
612        self.examples.clear();
613        self.validation_examples.clear();
614    }
615
616    /// Reset pipeline (clear examples and metrics)
617    pub fn reset(&mut self) {
618        self.clear_examples();
619        self.metrics = TrainingMetrics::new(&self.name);
620        self.stage = PipelineStage::Idle;
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627
628    #[test]
629    fn test_training_example() {
630        let example = TrainingExample::new(vec![0.1; 256], 0.8)
631            .with_route("test")
632            .with_context("ctx1")
633            .with_weight(1.5)
634            .with_tag("test");
635
636        assert_eq!(example.quality, 0.8);
637        assert_eq!(example.route, Some("test".into()));
638        assert_eq!(example.weight, 1.5);
639    }
640
641    #[test]
642    fn test_batch_config() {
643        let config = BatchConfig::for_data_size(&DataSizeHint::Small);
644        assert_eq!(config.batch_size, 16);
645        assert_eq!(config.epochs, 5);
646    }
647
648    #[test]
649    fn test_pipeline_creation() {
650        let pipeline = TrainingPipeline::new("test", SonaConfig::default());
651        assert_eq!(pipeline.stage(), &PipelineStage::Idle);
652        assert_eq!(pipeline.example_count(), 0);
653    }
654
655    #[test]
656    fn test_pipeline_from_template() {
657        let template = TrainingTemplate::code_agent()
658            .with_hidden_dim(256);
659        let pipeline = TrainingPipeline::from_template(template);
660        assert_eq!(pipeline.name, "code-agent");
661    }
662
663    #[test]
664    fn test_pipeline_training() {
665        let mut pipeline = TrainingPipeline::new("test", SonaConfig::default())
666            .with_batch_config(BatchConfig {
667                batch_size: 2,
668                epochs: 2,
669                ..Default::default()
670            });
671
672        // Add examples
673        for i in 0..5 {
674            pipeline.add_example(TrainingExample::new(
675                vec![i as f32 * 0.1; 256],
676                0.7 + i as f32 * 0.05,
677            ));
678        }
679
680        let result = pipeline.train().unwrap();
681        assert_eq!(result.epochs_completed, 2);
682        assert!(result.total_examples > 0);
683    }
684
685    #[test]
686    fn test_pipeline_with_validation() {
687        let mut pipeline = TrainingPipeline::new("test", SonaConfig::default())
688            .with_batch_config(BatchConfig::single_pass());
689
690        pipeline.add_example(TrainingExample::new(vec![0.1; 256], 0.8));
691        pipeline.add_validation_example(TrainingExample::new(vec![0.2; 256], 0.9));
692
693        let result = pipeline.train().unwrap();
694        assert!(result.validation_quality.is_some());
695    }
696}