Skip to main content

trustformers_training/
simplified_trainer.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use crate::losses::Loss;
8use crate::metrics::{Metric, MetricCollection};
9
10/// Simplified trainer interface for easy model training
11#[allow(dead_code)]
12pub struct SimpleTrainer<M, D, L> {
13    model: Arc<RwLock<M>>,
14    #[allow(dead_code)]
15    train_dataset: D,
16    eval_dataset: Option<D>,
17    loss_fn: L,
18    config: SimpleTrainingConfig,
19    callbacks: Vec<Box<dyn SimpleCallback>>,
20    metrics: MetricCollection,
21    state: TrainingState,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SimpleTrainingConfig {
26    pub learning_rate: f64,
27    pub batch_size: usize,
28    pub num_epochs: u32,
29    pub eval_steps: Option<u32>,
30    pub save_steps: Option<u32>,
31    pub logging_steps: u32,
32    pub warmup_steps: u32,
33    pub max_grad_norm: Option<f64>,
34    pub seed: Option<u64>,
35    pub output_dir: String,
36    pub early_stopping_patience: Option<u32>,
37    pub early_stopping_threshold: Option<f64>,
38}
39
40impl Default for SimpleTrainingConfig {
41    fn default() -> Self {
42        Self {
43            learning_rate: 3e-4,
44            batch_size: 32,
45            num_epochs: 3,
46            eval_steps: Some(500),
47            save_steps: Some(1000),
48            logging_steps: 100,
49            warmup_steps: 500,
50            max_grad_norm: Some(1.0),
51            seed: Some(42),
52            output_dir: "./output".to_string(),
53            early_stopping_patience: None,
54            early_stopping_threshold: None,
55        }
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct TrainingState {
61    pub epoch: u32,
62    pub global_step: u32,
63    pub train_loss: f64,
64    pub eval_loss: Option<f64>,
65    pub learning_rate: f64,
66    pub is_training: bool,
67    pub best_metric: Option<f64>,
68    pub patience_counter: u32,
69    pub should_stop: bool,
70    pub start_time: Option<Instant>,
71    pub metrics: HashMap<String, f64>,
72}
73
74impl Default for TrainingState {
75    fn default() -> Self {
76        Self {
77            epoch: 0,
78            global_step: 0,
79            train_loss: 0.0,
80            eval_loss: None,
81            learning_rate: 0.0,
82            is_training: false,
83            best_metric: None,
84            patience_counter: 0,
85            should_stop: false,
86            start_time: None,
87            metrics: HashMap::new(),
88        }
89    }
90}
91
92/// Simplified callback interface
93pub trait SimpleCallback: Send + Sync {
94    fn on_train_begin(
95        &mut self,
96        _state: &TrainingState,
97        _config: &SimpleTrainingConfig,
98    ) -> Result<()> {
99        Ok(())
100    }
101
102    fn on_train_end(&mut self, _state: &TrainingState) -> Result<()> {
103        Ok(())
104    }
105
106    fn on_epoch_begin(&mut self, _epoch: u32, _state: &TrainingState) -> Result<()> {
107        Ok(())
108    }
109
110    fn on_epoch_end(&mut self, _epoch: u32, _state: &TrainingState) -> Result<()> {
111        Ok(())
112    }
113
114    fn on_step_begin(&mut self, _step: u32, _state: &TrainingState) -> Result<()> {
115        Ok(())
116    }
117
118    fn on_step_end(&mut self, _step: u32, _state: &TrainingState) -> Result<()> {
119        Ok(())
120    }
121
122    fn on_evaluate_begin(&mut self, _state: &TrainingState) -> Result<()> {
123        Ok(())
124    }
125
126    fn on_evaluate_end(&mut self, _state: &TrainingState) -> Result<()> {
127        Ok(())
128    }
129
130    fn on_save(&mut self, _state: &TrainingState) -> Result<()> {
131        Ok(())
132    }
133
134    fn on_log(&mut self, _logs: &HashMap<String, f64>, _state: &TrainingState) -> Result<()> {
135        Ok(())
136    }
137}
138
139/// Built-in logging callback
140pub struct LoggingCallback {
141    log_level: LogLevel,
142}
143
144#[derive(Debug, Clone)]
145pub enum LogLevel {
146    Debug,
147    Info,
148    Warning,
149    Error,
150}
151
152impl LoggingCallback {
153    pub fn new(log_level: LogLevel) -> Self {
154        Self { log_level }
155    }
156}
157
158impl SimpleCallback for LoggingCallback {
159    fn on_train_begin(
160        &mut self,
161        _state: &TrainingState,
162        config: &SimpleTrainingConfig,
163    ) -> Result<()> {
164        println!(
165            "🚀 Starting training with config: learning_rate={}, batch_size={}, epochs={}",
166            config.learning_rate, config.batch_size, config.num_epochs
167        );
168        Ok(())
169    }
170
171    fn on_epoch_begin(&mut self, epoch: u32, _state: &TrainingState) -> Result<()> {
172        println!("📚 Starting epoch {}", epoch);
173        Ok(())
174    }
175
176    fn on_epoch_end(&mut self, epoch: u32, state: &TrainingState) -> Result<()> {
177        let eval_info = if let Some(eval_loss) = state.eval_loss {
178            format!(", eval_loss: {:.4}", eval_loss)
179        } else {
180            String::new()
181        };
182
183        println!(
184            "✅ Epoch {} completed - train_loss: {:.4}{}",
185            epoch, state.train_loss, eval_info
186        );
187        Ok(())
188    }
189
190    fn on_log(&mut self, logs: &HashMap<String, f64>, state: &TrainingState) -> Result<()> {
191        if matches!(self.log_level, LogLevel::Debug) {
192            println!("📊 Step {} - {:?}", state.global_step, logs);
193        }
194        Ok(())
195    }
196
197    fn on_train_end(&mut self, state: &TrainingState) -> Result<()> {
198        if let Some(start_time) = state.start_time {
199            let duration = start_time.elapsed();
200            println!("🎉 Training completed in {:.2}s", duration.as_secs_f64());
201        }
202        Ok(())
203    }
204}
205
206/// Progress bar callback
207pub struct ProgressCallback {
208    total_steps: u32,
209    current_step: u32,
210    bar_width: usize,
211}
212
213impl ProgressCallback {
214    pub fn new(total_steps: u32) -> Self {
215        Self {
216            total_steps,
217            current_step: 0,
218            bar_width: 50,
219        }
220    }
221
222    fn update_progress(&mut self, step: u32) {
223        self.current_step = step;
224        let progress = (step as f64 / self.total_steps as f64).min(1.0);
225        let filled = (progress * self.bar_width as f64) as usize;
226        let empty = self.bar_width - filled;
227
228        let bar = format!("[{}{}]", "█".repeat(filled), "░".repeat(empty));
229
230        print!(
231            "\r{} {:.1}% ({}/{})",
232            bar,
233            progress * 100.0,
234            step,
235            self.total_steps
236        );
237        if step >= self.total_steps {
238            println!();
239        }
240    }
241}
242
243impl SimpleCallback for ProgressCallback {
244    fn on_step_end(&mut self, step: u32, _state: &TrainingState) -> Result<()> {
245        self.update_progress(step);
246        Ok(())
247    }
248}
249
250/// Early stopping callback
251pub struct EarlyStoppingCallback {
252    monitor: String,
253    patience: u32,
254    threshold: f64,
255    mode: EarlyStoppingMode,
256    best_value: Option<f64>,
257    patience_counter: u32,
258}
259
260#[derive(Debug, Clone)]
261pub enum EarlyStoppingMode {
262    Min,
263    Max,
264}
265
266impl EarlyStoppingCallback {
267    pub fn new(monitor: String, patience: u32, threshold: f64, mode: EarlyStoppingMode) -> Self {
268        Self {
269            monitor,
270            patience,
271            threshold,
272            mode,
273            best_value: None,
274            patience_counter: 0,
275        }
276    }
277}
278
279impl SimpleCallback for EarlyStoppingCallback {
280    fn on_evaluate_end(&mut self, state: &TrainingState) -> Result<()> {
281        if let Some(current_value) = state.metrics.get(&self.monitor) {
282            let improved = match self.best_value {
283                None => true,
284                Some(best) => match self.mode {
285                    EarlyStoppingMode::Min => *current_value < best - self.threshold,
286                    EarlyStoppingMode::Max => *current_value > best + self.threshold,
287                },
288            };
289
290            if improved {
291                self.best_value = Some(*current_value);
292                self.patience_counter = 0;
293                println!("đŸŽ¯ New best {}: {:.4}", self.monitor, current_value);
294            } else {
295                self.patience_counter += 1;
296                if self.patience_counter >= self.patience {
297                    println!(
298                        "âšī¸  Early stopping triggered. No improvement in {} for {} epochs",
299                        self.monitor, self.patience
300                    );
301                    // In a real implementation, we would set a flag to stop training
302                }
303            }
304        }
305        Ok(())
306    }
307}
308
309/// Model checkpoint callback
310pub struct CheckpointCallback {
311    save_dir: String,
312    save_best_only: bool,
313    monitor: Option<String>,
314    mode: EarlyStoppingMode,
315    best_value: Option<f64>,
316}
317
318impl CheckpointCallback {
319    pub fn new(save_dir: String, save_best_only: bool, monitor: Option<String>) -> Self {
320        Self {
321            save_dir,
322            save_best_only,
323            monitor,
324            mode: EarlyStoppingMode::Min,
325            best_value: None,
326        }
327    }
328}
329
330impl SimpleCallback for CheckpointCallback {
331    fn on_save(&mut self, state: &TrainingState) -> Result<()> {
332        let should_save = if self.save_best_only {
333            if let (Some(_monitor), Some(current_value)) = (
334                &self.monitor,
335                self.monitor.as_ref().and_then(|m| state.metrics.get(m.as_str())),
336            ) {
337                let is_best = match self.best_value {
338                    None => true,
339                    Some(best) => match self.mode {
340                        EarlyStoppingMode::Min => *current_value < best,
341                        EarlyStoppingMode::Max => *current_value > best,
342                    },
343                };
344
345                if is_best {
346                    self.best_value = Some(*current_value);
347                }
348                is_best
349            } else {
350                true // Save if no monitor specified
351            }
352        } else {
353            true // Always save if not save_best_only
354        };
355
356        if should_save {
357            let checkpoint_path = format!("{}/checkpoint-{}", self.save_dir, state.global_step);
358            println!("💾 Saving checkpoint to {}", checkpoint_path);
359            // In a real implementation, would save model state here
360        }
361
362        Ok(())
363    }
364}
365
366/// Metrics tracking callback
367pub struct MetricsCallback {
368    tracked_metrics: Vec<String>,
369    history: HashMap<String, Vec<f64>>,
370}
371
372impl MetricsCallback {
373    pub fn new(tracked_metrics: Vec<String>) -> Self {
374        Self {
375            tracked_metrics,
376            history: HashMap::new(),
377        }
378    }
379
380    pub fn get_history(&self, metric: &str) -> Option<&Vec<f64>> {
381        self.history.get(metric)
382    }
383
384    pub fn get_all_history(&self) -> &HashMap<String, Vec<f64>> {
385        &self.history
386    }
387}
388
389impl SimpleCallback for MetricsCallback {
390    fn on_log(&mut self, logs: &HashMap<String, f64>, _state: &TrainingState) -> Result<()> {
391        for metric in &self.tracked_metrics {
392            if let Some(value) = logs.get(metric) {
393                self.history.entry(metric.clone()).or_default().push(*value);
394            }
395        }
396        Ok(())
397    }
398}
399
400impl<M, D, L> SimpleTrainer<M, D, L>
401where
402    M: Send + Sync,
403    D: Clone,
404    L: Loss + Send + Sync,
405{
406    pub fn new(model: M, train_dataset: D, loss_fn: L, config: SimpleTrainingConfig) -> Self {
407        Self {
408            model: Arc::new(RwLock::new(model)),
409            train_dataset,
410            eval_dataset: None,
411            loss_fn,
412            config,
413            callbacks: Vec::new(),
414            metrics: MetricCollection::new(),
415            state: TrainingState::default(),
416        }
417    }
418
419    pub fn with_eval_dataset(mut self, eval_dataset: D) -> Self {
420        self.eval_dataset = Some(eval_dataset);
421        self
422    }
423
424    pub fn add_callback(mut self, callback: Box<dyn SimpleCallback>) -> Self {
425        self.callbacks.push(callback);
426        self
427    }
428
429    pub fn add_metric(&mut self, metric: Box<dyn Metric>) -> &mut Self {
430        self.metrics.add_metric_mut(metric);
431        self
432    }
433
434    /// Start training with the configured parameters
435    pub fn train(&mut self) -> Result<TrainingResults> {
436        self.state.start_time = Some(Instant::now());
437        self.state.learning_rate = self.config.learning_rate;
438        self.state.is_training = true;
439
440        // Call train begin callbacks
441        for callback in &mut self.callbacks {
442            callback.on_train_begin(&self.state, &self.config)?;
443        }
444
445        let mut training_history = Vec::new();
446
447        for epoch in 1..=self.config.num_epochs {
448            self.state.epoch = epoch;
449
450            // Call epoch begin callbacks
451            for callback in &mut self.callbacks {
452                callback.on_epoch_begin(epoch, &self.state)?;
453            }
454
455            // Train epoch
456            let epoch_result = self.train_epoch()?;
457            training_history.push(epoch_result.clone());
458
459            // Update state
460            self.state.train_loss = epoch_result.train_loss;
461            self.state.eval_loss = epoch_result.eval_loss;
462
463            // Update metrics in state
464            for (key, value) in &epoch_result.metrics {
465                self.state.metrics.insert(key.clone(), *value);
466            }
467
468            // Call epoch end callbacks
469            for callback in &mut self.callbacks {
470                callback.on_epoch_end(epoch, &self.state)?;
471            }
472
473            // Check for early stopping
474            if self.should_stop_early()? {
475                println!("Training stopped early at epoch {}", epoch);
476                break;
477            }
478        }
479
480        self.state.is_training = false;
481
482        // Call train end callbacks
483        for callback in &mut self.callbacks {
484            callback.on_train_end(&self.state)?;
485        }
486
487        Ok(TrainingResults {
488            final_train_loss: self.state.train_loss,
489            final_eval_loss: self.state.eval_loss,
490            best_metric: self.state.best_metric,
491            total_epochs: self.state.epoch,
492            total_steps: self.state.global_step,
493            training_time: self
494                .state
495                .start_time
496                .expect("start_time is set at beginning of train method")
497                .elapsed(),
498            history: training_history,
499        })
500    }
501
502    fn train_epoch(&mut self) -> Result<EpochResult> {
503        let mut total_loss = 0.0;
504        let mut step_count = 0;
505
506        // Simplified training loop (in practice would iterate over actual batches)
507        let steps_per_epoch = 100; // Placeholder
508
509        for step in 1..=steps_per_epoch {
510            self.state.global_step += 1;
511
512            // Call step begin callbacks
513            for callback in &mut self.callbacks {
514                callback.on_step_begin(step, &self.state)?;
515            }
516
517            // Simulate training step
518            let step_loss = self.train_step()?;
519            total_loss += step_loss;
520            step_count += 1;
521
522            // Logging
523            if self.state.global_step % self.config.logging_steps == 0 {
524                let logs = {
525                    let mut logs = HashMap::new();
526                    logs.insert("train_loss".to_string(), step_loss);
527                    logs.insert("learning_rate".to_string(), self.state.learning_rate);
528                    logs
529                };
530
531                for callback in &mut self.callbacks {
532                    callback.on_log(&logs, &self.state)?;
533                }
534            }
535
536            // Evaluation
537            if let Some(eval_steps) = self.config.eval_steps {
538                if self.state.global_step % eval_steps == 0 {
539                    self.evaluate()?;
540                }
541            }
542
543            // Saving
544            if let Some(save_steps) = self.config.save_steps {
545                if self.state.global_step % save_steps == 0 {
546                    for callback in &mut self.callbacks {
547                        callback.on_save(&self.state)?;
548                    }
549                }
550            }
551
552            // Call step end callbacks
553            for callback in &mut self.callbacks {
554                callback.on_step_end(step, &self.state)?;
555            }
556        }
557
558        let avg_train_loss = total_loss / step_count as f64;
559
560        // Run evaluation at end of epoch if we have eval dataset
561        let eval_loss = if self.eval_dataset.is_some() { Some(self.evaluate()?) } else { None };
562
563        Ok(EpochResult {
564            epoch: self.state.epoch,
565            train_loss: avg_train_loss,
566            eval_loss,
567            metrics: self.state.metrics.clone(),
568        })
569    }
570
571    fn train_step(&mut self) -> Result<f64> {
572        // Simplified training step - in practice would:
573        // 1. Get batch from dataset
574        // 2. Forward pass
575        // 3. Compute loss
576        // 4. Backward pass
577        // 5. Update weights
578
579        // Simulate decreasing loss
580        let loss = 1.0 / (1.0 + self.state.global_step as f64 * 0.001);
581        Ok(loss)
582    }
583
584    fn evaluate(&mut self) -> Result<f64> {
585        if self.eval_dataset.is_none() {
586            return Ok(0.0);
587        }
588
589        // Call evaluate begin callbacks
590        for callback in &mut self.callbacks {
591            callback.on_evaluate_begin(&self.state)?;
592        }
593
594        // Simplified evaluation - in practice would:
595        // 1. Set model to eval mode
596        // 2. Iterate over eval dataset
597        // 3. Compute metrics
598        // 4. Set model back to train mode
599
600        let eval_loss = 0.5 / (1.0 + self.state.epoch as f64 * 0.1);
601
602        // Update state
603        self.state.eval_loss = Some(eval_loss);
604
605        // Call evaluate end callbacks
606        for callback in &mut self.callbacks {
607            callback.on_evaluate_end(&self.state)?;
608        }
609
610        Ok(eval_loss)
611    }
612
613    fn should_stop_early(&self) -> Result<bool> {
614        // Check if any callback has requested early stopping
615        if let (Some(patience), Some(threshold)) = (
616            self.config.early_stopping_patience,
617            self.config.early_stopping_threshold,
618        ) {
619            if let Some(current_loss) = self.state.eval_loss {
620                if let Some(best_metric) = self.state.best_metric {
621                    if current_loss > best_metric + threshold {
622                        return Ok(self.state.patience_counter >= patience);
623                    }
624                }
625            }
626        }
627
628        Ok(self.state.should_stop)
629    }
630
631    /// Get current training state
632    pub fn get_state(&self) -> &TrainingState {
633        &self.state
634    }
635
636    /// Get model reference
637    pub fn get_model(&self) -> Arc<RwLock<M>> {
638        Arc::clone(&self.model)
639    }
640}
641
642#[derive(Debug, Clone)]
643pub struct TrainingResults {
644    pub final_train_loss: f64,
645    pub final_eval_loss: Option<f64>,
646    pub best_metric: Option<f64>,
647    pub total_epochs: u32,
648    pub total_steps: u32,
649    pub training_time: Duration,
650    pub history: Vec<EpochResult>,
651}
652
653#[derive(Debug, Clone)]
654pub struct EpochResult {
655    pub epoch: u32,
656    pub train_loss: f64,
657    pub eval_loss: Option<f64>,
658    pub metrics: HashMap<String, f64>,
659}
660
661/// Builder pattern for easier trainer configuration
662pub struct SimpleTrainerBuilder<M, D, L> {
663    model: Option<M>,
664    train_dataset: Option<D>,
665    eval_dataset: Option<D>,
666    loss_fn: Option<L>,
667    config: SimpleTrainingConfig,
668    callbacks: Vec<Box<dyn SimpleCallback>>,
669    metrics: Vec<Box<dyn Metric>>,
670}
671
672impl<M, D, L> Default for SimpleTrainerBuilder<M, D, L>
673where
674    M: Send + Sync,
675    D: Clone,
676    L: Loss + Send + Sync,
677{
678    fn default() -> Self {
679        Self::new()
680    }
681}
682
683impl<M, D, L> SimpleTrainerBuilder<M, D, L>
684where
685    M: Send + Sync,
686    D: Clone,
687    L: Loss + Send + Sync,
688{
689    pub fn new() -> Self {
690        Self {
691            model: None,
692            train_dataset: None,
693            eval_dataset: None,
694            loss_fn: None,
695            config: SimpleTrainingConfig::default(),
696            callbacks: Vec::new(),
697            metrics: Vec::new(),
698        }
699    }
700
701    pub fn model(mut self, model: M) -> Self {
702        self.model = Some(model);
703        self
704    }
705
706    pub fn train_dataset(mut self, dataset: D) -> Self {
707        self.train_dataset = Some(dataset);
708        self
709    }
710
711    pub fn eval_dataset(mut self, dataset: D) -> Self {
712        self.eval_dataset = Some(dataset);
713        self
714    }
715
716    pub fn loss_function(mut self, loss_fn: L) -> Self {
717        self.loss_fn = Some(loss_fn);
718        self
719    }
720
721    pub fn learning_rate(mut self, lr: f64) -> Self {
722        self.config.learning_rate = lr;
723        self
724    }
725
726    pub fn batch_size(mut self, batch_size: usize) -> Self {
727        self.config.batch_size = batch_size;
728        self
729    }
730
731    pub fn num_epochs(mut self, epochs: u32) -> Self {
732        self.config.num_epochs = epochs;
733        self
734    }
735
736    pub fn output_dir(mut self, dir: String) -> Self {
737        self.config.output_dir = dir;
738        self
739    }
740
741    pub fn with_logging(mut self) -> Self {
742        self.callbacks.push(Box::new(LoggingCallback::new(LogLevel::Info)));
743        self
744    }
745
746    pub fn with_progress_bar(self) -> Self {
747        // Would need total steps calculation here
748        self
749    }
750
751    pub fn with_early_stopping(mut self, monitor: String, patience: u32, threshold: f64) -> Self {
752        self.callbacks.push(Box::new(EarlyStoppingCallback::new(
753            monitor,
754            patience,
755            threshold,
756            EarlyStoppingMode::Min,
757        )));
758        self
759    }
760
761    pub fn with_checkpoints(mut self, save_dir: String, save_best_only: bool) -> Self {
762        self.callbacks.push(Box::new(CheckpointCallback::new(
763            save_dir,
764            save_best_only,
765            Some("eval_loss".to_string()),
766        )));
767        self
768    }
769
770    pub fn build(self) -> Result<SimpleTrainer<M, D, L>> {
771        let model = self.model.context("Model is required")?;
772        let train_dataset = self.train_dataset.context("Training dataset is required")?;
773        let loss_fn = self.loss_fn.context("Loss function is required")?;
774
775        let mut trainer = SimpleTrainer::new(model, train_dataset, loss_fn, self.config);
776
777        if let Some(eval_dataset) = self.eval_dataset {
778            trainer = trainer.with_eval_dataset(eval_dataset);
779        }
780
781        for callback in self.callbacks {
782            trainer = trainer.add_callback(callback);
783        }
784
785        for metric in self.metrics {
786            trainer.add_metric(metric);
787        }
788
789        Ok(trainer)
790    }
791}
792
793#[cfg(test)]
794mod tests {
795    use super::*;
796    use crate::losses::MSELoss;
797
798    #[derive(Clone)]
799    struct DummyDataset;
800
801    struct DummyModel;
802
803    #[test]
804    fn test_simple_trainer_creation() {
805        let model = DummyModel;
806        let dataset = DummyDataset;
807        let loss_fn = MSELoss::new();
808        let config = SimpleTrainingConfig::default();
809
810        let trainer = SimpleTrainer::new(model, dataset, loss_fn, config);
811        assert_eq!(trainer.state.epoch, 0);
812        assert!(!trainer.state.is_training);
813    }
814
815    #[test]
816    fn test_simple_trainer_builder() {
817        let result = SimpleTrainerBuilder::new()
818            .model(DummyModel)
819            .train_dataset(DummyDataset)
820            .loss_function(MSELoss::new())
821            .learning_rate(0.001)
822            .batch_size(16)
823            .num_epochs(5)
824            .with_logging()
825            .build();
826
827        assert!(result.is_ok());
828        let trainer = result.expect("operation failed in test");
829        assert_eq!(trainer.config.learning_rate, 0.001);
830        assert_eq!(trainer.config.batch_size, 16);
831        assert_eq!(trainer.config.num_epochs, 5);
832    }
833
834    #[test]
835    fn test_logging_callback() {
836        let mut callback = LoggingCallback::new(LogLevel::Info);
837        let state = TrainingState::default();
838        let config = SimpleTrainingConfig::default();
839
840        // Test that callbacks don't panic
841        assert!(callback.on_train_begin(&state, &config).is_ok());
842        assert!(callback.on_epoch_begin(1, &state).is_ok());
843        assert!(callback.on_epoch_end(1, &state).is_ok());
844        assert!(callback.on_train_end(&state).is_ok());
845    }
846
847    #[test]
848    fn test_early_stopping_callback() {
849        let mut callback =
850            EarlyStoppingCallback::new("eval_loss".to_string(), 3, 0.01, EarlyStoppingMode::Min);
851
852        let mut state = TrainingState::default();
853        state.metrics.insert("eval_loss".to_string(), 0.5);
854
855        // First evaluation - should set best value
856        assert!(callback.on_evaluate_end(&state).is_ok());
857        assert_eq!(callback.best_value, Some(0.5));
858        assert_eq!(callback.patience_counter, 0);
859
860        // No improvement
861        state.metrics.insert("eval_loss".to_string(), 0.6);
862        assert!(callback.on_evaluate_end(&state).is_ok());
863        assert_eq!(callback.patience_counter, 1);
864    }
865
866    #[test]
867    fn test_metrics_callback() {
868        let mut callback = MetricsCallback::new(vec!["loss".to_string(), "accuracy".to_string()]);
869
870        let mut logs = HashMap::new();
871        logs.insert("loss".to_string(), 0.5);
872        logs.insert("accuracy".to_string(), 0.9);
873        logs.insert("other_metric".to_string(), 0.1); // Should be ignored
874
875        let state = TrainingState::default();
876        assert!(callback.on_log(&logs, &state).is_ok());
877
878        assert_eq!(callback.get_history("loss"), Some(&vec![0.5]));
879        assert_eq!(callback.get_history("accuracy"), Some(&vec![0.9]));
880        assert_eq!(callback.get_history("other_metric"), None);
881    }
882
883    #[test]
884    fn test_config_defaults() {
885        let config = SimpleTrainingConfig::default();
886        assert_eq!(config.learning_rate, 3e-4);
887        assert_eq!(config.batch_size, 32);
888        assert_eq!(config.num_epochs, 3);
889        assert_eq!(config.logging_steps, 100);
890        assert_eq!(config.warmup_steps, 500);
891        assert_eq!(config.seed, Some(42));
892    }
893}