1use crate::{
7 Activation, Dataset, Error, Layer, Loss, Metric, Mlp, Optimizer, OptimizerState, Result,
8 Trainer, loss,
9};
10
11use rand::SeedableRng;
12use rand::rngs::StdRng;
13use rand::seq::SliceRandom;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16pub enum Shuffle {
18 #[default]
20 None,
21 Seeded(u64),
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Default)]
26pub enum LrSchedule {
28 #[default]
30 Constant,
31 Step { step_size: usize, gamma: f32 },
33 CosineAnnealing { min_lr: f32 },
35}
36
37impl LrSchedule {
38 pub fn validate(self) -> Result<()> {
39 match self {
40 LrSchedule::Constant => Ok(()),
41 LrSchedule::Step { step_size, gamma } => {
42 if step_size == 0 {
43 return Err(Error::InvalidConfig(
44 "lr_schedule step_size must be > 0".to_owned(),
45 ));
46 }
47 if !(gamma.is_finite() && gamma > 0.0) {
48 return Err(Error::InvalidConfig(format!(
49 "lr_schedule gamma must be finite and > 0, got {gamma}"
50 )));
51 }
52 Ok(())
53 }
54 LrSchedule::CosineAnnealing { min_lr } => {
55 if !(min_lr.is_finite() && min_lr > 0.0) {
56 return Err(Error::InvalidConfig(format!(
57 "lr_schedule min_lr must be finite and > 0, got {min_lr}"
58 )));
59 }
60 Ok(())
61 }
62 }
63 }
64
65 fn lr_at_epoch(self, lr0: f32, epoch: usize, epochs: usize) -> f32 {
66 match self {
67 LrSchedule::Constant => lr0,
68 LrSchedule::Step { step_size, gamma } => {
69 let k = epoch / step_size;
70 lr0 * gamma.powi(k as i32)
71 }
72 LrSchedule::CosineAnnealing { min_lr } => {
73 if epochs <= 1 {
74 return lr0;
75 }
76
77 let t = epoch as f32;
78 let t_max = (epochs - 1) as f32;
79 let cos = (std::f32::consts::PI * (t / t_max)).cos();
80 min_lr + (lr0 - min_lr) * 0.5 * (1.0 + cos)
81 }
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct FitConfig {
89 pub epochs: usize,
90 pub lr: f32,
91 pub batch_size: usize,
92 pub shuffle: Shuffle,
93 pub lr_schedule: LrSchedule,
94 pub optimizer: Optimizer,
95 pub weight_decay: f32,
96 pub grad_clip_norm: Option<f32>,
97 pub loss: Loss,
98 pub metrics: Vec<Metric>,
99}
100
101impl Default for FitConfig {
102 fn default() -> Self {
103 Self {
104 epochs: 10,
105 lr: 1e-2,
106 batch_size: 1,
107 shuffle: Shuffle::None,
108 lr_schedule: LrSchedule::Constant,
109 optimizer: Optimizer::Sgd,
110 weight_decay: 0.0,
111 grad_clip_norm: None,
112 loss: Loss::Mse,
113 metrics: Vec::new(),
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
119pub struct FitReport {
121 pub epochs: Vec<EpochReport>,
122}
123
124#[derive(Debug, Clone)]
125pub struct EpochReport {
127 pub train: EvalReport,
128 pub val: Option<EvalReport>,
129}
130
131#[derive(Debug, Clone)]
132pub struct EvalReport {
134 pub loss: f32,
135 pub metrics: Vec<(Metric, f32)>,
136}
137
138impl EvalReport {
139 fn new(loss: f32, metrics: Vec<(Metric, f32)>) -> Self {
140 Self { loss, metrics }
141 }
142}
143
144impl Mlp {
145 pub fn evaluate(
147 &self,
148 data: &Dataset,
149 loss_fn: Loss,
150 metrics: &[Metric],
151 ) -> Result<EvalReport> {
152 validate_dataset_shapes(self, data)?;
153 validate_loss_compat(self, loss_fn, data.target_dim())?;
154 for &m in metrics {
155 m.validate()?;
156 }
157
158 let mut scratch = self.scratch();
159 let mut out_buf = vec![0.0_f32; self.output_dim()];
160
161 let mut total_loss = 0.0_f32;
162 let mut metric_acc = MetricsAccum::new(self.output_dim(), metrics)?;
163
164 for idx in 0..data.len() {
165 let x = data.input(idx);
166 let t = data.target(idx);
167
168 self.predict_into(x, &mut scratch, &mut out_buf)?;
169 total_loss += loss_fn.forward(&out_buf, t);
170 metric_acc.update(&out_buf, t)?;
171 }
172
173 let inv_n = 1.0 / data.len() as f32;
174 Ok(EvalReport::new(
175 total_loss * inv_n,
176 metric_acc.finish(data.len()),
177 ))
178 }
179
180 pub fn fit(
185 &mut self,
186 train: &Dataset,
187 val: Option<&Dataset>,
188 cfg: FitConfig,
189 ) -> Result<FitReport> {
190 if train.is_empty() {
191 return Err(Error::InvalidData(
192 "train dataset must not be empty".to_owned(),
193 ));
194 }
195 validate_dataset_shapes(self, train)?;
196 validate_loss_compat(self, cfg.loss, train.target_dim())?;
197 for &m in &cfg.metrics {
198 m.validate()?;
199 }
200
201 if let Some(val) = val {
202 if val.is_empty() {
203 return Err(Error::InvalidData(
204 "val dataset must not be empty".to_owned(),
205 ));
206 }
207 validate_dataset_shapes(self, val)?;
208 validate_loss_compat(self, cfg.loss, val.target_dim())?;
209 }
210
211 if cfg.epochs == 0 {
212 return Err(Error::InvalidConfig("epochs must be > 0".to_owned()));
213 }
214 if !(cfg.lr.is_finite() && cfg.lr > 0.0) {
215 return Err(Error::InvalidConfig("lr must be finite and > 0".to_owned()));
216 }
217 if cfg.batch_size == 0 {
218 return Err(Error::InvalidConfig("batch_size must be > 0".to_owned()));
219 }
220
221 cfg.lr_schedule.validate()?;
222
223 cfg.optimizer.validate()?;
224 if !(cfg.weight_decay.is_finite() && cfg.weight_decay >= 0.0) {
225 return Err(Error::InvalidConfig(
226 "weight_decay must be finite and >= 0".to_owned(),
227 ));
228 }
229 if let Some(v) = cfg.grad_clip_norm
230 && !(v.is_finite() && v > 0.0)
231 {
232 return Err(Error::InvalidConfig(
233 "grad_clip_norm must be finite and > 0".to_owned(),
234 ));
235 }
236
237 let mut opt_state: OptimizerState = cfg.optimizer.state(self)?;
238 let mut trainer = Trainer::new(self);
239 let mut batch_scratch = if cfg.batch_size > 1 {
240 Some(self.scratch_batch(cfg.batch_size))
241 } else {
242 None
243 };
244 let mut batch_backprop = if cfg.batch_size > 1 {
245 Some(self.backprop_scratch_batch(cfg.batch_size))
246 } else {
247 None
248 };
249 let mut d_outputs_batch = if cfg.batch_size > 1 {
250 Some(vec![0.0_f32; cfg.batch_size * self.output_dim()])
251 } else {
252 None
253 };
254 let mut gather_inputs = if cfg.batch_size > 1 {
255 match cfg.shuffle {
256 Shuffle::None => None,
257 Shuffle::Seeded(_) => Some(vec![0.0_f32; cfg.batch_size * self.input_dim()]),
258 }
259 } else {
260 None
261 };
262 let mut reports = Vec::with_capacity(cfg.epochs);
263
264 let mut indices: Vec<usize> = match cfg.shuffle {
266 Shuffle::None => Vec::new(),
267 Shuffle::Seeded(_) => (0..train.len()).collect(),
268 };
269
270 let mut rng = match cfg.shuffle {
271 Shuffle::None => None,
272 Shuffle::Seeded(seed) => Some(StdRng::seed_from_u64(seed)),
273 };
274
275 for epoch in 0..cfg.epochs {
276 let epoch_lr = cfg.lr_schedule.lr_at_epoch(cfg.lr, epoch, cfg.epochs);
277 debug_assert!(epoch_lr.is_finite() && epoch_lr > 0.0);
278
279 let mut epoch_loss = 0.0_f32;
280 let mut metric_acc = MetricsAccum::new(self.output_dim(), &cfg.metrics)?;
281
282 match cfg.shuffle {
283 Shuffle::None => {
284 if cfg.batch_size == 1 {
285 for idx in 0..train.len() {
286 let input = train.input(idx);
287 let target = train.target(idx);
288
289 self.forward(input, &mut trainer.scratch);
290 let pred = trainer.scratch.output();
291
292 let loss_val =
293 cfg.loss
294 .backward(pred, target, trainer.grads.d_output_mut());
295 epoch_loss += loss_val;
296 metric_acc.update(pred, target)?;
297
298 self.backward(input, &trainer.scratch, &mut trainer.grads);
299
300 if let Some(max_norm) = cfg.grad_clip_norm {
301 trainer.grads.clip_global_norm_params(max_norm);
302 }
303 self.apply_weight_decay(epoch_lr, cfg.weight_decay);
304 opt_state.step(self, &mut trainer.grads, epoch_lr);
305 }
306 } else {
307 for batch in train.batches(cfg.batch_size) {
308 if batch.len() == cfg.batch_size {
310 let bs = batch_scratch.as_mut().expect("batch_scratch must exist");
311 let bb =
312 batch_backprop.as_mut().expect("batch_backprop must exist");
313 let d_out = d_outputs_batch
314 .as_mut()
315 .expect("d_outputs_batch must exist");
316
317 self.forward_batch(batch.inputs_flat(), bs);
318 let preds = bs.output();
319
320 for b in 0..batch.len() {
321 let pred =
322 &preds[b * self.output_dim()..(b + 1) * self.output_dim()];
323 let target = batch.target(b);
324 let d_slice = &mut d_out
325 [b * self.output_dim()..(b + 1) * self.output_dim()];
326 let loss_val = cfg.loss.backward(pred, target, d_slice);
327 epoch_loss += loss_val;
328 metric_acc.update(pred, target)?;
329 }
330
331 self.backward_batch(
332 batch.inputs_flat(),
333 bs,
334 d_out,
335 &mut trainer.grads,
336 bb,
337 );
338 } else {
339 trainer.grads.zero_params();
341 for b in 0..batch.len() {
342 let input = batch.input(b);
343 let target = batch.target(b);
344
345 self.forward(input, &mut trainer.scratch);
346 let pred = trainer.scratch.output();
347
348 let loss_val = cfg.loss.backward(
349 pred,
350 target,
351 trainer.grads.d_output_mut(),
352 );
353 epoch_loss += loss_val;
354 metric_acc.update(pred, target)?;
355
356 self.backward_accumulate(
357 input,
358 &trainer.scratch,
359 &mut trainer.grads,
360 );
361 }
362 trainer.grads.scale_params(1.0 / batch.len() as f32);
363 }
364
365 if let Some(max_norm) = cfg.grad_clip_norm {
366 trainer.grads.clip_global_norm_params(max_norm);
367 }
368 self.apply_weight_decay(epoch_lr, cfg.weight_decay);
369 opt_state.step(self, &mut trainer.grads, epoch_lr);
370 }
371 }
372 }
373 Shuffle::Seeded(_) => {
374 let rng = rng.as_mut().expect("rng must be initialized for shuffling");
375 indices.shuffle(rng);
376
377 if cfg.batch_size == 1 {
378 for &idx in &indices {
379 let input = train.input(idx);
380 let target = train.target(idx);
381
382 self.forward(input, &mut trainer.scratch);
383 let pred = trainer.scratch.output();
384
385 let loss_val =
386 cfg.loss
387 .backward(pred, target, trainer.grads.d_output_mut());
388 epoch_loss += loss_val;
389 metric_acc.update(pred, target)?;
390
391 self.backward(input, &trainer.scratch, &mut trainer.grads);
392
393 if let Some(max_norm) = cfg.grad_clip_norm {
394 trainer.grads.clip_global_norm_params(max_norm);
395 }
396 self.apply_weight_decay(epoch_lr, cfg.weight_decay);
397 opt_state.step(self, &mut trainer.grads, epoch_lr);
398 }
399 } else {
400 for batch in indices.chunks(cfg.batch_size) {
401 if batch.len() == cfg.batch_size {
404 let bs = batch_scratch.as_mut().expect("batch_scratch must exist");
405 let bb =
406 batch_backprop.as_mut().expect("batch_backprop must exist");
407 let d_out = d_outputs_batch
408 .as_mut()
409 .expect("d_outputs_batch must exist");
410 let x_gather =
411 gather_inputs.as_mut().expect("gather_inputs must exist");
412
413 let in_dim = self.input_dim();
414 let out_dim = self.output_dim();
415 debug_assert_eq!(x_gather.len(), cfg.batch_size * in_dim);
416 debug_assert_eq!(d_out.len(), cfg.batch_size * out_dim);
417
418 for (b, &idx) in batch.iter().enumerate() {
419 let x = train.input(idx);
420 let x0 = b * in_dim;
421 x_gather[x0..x0 + in_dim].copy_from_slice(x);
422 }
423
424 self.forward_batch(x_gather, bs);
425 let preds = bs.output();
426
427 for (b, &idx) in batch.iter().enumerate() {
428 let pred = &preds[b * out_dim..(b + 1) * out_dim];
429 let target = train.target(idx);
430 let d_slice = &mut d_out[b * out_dim..(b + 1) * out_dim];
431
432 let loss_val = cfg.loss.backward(pred, target, d_slice);
433 epoch_loss += loss_val;
434 metric_acc.update(pred, target)?;
435 }
436
437 self.backward_batch(x_gather, bs, d_out, &mut trainer.grads, bb);
438 } else {
439 trainer.grads.zero_params();
441
442 for &idx in batch {
443 let input = train.input(idx);
444 let target = train.target(idx);
445
446 self.forward(input, &mut trainer.scratch);
447 let pred = trainer.scratch.output();
448
449 let loss_val = cfg.loss.backward(
450 pred,
451 target,
452 trainer.grads.d_output_mut(),
453 );
454 epoch_loss += loss_val;
455 metric_acc.update(pred, target)?;
456
457 self.backward_accumulate(
458 input,
459 &trainer.scratch,
460 &mut trainer.grads,
461 );
462 }
463
464 trainer.grads.scale_params(1.0 / batch.len() as f32);
465 }
466
467 if let Some(max_norm) = cfg.grad_clip_norm {
468 trainer.grads.clip_global_norm_params(max_norm);
469 }
470 self.apply_weight_decay(epoch_lr, cfg.weight_decay);
471 opt_state.step(self, &mut trainer.grads, epoch_lr);
472 }
473 }
474 }
475 }
476
477 let inv_n = 1.0 / train.len() as f32;
478 let train_report = EvalReport::new(epoch_loss * inv_n, metric_acc.finish(train.len()));
479 let val_report = match val {
480 Some(v) => Some(self.evaluate(v, cfg.loss, &cfg.metrics)?),
481 None => None,
482 };
483
484 reports.push(EpochReport {
485 train: train_report,
486 val: val_report,
487 });
488 }
489
490 Ok(FitReport { epochs: reports })
491 }
492
493 pub fn predict(&self, data: &Dataset) -> Result<Vec<f32>> {
497 if data.is_empty() {
498 return Err(Error::InvalidData("dataset must not be empty".to_owned()));
499 }
500 if data.input_dim() != self.input_dim() {
501 return Err(Error::InvalidData(format!(
502 "dataset input_dim {} does not match model input_dim {}",
503 data.input_dim(),
504 self.input_dim()
505 )));
506 }
507
508 let mut scratch = self.scratch();
509 let out_dim = self.output_dim();
510 let mut preds = vec![0.0_f32; data.len() * out_dim];
511
512 for idx in 0..data.len() {
513 let input = data.input(idx);
514 let y = self.forward(input, &mut scratch);
515 let start = idx * out_dim;
516 preds[start..start + out_dim].copy_from_slice(y);
517 }
518
519 Ok(preds)
520 }
521
522 pub fn predict_inputs(&self, inputs: &crate::Inputs) -> Result<Vec<f32>> {
526 if inputs.is_empty() {
527 return Err(Error::InvalidData("inputs must not be empty".to_owned()));
528 }
529 if inputs.input_dim() != self.input_dim() {
530 return Err(Error::InvalidData(format!(
531 "inputs input_dim {} does not match model input_dim {}",
532 inputs.input_dim(),
533 self.input_dim()
534 )));
535 }
536
537 let mut scratch = self.scratch();
538 let out_dim = self.output_dim();
539 let mut preds = vec![0.0_f32; inputs.len() * out_dim];
540
541 for idx in 0..inputs.len() {
542 let x = inputs.input(idx);
543 let y = self.forward(x, &mut scratch);
544 let start = idx * out_dim;
545 preds[start..start + out_dim].copy_from_slice(y);
546 }
547
548 Ok(preds)
549 }
550
551 pub fn evaluate_mse(&self, data: &Dataset) -> Result<f32> {
555 if data.is_empty() {
556 return Err(Error::InvalidData("dataset must not be empty".to_owned()));
557 }
558 Ok(self.evaluate(data, Loss::Mse, &[])?.loss)
559 }
560}
561
562fn validate_dataset_shapes(model: &Mlp, data: &Dataset) -> Result<()> {
563 if data.input_dim() != model.input_dim() {
564 return Err(Error::InvalidData(format!(
565 "dataset input_dim {} does not match model input_dim {}",
566 data.input_dim(),
567 model.input_dim()
568 )));
569 }
570 if data.target_dim() != model.output_dim() {
571 return Err(Error::InvalidData(format!(
572 "dataset target_dim {} does not match model output_dim {}",
573 data.target_dim(),
574 model.output_dim()
575 )));
576 }
577 Ok(())
578}
579
580fn validate_loss_compat(model: &Mlp, loss_fn: Loss, target_dim: usize) -> Result<()> {
581 loss_fn.validate()?;
582
583 match loss_fn {
584 Loss::Mse | Loss::Mae => Ok(()),
585 Loss::BinaryCrossEntropyWithLogits => {
586 if target_dim != 1 {
587 return Err(Error::InvalidConfig(format!(
588 "BinaryCrossEntropyWithLogits requires output_dim == 1, got {target_dim}"
589 )));
590 }
591 let last = last_layer_activation(model);
592 if last != Activation::Identity {
593 return Err(Error::InvalidConfig(
594 "BinaryCrossEntropyWithLogits expects raw logits; set the output layer activation to Identity"
595 .to_owned(),
596 ));
597 }
598 Ok(())
599 }
600 Loss::SoftmaxCrossEntropy => {
601 if target_dim < 2 {
602 return Err(Error::InvalidConfig(format!(
603 "SoftmaxCrossEntropy requires output_dim >= 2, got {target_dim}"
604 )));
605 }
606 let last = last_layer_activation(model);
607 if last != Activation::Identity {
608 return Err(Error::InvalidConfig(
609 "SoftmaxCrossEntropy expects raw logits; set the output layer activation to Identity".to_owned(),
610 ));
611 }
612 Ok(())
613 }
614 }
615}
616
617fn last_layer_activation(model: &Mlp) -> Activation {
618 last_layer(model)
620 .expect("mlp must have at least one layer")
621 .activation()
622}
623
624fn last_layer(model: &Mlp) -> Option<&Layer> {
625 model.layer(model.num_layers().checked_sub(1)?)
628}
629
630struct MetricsAccum {
631 output_dim: usize,
632 metrics: Vec<Metric>,
633 sums: Vec<f32>,
634}
635
636impl MetricsAccum {
637 fn new(output_dim: usize, metrics: &[Metric]) -> Result<Self> {
638 let mut ms = Vec::with_capacity(metrics.len());
639 for &m in metrics {
640 m.validate()?;
641 ms.push(m);
642 }
643 Ok(Self {
644 output_dim,
645 metrics: ms,
646 sums: vec![0.0; metrics.len()],
647 })
648 }
649
650 fn update(&mut self, pred: &[f32], target: &[f32]) -> Result<()> {
651 if self.metrics.is_empty() {
652 return Ok(());
653 }
654 if pred.len() != target.len() {
655 return Err(Error::InvalidData(format!(
656 "pred/target length mismatch: {} vs {}",
657 pred.len(),
658 target.len()
659 )));
660 }
661 if pred.len() != self.output_dim {
662 return Err(Error::InvalidData(format!(
663 "pred len {} does not match expected output_dim {}",
664 pred.len(),
665 self.output_dim
666 )));
667 }
668
669 for (idx, &m) in self.metrics.iter().enumerate() {
670 self.sums[idx] += metric_value(m, pred, target)?;
671 }
672 Ok(())
673 }
674
675 fn finish(self, n: usize) -> Vec<(Metric, f32)> {
676 if self.metrics.is_empty() {
677 return Vec::new();
678 }
679
680 let inv_n = 1.0 / n as f32;
681 self.metrics
682 .into_iter()
683 .zip(self.sums)
684 .map(|(m, s)| (m, s * inv_n))
685 .collect()
686 }
687}
688
689fn metric_value(metric: Metric, pred: &[f32], target: &[f32]) -> Result<f32> {
690 match metric {
691 Metric::Mse => Ok(loss::mse(pred, target)),
692 Metric::Mae => Ok(loss::mae(pred, target)),
693 Metric::Accuracy => Ok(accuracy(pred, target)?),
694 Metric::TopKAccuracy { k } => Ok(topk_accuracy(pred, target, k)?),
695 }
696}
697
698fn accuracy(pred: &[f32], target: &[f32]) -> Result<f32> {
699 if pred.len() != target.len() {
700 return Err(Error::InvalidData(format!(
701 "pred/target length mismatch: {} vs {}",
702 pred.len(),
703 target.len()
704 )));
705 }
706 if pred.is_empty() {
707 return Ok(0.0);
708 }
709
710 if pred.len() == 1 {
711 let y = pred[0];
713 let t = target[0];
714 let pred_label = if y >= 0.5 { 1 } else { 0 };
715 let true_label = if t >= 0.5 { 1 } else { 0 };
716 Ok(if pred_label == true_label { 1.0 } else { 0.0 })
717 } else {
718 let pred_idx = argmax(pred);
720 let true_idx = argmax(target);
721 Ok(if pred_idx == true_idx { 1.0 } else { 0.0 })
722 }
723}
724
725fn topk_accuracy(pred: &[f32], target: &[f32], k: usize) -> Result<f32> {
726 if pred.len() != target.len() {
727 return Err(Error::InvalidData(format!(
728 "pred/target length mismatch: {} vs {}",
729 pred.len(),
730 target.len()
731 )));
732 }
733 if pred.len() <= 1 {
734 return Err(Error::InvalidConfig(
735 "TopKAccuracy requires output_dim > 1".to_owned(),
736 ));
737 }
738 if k == 0 || k > pred.len() {
739 return Err(Error::InvalidConfig(format!(
740 "TopKAccuracy requires 1 <= k <= output_dim, got k={k} output_dim={}",
741 pred.len()
742 )));
743 }
744
745 let true_idx = argmax(target);
746
747 let true_score = pred[true_idx];
750 let mut num_greater = 0_usize;
751 for (i, &v) in pred.iter().enumerate() {
752 if i != true_idx && v > true_score {
753 num_greater += 1;
754 }
755 }
756 Ok(if num_greater < k { 1.0 } else { 0.0 })
757}
758
759fn argmax(xs: &[f32]) -> usize {
760 debug_assert!(!xs.is_empty());
761 let mut best_idx = 0;
762 let mut best_val = xs[0];
763 for (i, &v) in xs.iter().enumerate().skip(1) {
764 if v > best_val {
765 best_val = v;
766 best_idx = i;
767 }
768 }
769 best_idx
770}
771
772#[cfg(test)]
773mod tests {
774 use crate::{Activation, Dataset, Loss, Metric, MlpBuilder};
775
776 use super::Shuffle;
777
778 #[test]
779 fn evaluate_computes_accuracy_for_multiclass_one_hot() {
780 let mlp = MlpBuilder::new(2)
783 .unwrap()
784 .add_layer(3, Activation::Identity)
785 .unwrap()
786 .build_with_seed(0)
787 .unwrap();
788
789 let xs = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
791 let ys = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
792 let data = Dataset::from_rows(&xs, &ys).unwrap();
793
794 let report = mlp
797 .evaluate(&data, Loss::SoftmaxCrossEntropy, &[Metric::Accuracy])
798 .unwrap();
799 assert_eq!(report.metrics.len(), 1);
800 }
801
802 #[test]
803 fn shuffle_seeded_is_deterministic() {
804 let mut a = MlpBuilder::new(2)
805 .unwrap()
806 .add_layer(4, Activation::Tanh)
807 .unwrap()
808 .add_layer(1, Activation::Identity)
809 .unwrap()
810 .build_with_seed(0)
811 .unwrap();
812 let mut b = a.clone();
813
814 let xs = vec![
816 vec![0.0, 0.0],
817 vec![0.0, 1.0],
818 vec![1.0, 0.0],
819 vec![1.0, 1.0],
820 vec![2.0, 0.0],
821 ];
822 let ys = vec![vec![0.0], vec![1.0], vec![1.0], vec![2.0], vec![2.0]];
823 let train = Dataset::from_rows(&xs, &ys).unwrap();
824
825 let cfg = super::FitConfig {
826 epochs: 10,
827 lr: 0.05,
828 batch_size: 2,
829 shuffle: Shuffle::Seeded(123),
830 lr_schedule: super::LrSchedule::Constant,
831 optimizer: crate::Optimizer::Sgd,
832 weight_decay: 0.0,
833 grad_clip_norm: None,
834 loss: Loss::Mse,
835 metrics: vec![],
836 };
837
838 let rep_a = a.fit(&train, None, cfg.clone()).unwrap();
839 let rep_b = b.fit(&train, None, cfg).unwrap();
840
841 let last_a = rep_a.epochs.last().unwrap().train.loss;
842 let last_b = rep_b.epochs.last().unwrap().train.loss;
843 assert_eq!(last_a.to_bits(), last_b.to_bits());
844 }
845}