1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7use crate::losses::Loss;
8use crate::metrics::{Metric, MetricCollection};
9
10#[allow(dead_code)]
12pub struct SimpleTrainer<M, D, L> {
13 model: Arc<RwLock<M>>,
14 #[allow(dead_code)]
15 train_dataset: D,
16 eval_dataset: Option<D>,
17 loss_fn: L,
18 config: SimpleTrainingConfig,
19 callbacks: Vec<Box<dyn SimpleCallback>>,
20 metrics: MetricCollection,
21 state: TrainingState,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SimpleTrainingConfig {
26 pub learning_rate: f64,
27 pub batch_size: usize,
28 pub num_epochs: u32,
29 pub eval_steps: Option<u32>,
30 pub save_steps: Option<u32>,
31 pub logging_steps: u32,
32 pub warmup_steps: u32,
33 pub max_grad_norm: Option<f64>,
34 pub seed: Option<u64>,
35 pub output_dir: String,
36 pub early_stopping_patience: Option<u32>,
37 pub early_stopping_threshold: Option<f64>,
38}
39
40impl Default for SimpleTrainingConfig {
41 fn default() -> Self {
42 Self {
43 learning_rate: 3e-4,
44 batch_size: 32,
45 num_epochs: 3,
46 eval_steps: Some(500),
47 save_steps: Some(1000),
48 logging_steps: 100,
49 warmup_steps: 500,
50 max_grad_norm: Some(1.0),
51 seed: Some(42),
52 output_dir: "./output".to_string(),
53 early_stopping_patience: None,
54 early_stopping_threshold: None,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
60pub struct TrainingState {
61 pub epoch: u32,
62 pub global_step: u32,
63 pub train_loss: f64,
64 pub eval_loss: Option<f64>,
65 pub learning_rate: f64,
66 pub is_training: bool,
67 pub best_metric: Option<f64>,
68 pub patience_counter: u32,
69 pub should_stop: bool,
70 pub start_time: Option<Instant>,
71 pub metrics: HashMap<String, f64>,
72}
73
74impl Default for TrainingState {
75 fn default() -> Self {
76 Self {
77 epoch: 0,
78 global_step: 0,
79 train_loss: 0.0,
80 eval_loss: None,
81 learning_rate: 0.0,
82 is_training: false,
83 best_metric: None,
84 patience_counter: 0,
85 should_stop: false,
86 start_time: None,
87 metrics: HashMap::new(),
88 }
89 }
90}
91
92pub trait SimpleCallback: Send + Sync {
94 fn on_train_begin(
95 &mut self,
96 _state: &TrainingState,
97 _config: &SimpleTrainingConfig,
98 ) -> Result<()> {
99 Ok(())
100 }
101
102 fn on_train_end(&mut self, _state: &TrainingState) -> Result<()> {
103 Ok(())
104 }
105
106 fn on_epoch_begin(&mut self, _epoch: u32, _state: &TrainingState) -> Result<()> {
107 Ok(())
108 }
109
110 fn on_epoch_end(&mut self, _epoch: u32, _state: &TrainingState) -> Result<()> {
111 Ok(())
112 }
113
114 fn on_step_begin(&mut self, _step: u32, _state: &TrainingState) -> Result<()> {
115 Ok(())
116 }
117
118 fn on_step_end(&mut self, _step: u32, _state: &TrainingState) -> Result<()> {
119 Ok(())
120 }
121
122 fn on_evaluate_begin(&mut self, _state: &TrainingState) -> Result<()> {
123 Ok(())
124 }
125
126 fn on_evaluate_end(&mut self, _state: &TrainingState) -> Result<()> {
127 Ok(())
128 }
129
130 fn on_save(&mut self, _state: &TrainingState) -> Result<()> {
131 Ok(())
132 }
133
134 fn on_log(&mut self, _logs: &HashMap<String, f64>, _state: &TrainingState) -> Result<()> {
135 Ok(())
136 }
137}
138
139pub struct LoggingCallback {
141 log_level: LogLevel,
142}
143
144#[derive(Debug, Clone)]
145pub enum LogLevel {
146 Debug,
147 Info,
148 Warning,
149 Error,
150}
151
152impl LoggingCallback {
153 pub fn new(log_level: LogLevel) -> Self {
154 Self { log_level }
155 }
156}
157
158impl SimpleCallback for LoggingCallback {
159 fn on_train_begin(
160 &mut self,
161 _state: &TrainingState,
162 config: &SimpleTrainingConfig,
163 ) -> Result<()> {
164 println!(
165 "đ Starting training with config: learning_rate={}, batch_size={}, epochs={}",
166 config.learning_rate, config.batch_size, config.num_epochs
167 );
168 Ok(())
169 }
170
171 fn on_epoch_begin(&mut self, epoch: u32, _state: &TrainingState) -> Result<()> {
172 println!("đ Starting epoch {}", epoch);
173 Ok(())
174 }
175
176 fn on_epoch_end(&mut self, epoch: u32, state: &TrainingState) -> Result<()> {
177 let eval_info = if let Some(eval_loss) = state.eval_loss {
178 format!(", eval_loss: {:.4}", eval_loss)
179 } else {
180 String::new()
181 };
182
183 println!(
184 "â
Epoch {} completed - train_loss: {:.4}{}",
185 epoch, state.train_loss, eval_info
186 );
187 Ok(())
188 }
189
190 fn on_log(&mut self, logs: &HashMap<String, f64>, state: &TrainingState) -> Result<()> {
191 if matches!(self.log_level, LogLevel::Debug) {
192 println!("đ Step {} - {:?}", state.global_step, logs);
193 }
194 Ok(())
195 }
196
197 fn on_train_end(&mut self, state: &TrainingState) -> Result<()> {
198 if let Some(start_time) = state.start_time {
199 let duration = start_time.elapsed();
200 println!("đ Training completed in {:.2}s", duration.as_secs_f64());
201 }
202 Ok(())
203 }
204}
205
206pub struct ProgressCallback {
208 total_steps: u32,
209 current_step: u32,
210 bar_width: usize,
211}
212
213impl ProgressCallback {
214 pub fn new(total_steps: u32) -> Self {
215 Self {
216 total_steps,
217 current_step: 0,
218 bar_width: 50,
219 }
220 }
221
222 fn update_progress(&mut self, step: u32) {
223 self.current_step = step;
224 let progress = (step as f64 / self.total_steps as f64).min(1.0);
225 let filled = (progress * self.bar_width as f64) as usize;
226 let empty = self.bar_width - filled;
227
228 let bar = format!("[{}{}]", "â".repeat(filled), "â".repeat(empty));
229
230 print!(
231 "\r{} {:.1}% ({}/{})",
232 bar,
233 progress * 100.0,
234 step,
235 self.total_steps
236 );
237 if step >= self.total_steps {
238 println!();
239 }
240 }
241}
242
243impl SimpleCallback for ProgressCallback {
244 fn on_step_end(&mut self, step: u32, _state: &TrainingState) -> Result<()> {
245 self.update_progress(step);
246 Ok(())
247 }
248}
249
250pub struct EarlyStoppingCallback {
252 monitor: String,
253 patience: u32,
254 threshold: f64,
255 mode: EarlyStoppingMode,
256 best_value: Option<f64>,
257 patience_counter: u32,
258}
259
260#[derive(Debug, Clone)]
261pub enum EarlyStoppingMode {
262 Min,
263 Max,
264}
265
266impl EarlyStoppingCallback {
267 pub fn new(monitor: String, patience: u32, threshold: f64, mode: EarlyStoppingMode) -> Self {
268 Self {
269 monitor,
270 patience,
271 threshold,
272 mode,
273 best_value: None,
274 patience_counter: 0,
275 }
276 }
277}
278
279impl SimpleCallback for EarlyStoppingCallback {
280 fn on_evaluate_end(&mut self, state: &TrainingState) -> Result<()> {
281 if let Some(current_value) = state.metrics.get(&self.monitor) {
282 let improved = match self.best_value {
283 None => true,
284 Some(best) => match self.mode {
285 EarlyStoppingMode::Min => *current_value < best - self.threshold,
286 EarlyStoppingMode::Max => *current_value > best + self.threshold,
287 },
288 };
289
290 if improved {
291 self.best_value = Some(*current_value);
292 self.patience_counter = 0;
293 println!("đ¯ New best {}: {:.4}", self.monitor, current_value);
294 } else {
295 self.patience_counter += 1;
296 if self.patience_counter >= self.patience {
297 println!(
298 "âšī¸ Early stopping triggered. No improvement in {} for {} epochs",
299 self.monitor, self.patience
300 );
301 }
303 }
304 }
305 Ok(())
306 }
307}
308
309pub struct CheckpointCallback {
311 save_dir: String,
312 save_best_only: bool,
313 monitor: Option<String>,
314 mode: EarlyStoppingMode,
315 best_value: Option<f64>,
316}
317
318impl CheckpointCallback {
319 pub fn new(save_dir: String, save_best_only: bool, monitor: Option<String>) -> Self {
320 Self {
321 save_dir,
322 save_best_only,
323 monitor,
324 mode: EarlyStoppingMode::Min,
325 best_value: None,
326 }
327 }
328}
329
330impl SimpleCallback for CheckpointCallback {
331 fn on_save(&mut self, state: &TrainingState) -> Result<()> {
332 let should_save = if self.save_best_only {
333 if let (Some(_monitor), Some(current_value)) = (
334 &self.monitor,
335 self.monitor.as_ref().and_then(|m| state.metrics.get(m.as_str())),
336 ) {
337 let is_best = match self.best_value {
338 None => true,
339 Some(best) => match self.mode {
340 EarlyStoppingMode::Min => *current_value < best,
341 EarlyStoppingMode::Max => *current_value > best,
342 },
343 };
344
345 if is_best {
346 self.best_value = Some(*current_value);
347 }
348 is_best
349 } else {
350 true }
352 } else {
353 true };
355
356 if should_save {
357 let checkpoint_path = format!("{}/checkpoint-{}", self.save_dir, state.global_step);
358 println!("đž Saving checkpoint to {}", checkpoint_path);
359 }
361
362 Ok(())
363 }
364}
365
366pub struct MetricsCallback {
368 tracked_metrics: Vec<String>,
369 history: HashMap<String, Vec<f64>>,
370}
371
372impl MetricsCallback {
373 pub fn new(tracked_metrics: Vec<String>) -> Self {
374 Self {
375 tracked_metrics,
376 history: HashMap::new(),
377 }
378 }
379
380 pub fn get_history(&self, metric: &str) -> Option<&Vec<f64>> {
381 self.history.get(metric)
382 }
383
384 pub fn get_all_history(&self) -> &HashMap<String, Vec<f64>> {
385 &self.history
386 }
387}
388
389impl SimpleCallback for MetricsCallback {
390 fn on_log(&mut self, logs: &HashMap<String, f64>, _state: &TrainingState) -> Result<()> {
391 for metric in &self.tracked_metrics {
392 if let Some(value) = logs.get(metric) {
393 self.history.entry(metric.clone()).or_default().push(*value);
394 }
395 }
396 Ok(())
397 }
398}
399
400impl<M, D, L> SimpleTrainer<M, D, L>
401where
402 M: Send + Sync,
403 D: Clone,
404 L: Loss + Send + Sync,
405{
406 pub fn new(model: M, train_dataset: D, loss_fn: L, config: SimpleTrainingConfig) -> Self {
407 Self {
408 model: Arc::new(RwLock::new(model)),
409 train_dataset,
410 eval_dataset: None,
411 loss_fn,
412 config,
413 callbacks: Vec::new(),
414 metrics: MetricCollection::new(),
415 state: TrainingState::default(),
416 }
417 }
418
419 pub fn with_eval_dataset(mut self, eval_dataset: D) -> Self {
420 self.eval_dataset = Some(eval_dataset);
421 self
422 }
423
424 pub fn add_callback(mut self, callback: Box<dyn SimpleCallback>) -> Self {
425 self.callbacks.push(callback);
426 self
427 }
428
429 pub fn add_metric(&mut self, metric: Box<dyn Metric>) -> &mut Self {
430 self.metrics.add_metric_mut(metric);
431 self
432 }
433
434 pub fn train(&mut self) -> Result<TrainingResults> {
436 self.state.start_time = Some(Instant::now());
437 self.state.learning_rate = self.config.learning_rate;
438 self.state.is_training = true;
439
440 for callback in &mut self.callbacks {
442 callback.on_train_begin(&self.state, &self.config)?;
443 }
444
445 let mut training_history = Vec::new();
446
447 for epoch in 1..=self.config.num_epochs {
448 self.state.epoch = epoch;
449
450 for callback in &mut self.callbacks {
452 callback.on_epoch_begin(epoch, &self.state)?;
453 }
454
455 let epoch_result = self.train_epoch()?;
457 training_history.push(epoch_result.clone());
458
459 self.state.train_loss = epoch_result.train_loss;
461 self.state.eval_loss = epoch_result.eval_loss;
462
463 for (key, value) in &epoch_result.metrics {
465 self.state.metrics.insert(key.clone(), *value);
466 }
467
468 for callback in &mut self.callbacks {
470 callback.on_epoch_end(epoch, &self.state)?;
471 }
472
473 if self.should_stop_early()? {
475 println!("Training stopped early at epoch {}", epoch);
476 break;
477 }
478 }
479
480 self.state.is_training = false;
481
482 for callback in &mut self.callbacks {
484 callback.on_train_end(&self.state)?;
485 }
486
487 Ok(TrainingResults {
488 final_train_loss: self.state.train_loss,
489 final_eval_loss: self.state.eval_loss,
490 best_metric: self.state.best_metric,
491 total_epochs: self.state.epoch,
492 total_steps: self.state.global_step,
493 training_time: self
494 .state
495 .start_time
496 .expect("start_time is set at beginning of train method")
497 .elapsed(),
498 history: training_history,
499 })
500 }
501
502 fn train_epoch(&mut self) -> Result<EpochResult> {
503 let mut total_loss = 0.0;
504 let mut step_count = 0;
505
506 let steps_per_epoch = 100; for step in 1..=steps_per_epoch {
510 self.state.global_step += 1;
511
512 for callback in &mut self.callbacks {
514 callback.on_step_begin(step, &self.state)?;
515 }
516
517 let step_loss = self.train_step()?;
519 total_loss += step_loss;
520 step_count += 1;
521
522 if self.state.global_step % self.config.logging_steps == 0 {
524 let logs = {
525 let mut logs = HashMap::new();
526 logs.insert("train_loss".to_string(), step_loss);
527 logs.insert("learning_rate".to_string(), self.state.learning_rate);
528 logs
529 };
530
531 for callback in &mut self.callbacks {
532 callback.on_log(&logs, &self.state)?;
533 }
534 }
535
536 if let Some(eval_steps) = self.config.eval_steps {
538 if self.state.global_step % eval_steps == 0 {
539 self.evaluate()?;
540 }
541 }
542
543 if let Some(save_steps) = self.config.save_steps {
545 if self.state.global_step % save_steps == 0 {
546 for callback in &mut self.callbacks {
547 callback.on_save(&self.state)?;
548 }
549 }
550 }
551
552 for callback in &mut self.callbacks {
554 callback.on_step_end(step, &self.state)?;
555 }
556 }
557
558 let avg_train_loss = total_loss / step_count as f64;
559
560 let eval_loss = if self.eval_dataset.is_some() { Some(self.evaluate()?) } else { None };
562
563 Ok(EpochResult {
564 epoch: self.state.epoch,
565 train_loss: avg_train_loss,
566 eval_loss,
567 metrics: self.state.metrics.clone(),
568 })
569 }
570
571 fn train_step(&mut self) -> Result<f64> {
572 let loss = 1.0 / (1.0 + self.state.global_step as f64 * 0.001);
581 Ok(loss)
582 }
583
584 fn evaluate(&mut self) -> Result<f64> {
585 if self.eval_dataset.is_none() {
586 return Ok(0.0);
587 }
588
589 for callback in &mut self.callbacks {
591 callback.on_evaluate_begin(&self.state)?;
592 }
593
594 let eval_loss = 0.5 / (1.0 + self.state.epoch as f64 * 0.1);
601
602 self.state.eval_loss = Some(eval_loss);
604
605 for callback in &mut self.callbacks {
607 callback.on_evaluate_end(&self.state)?;
608 }
609
610 Ok(eval_loss)
611 }
612
613 fn should_stop_early(&self) -> Result<bool> {
614 if let (Some(patience), Some(threshold)) = (
616 self.config.early_stopping_patience,
617 self.config.early_stopping_threshold,
618 ) {
619 if let Some(current_loss) = self.state.eval_loss {
620 if let Some(best_metric) = self.state.best_metric {
621 if current_loss > best_metric + threshold {
622 return Ok(self.state.patience_counter >= patience);
623 }
624 }
625 }
626 }
627
628 Ok(self.state.should_stop)
629 }
630
631 pub fn get_state(&self) -> &TrainingState {
633 &self.state
634 }
635
636 pub fn get_model(&self) -> Arc<RwLock<M>> {
638 Arc::clone(&self.model)
639 }
640}
641
642#[derive(Debug, Clone)]
643pub struct TrainingResults {
644 pub final_train_loss: f64,
645 pub final_eval_loss: Option<f64>,
646 pub best_metric: Option<f64>,
647 pub total_epochs: u32,
648 pub total_steps: u32,
649 pub training_time: Duration,
650 pub history: Vec<EpochResult>,
651}
652
653#[derive(Debug, Clone)]
654pub struct EpochResult {
655 pub epoch: u32,
656 pub train_loss: f64,
657 pub eval_loss: Option<f64>,
658 pub metrics: HashMap<String, f64>,
659}
660
661pub struct SimpleTrainerBuilder<M, D, L> {
663 model: Option<M>,
664 train_dataset: Option<D>,
665 eval_dataset: Option<D>,
666 loss_fn: Option<L>,
667 config: SimpleTrainingConfig,
668 callbacks: Vec<Box<dyn SimpleCallback>>,
669 metrics: Vec<Box<dyn Metric>>,
670}
671
672impl<M, D, L> Default for SimpleTrainerBuilder<M, D, L>
673where
674 M: Send + Sync,
675 D: Clone,
676 L: Loss + Send + Sync,
677{
678 fn default() -> Self {
679 Self::new()
680 }
681}
682
683impl<M, D, L> SimpleTrainerBuilder<M, D, L>
684where
685 M: Send + Sync,
686 D: Clone,
687 L: Loss + Send + Sync,
688{
689 pub fn new() -> Self {
690 Self {
691 model: None,
692 train_dataset: None,
693 eval_dataset: None,
694 loss_fn: None,
695 config: SimpleTrainingConfig::default(),
696 callbacks: Vec::new(),
697 metrics: Vec::new(),
698 }
699 }
700
701 pub fn model(mut self, model: M) -> Self {
702 self.model = Some(model);
703 self
704 }
705
706 pub fn train_dataset(mut self, dataset: D) -> Self {
707 self.train_dataset = Some(dataset);
708 self
709 }
710
711 pub fn eval_dataset(mut self, dataset: D) -> Self {
712 self.eval_dataset = Some(dataset);
713 self
714 }
715
716 pub fn loss_function(mut self, loss_fn: L) -> Self {
717 self.loss_fn = Some(loss_fn);
718 self
719 }
720
721 pub fn learning_rate(mut self, lr: f64) -> Self {
722 self.config.learning_rate = lr;
723 self
724 }
725
726 pub fn batch_size(mut self, batch_size: usize) -> Self {
727 self.config.batch_size = batch_size;
728 self
729 }
730
731 pub fn num_epochs(mut self, epochs: u32) -> Self {
732 self.config.num_epochs = epochs;
733 self
734 }
735
736 pub fn output_dir(mut self, dir: String) -> Self {
737 self.config.output_dir = dir;
738 self
739 }
740
741 pub fn with_logging(mut self) -> Self {
742 self.callbacks.push(Box::new(LoggingCallback::new(LogLevel::Info)));
743 self
744 }
745
746 pub fn with_progress_bar(self) -> Self {
747 self
749 }
750
751 pub fn with_early_stopping(mut self, monitor: String, patience: u32, threshold: f64) -> Self {
752 self.callbacks.push(Box::new(EarlyStoppingCallback::new(
753 monitor,
754 patience,
755 threshold,
756 EarlyStoppingMode::Min,
757 )));
758 self
759 }
760
761 pub fn with_checkpoints(mut self, save_dir: String, save_best_only: bool) -> Self {
762 self.callbacks.push(Box::new(CheckpointCallback::new(
763 save_dir,
764 save_best_only,
765 Some("eval_loss".to_string()),
766 )));
767 self
768 }
769
770 pub fn build(self) -> Result<SimpleTrainer<M, D, L>> {
771 let model = self.model.context("Model is required")?;
772 let train_dataset = self.train_dataset.context("Training dataset is required")?;
773 let loss_fn = self.loss_fn.context("Loss function is required")?;
774
775 let mut trainer = SimpleTrainer::new(model, train_dataset, loss_fn, self.config);
776
777 if let Some(eval_dataset) = self.eval_dataset {
778 trainer = trainer.with_eval_dataset(eval_dataset);
779 }
780
781 for callback in self.callbacks {
782 trainer = trainer.add_callback(callback);
783 }
784
785 for metric in self.metrics {
786 trainer.add_metric(metric);
787 }
788
789 Ok(trainer)
790 }
791}
792
793#[cfg(test)]
794mod tests {
795 use super::*;
796 use crate::losses::MSELoss;
797
798 #[derive(Clone)]
799 struct DummyDataset;
800
801 struct DummyModel;
802
803 #[test]
804 fn test_simple_trainer_creation() {
805 let model = DummyModel;
806 let dataset = DummyDataset;
807 let loss_fn = MSELoss::new();
808 let config = SimpleTrainingConfig::default();
809
810 let trainer = SimpleTrainer::new(model, dataset, loss_fn, config);
811 assert_eq!(trainer.state.epoch, 0);
812 assert!(!trainer.state.is_training);
813 }
814
815 #[test]
816 fn test_simple_trainer_builder() {
817 let result = SimpleTrainerBuilder::new()
818 .model(DummyModel)
819 .train_dataset(DummyDataset)
820 .loss_function(MSELoss::new())
821 .learning_rate(0.001)
822 .batch_size(16)
823 .num_epochs(5)
824 .with_logging()
825 .build();
826
827 assert!(result.is_ok());
828 let trainer = result.expect("operation failed in test");
829 assert_eq!(trainer.config.learning_rate, 0.001);
830 assert_eq!(trainer.config.batch_size, 16);
831 assert_eq!(trainer.config.num_epochs, 5);
832 }
833
834 #[test]
835 fn test_logging_callback() {
836 let mut callback = LoggingCallback::new(LogLevel::Info);
837 let state = TrainingState::default();
838 let config = SimpleTrainingConfig::default();
839
840 assert!(callback.on_train_begin(&state, &config).is_ok());
842 assert!(callback.on_epoch_begin(1, &state).is_ok());
843 assert!(callback.on_epoch_end(1, &state).is_ok());
844 assert!(callback.on_train_end(&state).is_ok());
845 }
846
847 #[test]
848 fn test_early_stopping_callback() {
849 let mut callback =
850 EarlyStoppingCallback::new("eval_loss".to_string(), 3, 0.01, EarlyStoppingMode::Min);
851
852 let mut state = TrainingState::default();
853 state.metrics.insert("eval_loss".to_string(), 0.5);
854
855 assert!(callback.on_evaluate_end(&state).is_ok());
857 assert_eq!(callback.best_value, Some(0.5));
858 assert_eq!(callback.patience_counter, 0);
859
860 state.metrics.insert("eval_loss".to_string(), 0.6);
862 assert!(callback.on_evaluate_end(&state).is_ok());
863 assert_eq!(callback.patience_counter, 1);
864 }
865
866 #[test]
867 fn test_metrics_callback() {
868 let mut callback = MetricsCallback::new(vec!["loss".to_string(), "accuracy".to_string()]);
869
870 let mut logs = HashMap::new();
871 logs.insert("loss".to_string(), 0.5);
872 logs.insert("accuracy".to_string(), 0.9);
873 logs.insert("other_metric".to_string(), 0.1); let state = TrainingState::default();
876 assert!(callback.on_log(&logs, &state).is_ok());
877
878 assert_eq!(callback.get_history("loss"), Some(&vec![0.5]));
879 assert_eq!(callback.get_history("accuracy"), Some(&vec![0.9]));
880 assert_eq!(callback.get_history("other_metric"), None);
881 }
882
883 #[test]
884 fn test_config_defaults() {
885 let config = SimpleTrainingConfig::default();
886 assert_eq!(config.learning_rate, 3e-4);
887 assert_eq!(config.batch_size, 32);
888 assert_eq!(config.num_epochs, 3);
889 assert_eq!(config.logging_steps, 100);
890 assert_eq!(config.warmup_steps, 500);
891 assert_eq!(config.seed, Some(42));
892 }
893}