1use crate::{
4 extract_batch, BatchConfig, BatchIterator, CallbackList, Loss, LrScheduler, MetricTracker,
5 Optimizer, TrainResult,
6};
7use scirs2_core::ndarray::{Array, ArrayView, Ix2};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub struct TrainingState {
13 pub epoch: usize,
15 pub batch: usize,
17 pub train_loss: f64,
19 pub val_loss: Option<f64>,
21 pub batch_loss: f64,
23 pub learning_rate: f64,
25 pub metrics: HashMap<String, f64>,
27}
28
29impl Default for TrainingState {
30 fn default() -> Self {
31 Self {
32 epoch: 0,
33 batch: 0,
34 train_loss: 0.0,
35 val_loss: None,
36 batch_loss: 0.0,
37 learning_rate: 0.001,
38 metrics: HashMap::new(),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct TrainerConfig {
46 pub num_epochs: usize,
48 pub batch_config: BatchConfig,
50 pub validate_every_epoch: bool,
52 pub log_frequency: usize,
54 pub use_scheduler: bool,
56}
57
58impl Default for TrainerConfig {
59 fn default() -> Self {
60 Self {
61 num_epochs: 10,
62 batch_config: BatchConfig::default(),
63 validate_every_epoch: true,
64 log_frequency: 100,
65 use_scheduler: false,
66 }
67 }
68}
69
70pub struct Trainer {
72 config: TrainerConfig,
74 loss_fn: Box<dyn Loss>,
76 optimizer: Box<dyn Optimizer>,
78 scheduler: Option<Box<dyn LrScheduler>>,
80 callbacks: CallbackList,
82 metrics: MetricTracker,
84 state: TrainingState,
86}
87
88impl Trainer {
89 pub fn new(
91 config: TrainerConfig,
92 loss_fn: Box<dyn Loss>,
93 optimizer: Box<dyn Optimizer>,
94 ) -> Self {
95 Self {
96 config,
97 loss_fn,
98 optimizer,
99 scheduler: None,
100 callbacks: CallbackList::new(),
101 metrics: MetricTracker::new(),
102 state: TrainingState::default(),
103 }
104 }
105
106 pub fn with_scheduler(mut self, scheduler: Box<dyn LrScheduler>) -> Self {
108 self.scheduler = Some(scheduler);
109 self
110 }
111
112 pub fn with_callbacks(mut self, callbacks: CallbackList) -> Self {
114 self.callbacks = callbacks;
115 self
116 }
117
118 pub fn with_metrics(mut self, metrics: MetricTracker) -> Self {
120 self.metrics = metrics;
121 self
122 }
123
124 pub fn train(
126 &mut self,
127 train_data: &ArrayView<f64, Ix2>,
128 train_targets: &ArrayView<f64, Ix2>,
129 val_data: Option<&ArrayView<f64, Ix2>>,
130 val_targets: Option<&ArrayView<f64, Ix2>>,
131 parameters: &mut HashMap<String, Array<f64, Ix2>>,
132 ) -> TrainResult<TrainingHistory> {
133 let mut history = TrainingHistory::new();
134
135 self.state.learning_rate = self.optimizer.get_lr();
137
138 self.callbacks.on_train_begin(&self.state)?;
140
141 for epoch in 0..self.config.num_epochs {
143 self.state.epoch = epoch;
144
145 self.callbacks.on_epoch_begin(epoch, &self.state)?;
147
148 let epoch_loss = self.train_epoch(train_data, train_targets, parameters)?;
150
151 self.state.train_loss = epoch_loss;
152 history.train_loss.push(epoch_loss);
153
154 if self.config.validate_every_epoch {
156 if let (Some(val_data), Some(val_targets)) = (val_data, val_targets) {
157 let val_loss = self.validate(val_data, val_targets, parameters)?;
158 self.state.val_loss = Some(val_loss);
159 history.val_loss.push(val_loss);
160
161 let predictions = self.forward(val_data, parameters)?;
163 let metrics = self.metrics.compute_all(&predictions.view(), val_targets)?;
164 self.state.metrics = metrics.clone();
165
166 for (name, value) in metrics {
167 history.metrics.entry(name).or_default().push(value);
168 }
169
170 self.callbacks.on_validation_end(&self.state)?;
172 }
173 }
174
175 if self.config.use_scheduler {
177 if let Some(scheduler) = &mut self.scheduler {
178 scheduler.step(&mut *self.optimizer);
179 self.state.learning_rate = self.optimizer.get_lr();
180 }
181 }
182
183 self.callbacks.on_epoch_end(epoch, &self.state)?;
185
186 if self.callbacks.should_stop() {
188 println!("Early stopping triggered at epoch {}", epoch);
189 break;
190 }
191 }
192
193 self.callbacks.on_train_end(&self.state)?;
195
196 Ok(history)
197 }
198
199 fn train_epoch(
201 &mut self,
202 train_data: &ArrayView<f64, Ix2>,
203 train_targets: &ArrayView<f64, Ix2>,
204 parameters: &mut HashMap<String, Array<f64, Ix2>>,
205 ) -> TrainResult<f64> {
206 let mut total_loss = 0.0;
207 let mut num_batches = 0;
208
209 let mut batch_iter =
210 BatchIterator::new(train_data.nrows(), self.config.batch_config.clone());
211
212 while let Some(batch_indices) = batch_iter.next_batch() {
213 self.state.batch = num_batches;
214
215 self.callbacks.on_batch_begin(num_batches, &self.state)?;
217
218 let batch_data = extract_batch(train_data, &batch_indices)?;
220 let batch_targets = extract_batch(train_targets, &batch_indices)?;
221
222 let predictions = self.forward(&batch_data.view(), parameters)?;
224
225 let loss = self
227 .loss_fn
228 .compute(&predictions.view(), &batch_targets.view())?;
229 self.state.batch_loss = loss;
230 total_loss += loss;
231
232 let loss_grad = self
234 .loss_fn
235 .gradient(&predictions.view(), &batch_targets.view())?;
236
237 let gradients = self.backward(&batch_data.view(), &loss_grad.view(), parameters)?;
239
240 self.optimizer.step(parameters, &gradients)?;
242
243 self.callbacks.on_batch_end(num_batches, &self.state)?;
245
246 num_batches += 1;
247
248 if num_batches % self.config.log_frequency == 0 {
250 log::debug!("Batch {}: loss={:.6}", num_batches, loss);
251 }
252 }
253
254 Ok(total_loss / num_batches as f64)
255 }
256
257 fn validate(
259 &mut self,
260 val_data: &ArrayView<f64, Ix2>,
261 val_targets: &ArrayView<f64, Ix2>,
262 parameters: &HashMap<String, Array<f64, Ix2>>,
263 ) -> TrainResult<f64> {
264 let mut total_loss = 0.0;
265 let mut num_batches = 0;
266
267 let mut batch_iter = BatchIterator::new(val_data.nrows(), self.config.batch_config.clone());
268
269 while let Some(batch_indices) = batch_iter.next_batch() {
270 let batch_data = extract_batch(val_data, &batch_indices)?;
271 let batch_targets = extract_batch(val_targets, &batch_indices)?;
272
273 let predictions = self.forward(&batch_data.view(), parameters)?;
274 let loss = self
275 .loss_fn
276 .compute(&predictions.view(), &batch_targets.view())?;
277
278 total_loss += loss;
279 num_batches += 1;
280 }
281
282 Ok(total_loss / num_batches as f64)
283 }
284
285 fn forward(
287 &self,
288 data: &ArrayView<f64, Ix2>,
289 _parameters: &HashMap<String, Array<f64, Ix2>>,
290 ) -> TrainResult<Array<f64, Ix2>> {
291 Ok(data.to_owned())
295 }
296
297 fn backward(
299 &self,
300 _data: &ArrayView<f64, Ix2>,
301 _loss_grad: &ArrayView<f64, Ix2>,
302 parameters: &HashMap<String, Array<f64, Ix2>>,
303 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
304 let mut gradients = HashMap::new();
307
308 for (name, param) in parameters {
309 gradients.insert(name.clone(), Array::zeros(param.raw_dim()));
311 }
312
313 Ok(gradients)
314 }
315
316 pub fn get_state(&self) -> &TrainingState {
318 &self.state
319 }
320
321 pub fn save_checkpoint(
330 &self,
331 path: &std::path::PathBuf,
332 parameters: &HashMap<String, Array<f64, Ix2>>,
333 history: &TrainingHistory,
334 best_val_loss: Option<f64>,
335 ) -> TrainResult<()> {
336 use crate::TrainingCheckpoint;
337
338 let optimizer_state = self.optimizer.state_dict();
340
341 let scheduler_state = self.scheduler.as_ref().map(|s| s.state_dict());
343
344 let checkpoint = TrainingCheckpoint::new(
346 self.state.epoch,
347 parameters,
348 &optimizer_state,
349 scheduler_state,
350 &self.state,
351 &history.train_loss,
352 &history.val_loss,
353 &history.metrics,
354 best_val_loss,
355 );
356
357 checkpoint.save(path)?;
359
360 println!("Training checkpoint saved to {:?}", path);
361 Ok(())
362 }
363
364 #[allow(clippy::type_complexity)]
371 pub fn load_checkpoint(
372 &mut self,
373 path: &std::path::PathBuf,
374 ) -> TrainResult<(HashMap<String, Array<f64, Ix2>>, TrainingHistory, usize)> {
375 use crate::TrainingCheckpoint;
376 use scirs2_core::ndarray::Array;
377
378 let checkpoint = TrainingCheckpoint::load(path)?;
380
381 println!(
382 "Loading checkpoint from epoch {} (val_loss: {:?})",
383 checkpoint.epoch, checkpoint.val_loss
384 );
385
386 let mut parameters = HashMap::new();
388 for (name, values) in checkpoint.parameters {
389 let len = values.len();
393 let array = Array::from_vec(values);
394 parameters.insert(
397 name,
398 array.into_shape_with_order((1, len)).map_err(|e| {
399 crate::TrainError::CheckpointError(format!(
400 "Failed to reshape parameter: {}",
401 e
402 ))
403 })?,
404 );
405 }
406
407 self.optimizer.load_state_dict(checkpoint.optimizer_state);
409
410 if let (Some(scheduler), Some(scheduler_state)) =
412 (self.scheduler.as_mut(), checkpoint.scheduler_state.as_ref())
413 {
414 scheduler.load_state_dict(scheduler_state)?;
415 }
416
417 let history = TrainingHistory {
419 train_loss: checkpoint.train_loss_history,
420 val_loss: checkpoint.val_loss_history,
421 metrics: checkpoint.metrics_history,
422 };
423
424 self.state.epoch = checkpoint.epoch;
426 self.state.train_loss = checkpoint.train_loss;
427 self.state.val_loss = checkpoint.val_loss;
428 self.state.learning_rate = checkpoint.learning_rate;
429
430 println!(
431 "Checkpoint loaded successfully. Resuming from epoch {}",
432 checkpoint.epoch + 1
433 );
434
435 Ok((parameters, history, checkpoint.epoch))
436 }
437
438 #[allow(clippy::type_complexity)]
442 pub fn train_from_checkpoint(
443 &mut self,
444 checkpoint_path: &std::path::PathBuf,
445 train_data: &ArrayView<f64, Ix2>,
446 train_targets: &ArrayView<f64, Ix2>,
447 val_data: Option<&ArrayView<f64, Ix2>>,
448 val_targets: Option<&ArrayView<f64, Ix2>>,
449 ) -> TrainResult<(HashMap<String, Array<f64, Ix2>>, TrainingHistory)> {
450 let (mut parameters, mut history, start_epoch) = self.load_checkpoint(checkpoint_path)?;
452
453 let remaining_epochs = self.config.num_epochs.saturating_sub(start_epoch + 1);
455 let original_num_epochs = self.config.num_epochs;
456 self.config.num_epochs = remaining_epochs;
457
458 println!(
459 "Resuming training: {} epochs completed, {} epochs remaining",
460 start_epoch + 1,
461 remaining_epochs
462 );
463
464 let continued_history = self.train(
466 train_data,
467 train_targets,
468 val_data,
469 val_targets,
470 &mut parameters,
471 )?;
472
473 self.config.num_epochs = original_num_epochs;
475
476 history.train_loss.extend(continued_history.train_loss);
478 history.val_loss.extend(continued_history.val_loss);
479 for (metric_name, values) in continued_history.metrics {
480 history
481 .metrics
482 .entry(metric_name)
483 .or_default()
484 .extend(values);
485 }
486
487 Ok((parameters, history))
488 }
489}
490
491#[derive(Debug, Clone)]
493pub struct TrainingHistory {
494 pub train_loss: Vec<f64>,
496 pub val_loss: Vec<f64>,
498 pub metrics: HashMap<String, Vec<f64>>,
500}
501
502impl TrainingHistory {
503 pub fn new() -> Self {
505 Self {
506 train_loss: Vec::new(),
507 val_loss: Vec::new(),
508 metrics: HashMap::new(),
509 }
510 }
511
512 pub fn best_val_loss(&self) -> Option<(usize, f64)> {
514 self.val_loss
515 .iter()
516 .enumerate()
517 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
518 .map(|(idx, &loss)| (idx, loss))
519 }
520
521 pub fn get_metric_history(&self, metric_name: &str) -> Option<&Vec<f64>> {
523 self.metrics.get(metric_name)
524 }
525}
526
527impl Default for TrainingHistory {
528 fn default() -> Self {
529 Self::new()
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use crate::{MseLoss, OptimizerConfig, SgdOptimizer};
537
538 #[test]
539 fn test_trainer_creation() {
540 let config = TrainerConfig {
541 num_epochs: 5,
542 ..Default::default()
543 };
544
545 let loss = Box::new(MseLoss);
546 let optimizer = Box::new(SgdOptimizer::new(OptimizerConfig::default()));
547
548 let trainer = Trainer::new(config, loss, optimizer);
549 assert_eq!(trainer.config.num_epochs, 5);
550 }
551
552 #[test]
553 fn test_training_history() {
554 let mut history = TrainingHistory::new();
555 history.train_loss.push(1.0);
556 history.train_loss.push(0.8);
557 history.train_loss.push(0.6);
558
559 history.val_loss.push(1.2);
560 history.val_loss.push(0.9);
561 history.val_loss.push(0.7);
562
563 let (best_epoch, best_loss) = history.best_val_loss().unwrap();
564 assert_eq!(best_epoch, 2);
565 assert_eq!(best_loss, 0.7);
566 }
567
568 #[test]
569 fn test_training_state() {
570 let state = TrainingState {
571 epoch: 5,
572 batch: 100,
573 train_loss: 0.5,
574 val_loss: Some(0.6),
575 batch_loss: 0.4,
576 learning_rate: 0.001,
577 metrics: HashMap::new(),
578 };
579
580 assert_eq!(state.epoch, 5);
581 assert_eq!(state.batch, 100);
582 assert!((state.train_loss - 0.5).abs() < 1e-6);
583 }
584}