1use std::collections::HashMap;
2
3use yscv_autograd::Graph;
4use yscv_optim::{Adam, AdamW, Sgd};
5use yscv_tensor::Tensor;
6
7use crate::{
8 EpochTrainOptions, ModelError, SequentialModel, SupervisedDataset, SupervisedLoss,
9 TrainingCallback, TrainingLog, train_epoch_adam_with_options_and_loss,
10 train_epoch_adamw_with_options_and_loss, train_epoch_sgd_with_options_and_loss,
11};
12
13fn compute_raw_loss(predictions: &Tensor, targets: &Tensor, loss_kind: LossKind) -> f32 {
15 match loss_kind {
16 LossKind::Mse => {
17 let diff = predictions
18 .sub(targets)
19 .expect("shape mismatch in val loss");
20 let sq = diff.mul(&diff).expect("shape mismatch in val loss");
21 sq.mean()
22 }
23 LossKind::CrossEntropy => {
24 let clamped = predictions.clamp(1e-7, 1.0);
26 let log_pred = clamped.ln();
27 let product = targets.mul(&log_pred).expect("shape mismatch in val loss");
28 -product.mean()
29 }
30 LossKind::Bce => {
31 let eps = 1e-7_f32;
33 let clamped = predictions.clamp(eps, 1.0 - eps);
34 let log_p = clamped.ln();
35 let one_minus_p = clamped.neg().add(&Tensor::scalar(1.0)).expect("bce add");
36 let log_1mp = one_minus_p.clamp(eps, 1.0).ln();
37 let one_minus_t = targets.neg().add(&Tensor::scalar(1.0)).expect("bce add");
38 let term1 = targets.mul(&log_p).expect("bce mul");
39 let term2 = one_minus_t.mul(&log_1mp).expect("bce mul");
40 let sum = term1.add(&term2).expect("bce add");
41 -sum.mean()
42 }
43 }
44}
45
46#[derive(Debug, Clone, PartialEq)]
48pub enum OptimizerKind {
49 Sgd { lr: f32, momentum: f32 },
50 Adam { lr: f32 },
51 AdamW { lr: f32, weight_decay: f32 },
52}
53
54#[derive(Debug, Clone, Copy, PartialEq)]
56pub enum LossKind {
57 Mse,
58 CrossEntropy,
59 Bce,
60}
61
62impl LossKind {
63 fn to_supervised_loss(self) -> SupervisedLoss {
64 match self {
65 LossKind::Mse => SupervisedLoss::Mse,
66 LossKind::CrossEntropy => SupervisedLoss::CrossEntropy,
67 LossKind::Bce => SupervisedLoss::Bce,
68 }
69 }
70}
71
72#[derive(Debug, Clone, PartialEq)]
74pub struct TrainerConfig {
75 pub optimizer: OptimizerKind,
76 pub loss: LossKind,
77 pub epochs: usize,
78 pub batch_size: usize,
79 pub validation_split: Option<f32>,
82}
83
84impl Default for TrainerConfig {
85 fn default() -> Self {
86 Self {
87 optimizer: OptimizerKind::Sgd {
88 lr: 0.01,
89 momentum: 0.0,
90 },
91 loss: LossKind::Mse,
92 epochs: 10,
93 batch_size: 32,
94 validation_split: None,
95 }
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct TrainResult {
102 pub epochs_trained: usize,
103 pub final_loss: f32,
104 pub history: Vec<HashMap<String, f32>>,
105 pub log: TrainingLog,
107}
108
109pub struct Trainer {
111 config: TrainerConfig,
112 callbacks: Vec<Box<dyn TrainingCallback>>,
113}
114
115impl Trainer {
116 pub fn new(config: TrainerConfig) -> Self {
118 Self {
119 config,
120 callbacks: Vec::new(),
121 }
122 }
123
124 pub fn add_callback(&mut self, cb: Box<dyn TrainingCallback>) -> &mut Self {
126 self.callbacks.push(cb);
127 self
128 }
129
130 pub fn fit(
137 &mut self,
138 model: &mut SequentialModel,
139 graph: &mut Graph,
140 inputs: &Tensor,
141 targets: &Tensor,
142 ) -> Result<TrainResult, ModelError> {
143 model.register_cnn_params(graph);
147
148 if self.config.epochs == 0 {
149 return Err(ModelError::InvalidEpochCount { epochs: 0 });
150 }
151 if self.config.batch_size == 0 {
152 return Err(ModelError::InvalidBatchSize { batch_size: 0 });
153 }
154
155 let n_samples = inputs.shape()[0];
157 let (train_inputs, train_targets, val_data) = match self.config.validation_split {
158 Some(frac) if frac > 0.0 && frac < 1.0 => {
159 let val_count = ((n_samples as f32) * frac).round() as usize;
160 let val_count = val_count.max(1).min(n_samples - 1);
161 let train_count = n_samples - val_count;
162 let ti = inputs.narrow(0, 0, train_count)?;
163 let tt = targets.narrow(0, 0, train_count)?;
164 let vi = inputs.narrow(0, train_count, val_count)?;
165 let vt = targets.narrow(0, train_count, val_count)?;
166 (ti, tt, Some((vi, vt)))
167 }
168 _ => (inputs.clone(), targets.clone(), None),
169 };
170
171 let dataset = SupervisedDataset::new(train_inputs, train_targets)?;
172 let supervised_loss = self.config.loss.to_supervised_loss();
173 let loss_kind = self.config.loss;
174 let epoch_options = EpochTrainOptions {
175 batch_size: self.config.batch_size,
176 ..EpochTrainOptions::default()
177 };
178
179 let mut history: Vec<HashMap<String, f32>> = Vec::with_capacity(self.config.epochs);
180 let mut log = TrainingLog::new();
181 let mut epochs_trained = 0usize;
182 let mut final_loss = f32::NAN;
183
184 macro_rules! epoch_body {
185 ($epoch:expr, $metrics:expr) => {{
186 final_loss = $metrics.mean_loss;
187 epochs_trained = $epoch + 1;
188 let mut epoch_metrics = HashMap::new();
189 epoch_metrics.insert("loss".to_string(), $metrics.mean_loss);
190 if let Some((ref val_inputs, ref val_targets)) = val_data {
191 let vi_node = graph.variable(val_inputs.clone());
193 let vo_node = model.forward(graph, vi_node)?;
194 let val_preds = graph.value(vo_node)?.clone();
195 let val_loss = compute_raw_loss(&val_preds, val_targets, loss_kind);
196 epoch_metrics.insert("val_loss".to_string(), val_loss);
197 }
198 let should_stop = self.callbacks.iter_mut().fold(false, |stop, cb| {
199 cb.on_epoch_end($epoch, &epoch_metrics) || stop
200 });
201 log.log_epoch(epoch_metrics.clone());
202 history.push(epoch_metrics);
203 should_stop
204 }};
205 }
206
207 match &self.config.optimizer {
208 OptimizerKind::Sgd { lr, momentum } => {
209 let mut opt = Sgd::new(*lr)?;
210 if *momentum != 0.0 {
211 opt = opt.with_momentum(*momentum)?;
212 }
213 for epoch in 0..self.config.epochs {
214 let metrics = train_epoch_sgd_with_options_and_loss(
215 graph,
216 model,
217 &mut opt,
218 &dataset,
219 epoch_options.clone(),
220 supervised_loss,
221 )?;
222 if epoch_body!(epoch, metrics) {
223 break;
224 }
225 }
226 }
227 OptimizerKind::Adam { lr } => {
228 let mut opt = Adam::new(*lr)?;
229 for epoch in 0..self.config.epochs {
230 let metrics = train_epoch_adam_with_options_and_loss(
231 graph,
232 model,
233 &mut opt,
234 &dataset,
235 epoch_options.clone(),
236 supervised_loss,
237 )?;
238 if epoch_body!(epoch, metrics) {
239 break;
240 }
241 }
242 }
243 OptimizerKind::AdamW { lr, weight_decay } => {
244 let mut opt = AdamW::new(*lr)?.with_weight_decay(*weight_decay)?;
245 for epoch in 0..self.config.epochs {
246 let metrics = train_epoch_adamw_with_options_and_loss(
247 graph,
248 model,
249 &mut opt,
250 &dataset,
251 epoch_options.clone(),
252 supervised_loss,
253 )?;
254 if epoch_body!(epoch, metrics) {
255 break;
256 }
257 }
258 }
259 }
260
261 Ok(TrainResult {
262 epochs_trained,
263 final_loss,
264 history,
265 log,
266 })
267 }
268}