Skip to main content

yscv_model/
trainer.rs

1use std::collections::HashMap;
2
3use yscv_autograd::Graph;
4use yscv_optim::{Adam, AdamW, Sgd};
5use yscv_tensor::Tensor;
6
7use crate::{
8    EpochTrainOptions, ModelError, SequentialModel, SupervisedDataset, SupervisedLoss,
9    TrainingCallback, TrainingLog, train_epoch_adam_with_options_and_loss,
10    train_epoch_adamw_with_options_and_loss, train_epoch_sgd_with_options_and_loss,
11};
12
13/// Compute a scalar loss value from prediction and target tensors (no autograd).
14fn compute_raw_loss(predictions: &Tensor, targets: &Tensor, loss_kind: LossKind) -> f32 {
15    match loss_kind {
16        LossKind::Mse => {
17            let diff = predictions
18                .sub(targets)
19                .expect("shape mismatch in val loss");
20            let sq = diff.mul(&diff).expect("shape mismatch in val loss");
21            sq.mean()
22        }
23        LossKind::CrossEntropy => {
24            // -mean(target * ln(clamp(pred)))
25            let clamped = predictions.clamp(1e-7, 1.0);
26            let log_pred = clamped.ln();
27            let product = targets.mul(&log_pred).expect("shape mismatch in val loss");
28            -product.mean()
29        }
30        LossKind::Bce => {
31            // -mean(target*ln(pred) + (1-target)*ln(1-pred))
32            let eps = 1e-7_f32;
33            let clamped = predictions.clamp(eps, 1.0 - eps);
34            let log_p = clamped.ln();
35            let one_minus_p = clamped.neg().add(&Tensor::scalar(1.0)).expect("bce add");
36            let log_1mp = one_minus_p.clamp(eps, 1.0).ln();
37            let one_minus_t = targets.neg().add(&Tensor::scalar(1.0)).expect("bce add");
38            let term1 = targets.mul(&log_p).expect("bce mul");
39            let term2 = one_minus_t.mul(&log_1mp).expect("bce mul");
40            let sum = term1.add(&term2).expect("bce add");
41            -sum.mean()
42        }
43    }
44}
45
46/// Which optimizer to use.
47#[derive(Debug, Clone, PartialEq)]
48pub enum OptimizerKind {
49    Sgd { lr: f32, momentum: f32 },
50    Adam { lr: f32 },
51    AdamW { lr: f32, weight_decay: f32 },
52}
53
54/// Which loss function to use.
55#[derive(Debug, Clone, Copy, PartialEq)]
56pub enum LossKind {
57    Mse,
58    CrossEntropy,
59    Bce,
60}
61
62impl LossKind {
63    fn to_supervised_loss(self) -> SupervisedLoss {
64        match self {
65            LossKind::Mse => SupervisedLoss::Mse,
66            LossKind::CrossEntropy => SupervisedLoss::CrossEntropy,
67            LossKind::Bce => SupervisedLoss::Bce,
68        }
69    }
70}
71
72/// High-level training configuration.
73#[derive(Debug, Clone, PartialEq)]
74pub struct TrainerConfig {
75    pub optimizer: OptimizerKind,
76    pub loss: LossKind,
77    pub epochs: usize,
78    pub batch_size: usize,
79    /// Optional fraction of data to hold out for validation (e.g. 0.2 = 20%).
80    /// When `None`, no validation is performed.
81    pub validation_split: Option<f32>,
82}
83
84impl Default for TrainerConfig {
85    fn default() -> Self {
86        Self {
87            optimizer: OptimizerKind::Sgd {
88                lr: 0.01,
89                momentum: 0.0,
90            },
91            loss: LossKind::Mse,
92            epochs: 10,
93            batch_size: 32,
94            validation_split: None,
95        }
96    }
97}
98
99/// Training result returned after fitting.
100#[derive(Debug, Clone)]
101pub struct TrainResult {
102    pub epochs_trained: usize,
103    pub final_loss: f32,
104    pub history: Vec<HashMap<String, f32>>,
105    /// Structured training log with CSV export and per-metric history queries.
106    pub log: TrainingLog,
107}
108
109/// High-level trainer that wraps optimizer + loss + callbacks configuration.
110pub struct Trainer {
111    config: TrainerConfig,
112    callbacks: Vec<Box<dyn TrainingCallback>>,
113}
114
115impl Trainer {
116    /// Create a new trainer with the given configuration.
117    pub fn new(config: TrainerConfig) -> Self {
118        Self {
119            config,
120            callbacks: Vec::new(),
121        }
122    }
123
124    /// Add a callback (EarlyStopping, BestModelCheckpoint, etc.).
125    pub fn add_callback(&mut self, cb: Box<dyn TrainingCallback>) -> &mut Self {
126        self.callbacks.push(cb);
127        self
128    }
129
130    /// Train the model on the given data.
131    ///
132    /// `inputs` and `targets` are combined into a `SupervisedDataset`. For each
133    /// epoch the method runs a full pass over mini-batches, computes the loss,
134    /// back-propagates, and steps the optimizer. Callbacks are invoked after
135    /// every epoch and may request early stopping.
136    pub fn fit(
137        &mut self,
138        model: &mut SequentialModel,
139        graph: &mut Graph,
140        inputs: &Tensor,
141        targets: &Tensor,
142    ) -> Result<TrainResult, ModelError> {
143        // Auto-register CNN/attention/recurrent layer parameters as graph
144        // variables so that layers created without `new_in_graph` work in
145        // training mode.
146        model.register_cnn_params(graph);
147
148        if self.config.epochs == 0 {
149            return Err(ModelError::InvalidEpochCount { epochs: 0 });
150        }
151        if self.config.batch_size == 0 {
152            return Err(ModelError::InvalidBatchSize { batch_size: 0 });
153        }
154
155        // Split data into train / validation if requested.
156        let n_samples = inputs.shape()[0];
157        let (train_inputs, train_targets, val_data) = match self.config.validation_split {
158            Some(frac) if frac > 0.0 && frac < 1.0 => {
159                let val_count = ((n_samples as f32) * frac).round() as usize;
160                let val_count = val_count.max(1).min(n_samples - 1);
161                let train_count = n_samples - val_count;
162                let ti = inputs.narrow(0, 0, train_count)?;
163                let tt = targets.narrow(0, 0, train_count)?;
164                let vi = inputs.narrow(0, train_count, val_count)?;
165                let vt = targets.narrow(0, train_count, val_count)?;
166                (ti, tt, Some((vi, vt)))
167            }
168            _ => (inputs.clone(), targets.clone(), None),
169        };
170
171        let dataset = SupervisedDataset::new(train_inputs, train_targets)?;
172        let supervised_loss = self.config.loss.to_supervised_loss();
173        let loss_kind = self.config.loss;
174        let epoch_options = EpochTrainOptions {
175            batch_size: self.config.batch_size,
176            ..EpochTrainOptions::default()
177        };
178
179        let mut history: Vec<HashMap<String, f32>> = Vec::with_capacity(self.config.epochs);
180        let mut log = TrainingLog::new();
181        let mut epochs_trained = 0usize;
182        let mut final_loss = f32::NAN;
183
184        macro_rules! epoch_body {
185            ($epoch:expr, $metrics:expr) => {{
186                final_loss = $metrics.mean_loss;
187                epochs_trained = $epoch + 1;
188                let mut epoch_metrics = HashMap::new();
189                epoch_metrics.insert("loss".to_string(), $metrics.mean_loss);
190                if let Some((ref val_inputs, ref val_targets)) = val_data {
191                    // Use graph-based forward pass (works for all layer types).
192                    let vi_node = graph.variable(val_inputs.clone());
193                    let vo_node = model.forward(graph, vi_node)?;
194                    let val_preds = graph.value(vo_node)?.clone();
195                    let val_loss = compute_raw_loss(&val_preds, val_targets, loss_kind);
196                    epoch_metrics.insert("val_loss".to_string(), val_loss);
197                }
198                let should_stop = self.callbacks.iter_mut().fold(false, |stop, cb| {
199                    cb.on_epoch_end($epoch, &epoch_metrics) || stop
200                });
201                log.log_epoch(epoch_metrics.clone());
202                history.push(epoch_metrics);
203                should_stop
204            }};
205        }
206
207        match &self.config.optimizer {
208            OptimizerKind::Sgd { lr, momentum } => {
209                let mut opt = Sgd::new(*lr)?;
210                if *momentum != 0.0 {
211                    opt = opt.with_momentum(*momentum)?;
212                }
213                for epoch in 0..self.config.epochs {
214                    let metrics = train_epoch_sgd_with_options_and_loss(
215                        graph,
216                        model,
217                        &mut opt,
218                        &dataset,
219                        epoch_options.clone(),
220                        supervised_loss,
221                    )?;
222                    if epoch_body!(epoch, metrics) {
223                        break;
224                    }
225                }
226            }
227            OptimizerKind::Adam { lr } => {
228                let mut opt = Adam::new(*lr)?;
229                for epoch in 0..self.config.epochs {
230                    let metrics = train_epoch_adam_with_options_and_loss(
231                        graph,
232                        model,
233                        &mut opt,
234                        &dataset,
235                        epoch_options.clone(),
236                        supervised_loss,
237                    )?;
238                    if epoch_body!(epoch, metrics) {
239                        break;
240                    }
241                }
242            }
243            OptimizerKind::AdamW { lr, weight_decay } => {
244                let mut opt = AdamW::new(*lr)?.with_weight_decay(*weight_decay)?;
245                for epoch in 0..self.config.epochs {
246                    let metrics = train_epoch_adamw_with_options_and_loss(
247                        graph,
248                        model,
249                        &mut opt,
250                        &dataset,
251                        epoch_options.clone(),
252                        supervised_loss,
253                    )?;
254                    if epoch_body!(epoch, metrics) {
255                        break;
256                    }
257                }
258            }
259        }
260
261        Ok(TrainResult {
262            epochs_trained,
263            final_loss,
264            history,
265            log,
266        })
267    }
268}