Skip to main content

tensorlogic_train/
trainer.rs

1//! Main training loop implementation.
2
3use crate::{
4    extract_batch, BatchConfig, BatchIterator, CallbackList, Loss, LrScheduler, MetricTracker,
5    Optimizer, TrainResult,
6};
7use scirs2_core::ndarray::{Array, ArrayView, Ix2};
8use std::collections::HashMap;
9
10/// Training state passed to callbacks.
11#[derive(Debug, Clone)]
12pub struct TrainingState {
13    /// Current epoch number.
14    pub epoch: usize,
15    /// Current batch number within epoch.
16    pub batch: usize,
17    /// Training loss for current epoch.
18    pub train_loss: f64,
19    /// Validation loss (if validation is performed).
20    pub val_loss: Option<f64>,
21    /// Loss for current batch.
22    pub batch_loss: f64,
23    /// Current learning rate.
24    pub learning_rate: f64,
25    /// Additional metrics.
26    pub metrics: HashMap<String, f64>,
27}
28
29impl Default for TrainingState {
30    fn default() -> Self {
31        Self {
32            epoch: 0,
33            batch: 0,
34            train_loss: 0.0,
35            val_loss: None,
36            batch_loss: 0.0,
37            learning_rate: 0.001,
38            metrics: HashMap::new(),
39        }
40    }
41}
42
43/// Configuration for training.
44#[derive(Debug, Clone)]
45pub struct TrainerConfig {
46    /// Number of epochs to train.
47    pub num_epochs: usize,
48    /// Batch configuration.
49    pub batch_config: BatchConfig,
50    /// Whether to validate after each epoch.
51    pub validate_every_epoch: bool,
52    /// Frequency of logging (every N batches).
53    pub log_frequency: usize,
54    /// Whether to use learning rate scheduler.
55    pub use_scheduler: bool,
56}
57
58impl Default for TrainerConfig {
59    fn default() -> Self {
60        Self {
61            num_epochs: 10,
62            batch_config: BatchConfig::default(),
63            validate_every_epoch: true,
64            log_frequency: 100,
65            use_scheduler: false,
66        }
67    }
68}
69
70/// Main trainer for model training.
71pub struct Trainer {
72    /// Configuration.
73    config: TrainerConfig,
74    /// Loss function.
75    loss_fn: Box<dyn Loss>,
76    /// Optimizer.
77    optimizer: Box<dyn Optimizer>,
78    /// Optional learning rate scheduler.
79    scheduler: Option<Box<dyn LrScheduler>>,
80    /// Callbacks.
81    callbacks: CallbackList,
82    /// Metric tracker.
83    metrics: MetricTracker,
84    /// Training state.
85    state: TrainingState,
86}
87
88impl Trainer {
89    /// Create a new trainer.
90    pub fn new(
91        config: TrainerConfig,
92        loss_fn: Box<dyn Loss>,
93        optimizer: Box<dyn Optimizer>,
94    ) -> Self {
95        Self {
96            config,
97            loss_fn,
98            optimizer,
99            scheduler: None,
100            callbacks: CallbackList::new(),
101            metrics: MetricTracker::new(),
102            state: TrainingState::default(),
103        }
104    }
105
106    /// Set learning rate scheduler.
107    pub fn with_scheduler(mut self, scheduler: Box<dyn LrScheduler>) -> Self {
108        self.scheduler = Some(scheduler);
109        self
110    }
111
112    /// Set callbacks.
113    pub fn with_callbacks(mut self, callbacks: CallbackList) -> Self {
114        self.callbacks = callbacks;
115        self
116    }
117
118    /// Set metrics.
119    pub fn with_metrics(mut self, metrics: MetricTracker) -> Self {
120        self.metrics = metrics;
121        self
122    }
123
124    /// Train the model.
125    pub fn train(
126        &mut self,
127        train_data: &ArrayView<f64, Ix2>,
128        train_targets: &ArrayView<f64, Ix2>,
129        val_data: Option<&ArrayView<f64, Ix2>>,
130        val_targets: Option<&ArrayView<f64, Ix2>>,
131        parameters: &mut HashMap<String, Array<f64, Ix2>>,
132    ) -> TrainResult<TrainingHistory> {
133        let mut history = TrainingHistory::new();
134
135        // Initialize state
136        self.state.learning_rate = self.optimizer.get_lr();
137
138        // Call on_train_begin
139        self.callbacks.on_train_begin(&self.state)?;
140
141        // Training loop
142        for epoch in 0..self.config.num_epochs {
143            self.state.epoch = epoch;
144
145            // Call on_epoch_begin
146            self.callbacks.on_epoch_begin(epoch, &self.state)?;
147
148            // Train one epoch
149            let epoch_loss = self.train_epoch(train_data, train_targets, parameters)?;
150
151            self.state.train_loss = epoch_loss;
152            history.train_loss.push(epoch_loss);
153
154            // Validation
155            if self.config.validate_every_epoch {
156                if let (Some(val_data), Some(val_targets)) = (val_data, val_targets) {
157                    let val_loss = self.validate(val_data, val_targets, parameters)?;
158                    self.state.val_loss = Some(val_loss);
159                    history.val_loss.push(val_loss);
160
161                    // Compute metrics
162                    let predictions = self.forward(val_data, parameters)?;
163                    let metrics = self.metrics.compute_all(&predictions.view(), val_targets)?;
164                    self.state.metrics = metrics.clone();
165
166                    for (name, value) in metrics {
167                        history.metrics.entry(name).or_default().push(value);
168                    }
169
170                    // Call on_validation_end
171                    self.callbacks.on_validation_end(&self.state)?;
172                }
173            }
174
175            // Update learning rate
176            if self.config.use_scheduler {
177                if let Some(scheduler) = &mut self.scheduler {
178                    scheduler.step(&mut *self.optimizer);
179                    self.state.learning_rate = self.optimizer.get_lr();
180                }
181            }
182
183            // Call on_epoch_end
184            self.callbacks.on_epoch_end(epoch, &self.state)?;
185
186            // Check for early stopping
187            if self.callbacks.should_stop() {
188                println!("Early stopping triggered at epoch {}", epoch);
189                break;
190            }
191        }
192
193        // Call on_train_end
194        self.callbacks.on_train_end(&self.state)?;
195
196        Ok(history)
197    }
198
199    /// Train for one epoch.
200    fn train_epoch(
201        &mut self,
202        train_data: &ArrayView<f64, Ix2>,
203        train_targets: &ArrayView<f64, Ix2>,
204        parameters: &mut HashMap<String, Array<f64, Ix2>>,
205    ) -> TrainResult<f64> {
206        let mut total_loss = 0.0;
207        let mut num_batches = 0;
208
209        let mut batch_iter =
210            BatchIterator::new(train_data.nrows(), self.config.batch_config.clone());
211
212        while let Some(batch_indices) = batch_iter.next_batch() {
213            self.state.batch = num_batches;
214
215            // Call on_batch_begin
216            self.callbacks.on_batch_begin(num_batches, &self.state)?;
217
218            // Extract batch
219            let batch_data = extract_batch(train_data, &batch_indices)?;
220            let batch_targets = extract_batch(train_targets, &batch_indices)?;
221
222            // Forward pass
223            let predictions = self.forward(&batch_data.view(), parameters)?;
224
225            // Compute loss
226            let loss = self
227                .loss_fn
228                .compute(&predictions.view(), &batch_targets.view())?;
229            self.state.batch_loss = loss;
230            total_loss += loss;
231
232            // Compute gradients
233            let loss_grad = self
234                .loss_fn
235                .gradient(&predictions.view(), &batch_targets.view())?;
236
237            // Backward pass (simplified - in real implementation would use autodiff)
238            let gradients = self.backward(&batch_data.view(), &loss_grad.view(), parameters)?;
239
240            // Update parameters
241            self.optimizer.step(parameters, &gradients)?;
242
243            // Call on_batch_end
244            self.callbacks.on_batch_end(num_batches, &self.state)?;
245
246            num_batches += 1;
247
248            // Logging
249            if num_batches % self.config.log_frequency == 0 {
250                log::debug!("Batch {}: loss={:.6}", num_batches, loss);
251            }
252        }
253
254        Ok(total_loss / num_batches as f64)
255    }
256
257    /// Validate the model.
258    fn validate(
259        &mut self,
260        val_data: &ArrayView<f64, Ix2>,
261        val_targets: &ArrayView<f64, Ix2>,
262        parameters: &HashMap<String, Array<f64, Ix2>>,
263    ) -> TrainResult<f64> {
264        let mut total_loss = 0.0;
265        let mut num_batches = 0;
266
267        let mut batch_iter = BatchIterator::new(val_data.nrows(), self.config.batch_config.clone());
268
269        while let Some(batch_indices) = batch_iter.next_batch() {
270            let batch_data = extract_batch(val_data, &batch_indices)?;
271            let batch_targets = extract_batch(val_targets, &batch_indices)?;
272
273            let predictions = self.forward(&batch_data.view(), parameters)?;
274            let loss = self
275                .loss_fn
276                .compute(&predictions.view(), &batch_targets.view())?;
277
278            total_loss += loss;
279            num_batches += 1;
280        }
281
282        Ok(total_loss / num_batches as f64)
283    }
284
285    /// Forward pass (placeholder - actual implementation depends on model).
286    fn forward(
287        &self,
288        data: &ArrayView<f64, Ix2>,
289        _parameters: &HashMap<String, Array<f64, Ix2>>,
290    ) -> TrainResult<Array<f64, Ix2>> {
291        // This is a placeholder implementation
292        // In a real scenario, this would depend on the model architecture
293        // For now, return input as output
294        Ok(data.to_owned())
295    }
296
297    /// Backward pass (placeholder - actual implementation would use autodiff).
298    fn backward(
299        &self,
300        _data: &ArrayView<f64, Ix2>,
301        _loss_grad: &ArrayView<f64, Ix2>,
302        parameters: &HashMap<String, Array<f64, Ix2>>,
303    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
304        // This is a placeholder implementation
305        // In a real scenario, this would use automatic differentiation
306        let mut gradients = HashMap::new();
307
308        for (name, param) in parameters {
309            // Simple gradient (placeholder)
310            gradients.insert(name.clone(), Array::zeros(param.raw_dim()));
311        }
312
313        Ok(gradients)
314    }
315
316    /// Get current training state.
317    pub fn get_state(&self) -> &TrainingState {
318        &self.state
319    }
320
321    /// Save a complete training checkpoint.
322    ///
323    /// This saves all state needed to resume training, including:
324    /// - Model parameters
325    /// - Optimizer state
326    /// - Scheduler state (if present)
327    /// - Training history
328    /// - Current epoch and losses
329    pub fn save_checkpoint(
330        &self,
331        path: &std::path::PathBuf,
332        parameters: &HashMap<String, Array<f64, Ix2>>,
333        history: &TrainingHistory,
334        best_val_loss: Option<f64>,
335    ) -> TrainResult<()> {
336        use crate::TrainingCheckpoint;
337
338        // Get optimizer state
339        let optimizer_state = self.optimizer.state_dict();
340
341        // Get scheduler state if present
342        let scheduler_state = self.scheduler.as_ref().map(|s| s.state_dict());
343
344        // Create checkpoint
345        let checkpoint = TrainingCheckpoint::new(
346            self.state.epoch,
347            parameters,
348            &optimizer_state,
349            scheduler_state,
350            &self.state,
351            &history.train_loss,
352            &history.val_loss,
353            &history.metrics,
354            best_val_loss,
355        );
356
357        // Save to file
358        checkpoint.save(path)?;
359
360        println!("Training checkpoint saved to {:?}", path);
361        Ok(())
362    }
363
364    /// Resume training from a checkpoint.
365    ///
366    /// This restores all training state including parameters, optimizer state,
367    /// and history. Training will resume from the saved epoch.
368    ///
369    /// Returns the restored parameters, history, and starting epoch.
370    #[allow(clippy::type_complexity)]
371    pub fn load_checkpoint(
372        &mut self,
373        path: &std::path::PathBuf,
374    ) -> TrainResult<(HashMap<String, Array<f64, Ix2>>, TrainingHistory, usize)> {
375        use crate::TrainingCheckpoint;
376        use scirs2_core::ndarray::Array;
377
378        // Load checkpoint
379        let checkpoint = TrainingCheckpoint::load(path)?;
380
381        println!(
382            "Loading checkpoint from epoch {} (val_loss: {:?})",
383            checkpoint.epoch, checkpoint.val_loss
384        );
385
386        // Restore parameters
387        let mut parameters = HashMap::new();
388        for (name, values) in checkpoint.parameters {
389            // Note: We need to know the shape to reconstruct the array
390            // For now, we'll create a dummy shape. In practice, this would need
391            // to be handled by the model's load_state_dict method
392            let len = values.len();
393            let array = Array::from_vec(values);
394            // This is a limitation - we need shape information
395            // In real usage, the model should handle this via its load_state_dict
396            parameters.insert(
397                name,
398                array.into_shape_with_order((1, len)).map_err(|e| {
399                    crate::TrainError::CheckpointError(format!(
400                        "Failed to reshape parameter: {}",
401                        e
402                    ))
403                })?,
404            );
405        }
406
407        // Restore optimizer state
408        self.optimizer.load_state_dict(checkpoint.optimizer_state);
409
410        // Restore scheduler state
411        if let (Some(scheduler), Some(scheduler_state)) =
412            (self.scheduler.as_mut(), checkpoint.scheduler_state.as_ref())
413        {
414            scheduler.load_state_dict(scheduler_state)?;
415        }
416
417        // Restore training history
418        let history = TrainingHistory {
419            train_loss: checkpoint.train_loss_history,
420            val_loss: checkpoint.val_loss_history,
421            metrics: checkpoint.metrics_history,
422        };
423
424        // Restore training state
425        self.state.epoch = checkpoint.epoch;
426        self.state.train_loss = checkpoint.train_loss;
427        self.state.val_loss = checkpoint.val_loss;
428        self.state.learning_rate = checkpoint.learning_rate;
429
430        println!(
431            "Checkpoint loaded successfully. Resuming from epoch {}",
432            checkpoint.epoch + 1
433        );
434
435        Ok((parameters, history, checkpoint.epoch))
436    }
437
438    /// Train the model starting from a checkpoint.
439    ///
440    /// This is a convenience method that loads a checkpoint and continues training.
441    #[allow(clippy::type_complexity)]
442    pub fn train_from_checkpoint(
443        &mut self,
444        checkpoint_path: &std::path::PathBuf,
445        train_data: &ArrayView<f64, Ix2>,
446        train_targets: &ArrayView<f64, Ix2>,
447        val_data: Option<&ArrayView<f64, Ix2>>,
448        val_targets: Option<&ArrayView<f64, Ix2>>,
449    ) -> TrainResult<(HashMap<String, Array<f64, Ix2>>, TrainingHistory)> {
450        // Load checkpoint
451        let (mut parameters, mut history, start_epoch) = self.load_checkpoint(checkpoint_path)?;
452
453        // Adjust config to continue from checkpoint epoch
454        let remaining_epochs = self.config.num_epochs.saturating_sub(start_epoch + 1);
455        let original_num_epochs = self.config.num_epochs;
456        self.config.num_epochs = remaining_epochs;
457
458        println!(
459            "Resuming training: {} epochs completed, {} epochs remaining",
460            start_epoch + 1,
461            remaining_epochs
462        );
463
464        // Continue training
465        let continued_history = self.train(
466            train_data,
467            train_targets,
468            val_data,
469            val_targets,
470            &mut parameters,
471        )?;
472
473        // Restore original config
474        self.config.num_epochs = original_num_epochs;
475
476        // Merge histories
477        history.train_loss.extend(continued_history.train_loss);
478        history.val_loss.extend(continued_history.val_loss);
479        for (metric_name, values) in continued_history.metrics {
480            history
481                .metrics
482                .entry(metric_name)
483                .or_default()
484                .extend(values);
485        }
486
487        Ok((parameters, history))
488    }
489}
490
491/// Training history containing losses and metrics.
492#[derive(Debug, Clone)]
493pub struct TrainingHistory {
494    /// Training loss per epoch.
495    pub train_loss: Vec<f64>,
496    /// Validation loss per epoch.
497    pub val_loss: Vec<f64>,
498    /// Metrics per epoch.
499    pub metrics: HashMap<String, Vec<f64>>,
500}
501
502impl TrainingHistory {
503    /// Create a new training history.
504    pub fn new() -> Self {
505        Self {
506            train_loss: Vec::new(),
507            val_loss: Vec::new(),
508            metrics: HashMap::new(),
509        }
510    }
511
512    /// Get best validation loss and corresponding epoch.
513    pub fn best_val_loss(&self) -> Option<(usize, f64)> {
514        self.val_loss
515            .iter()
516            .enumerate()
517            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
518            .map(|(idx, &loss)| (idx, loss))
519    }
520
521    /// Get metric history.
522    pub fn get_metric_history(&self, metric_name: &str) -> Option<&Vec<f64>> {
523        self.metrics.get(metric_name)
524    }
525}
526
527impl Default for TrainingHistory {
528    fn default() -> Self {
529        Self::new()
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use crate::{MseLoss, OptimizerConfig, SgdOptimizer};
537
538    #[test]
539    fn test_trainer_creation() {
540        let config = TrainerConfig {
541            num_epochs: 5,
542            ..Default::default()
543        };
544
545        let loss = Box::new(MseLoss);
546        let optimizer = Box::new(SgdOptimizer::new(OptimizerConfig::default()));
547
548        let trainer = Trainer::new(config, loss, optimizer);
549        assert_eq!(trainer.config.num_epochs, 5);
550    }
551
552    #[test]
553    fn test_training_history() {
554        let mut history = TrainingHistory::new();
555        history.train_loss.push(1.0);
556        history.train_loss.push(0.8);
557        history.train_loss.push(0.6);
558
559        history.val_loss.push(1.2);
560        history.val_loss.push(0.9);
561        history.val_loss.push(0.7);
562
563        let (best_epoch, best_loss) = history.best_val_loss().unwrap();
564        assert_eq!(best_epoch, 2);
565        assert_eq!(best_loss, 0.7);
566    }
567
568    #[test]
569    fn test_training_state() {
570        let state = TrainingState {
571            epoch: 5,
572            batch: 100,
573            train_loss: 0.5,
574            val_loss: Some(0.6),
575            batch_loss: 0.4,
576            learning_rate: 0.001,
577            metrics: HashMap::new(),
578        };
579
580        assert_eq!(state.epoch, 5);
581        assert_eq!(state.batch, 100);
582        assert!((state.train_loss - 0.5).abs() < 1e-6);
583    }
584}