Skip to main content

rust_mlp/
train.rs

1//! High-level training and evaluation APIs.
2//!
3//! These methods validate dataset shapes and return `Result`, while internally using
4//! allocation-free per-sample forward/backward passes.
5
6use crate::{
7    Activation, Dataset, Error, Layer, Loss, Metric, Mlp, Optimizer, OptimizerState, Result,
8    Trainer, loss,
9};
10
11use rand::SeedableRng;
12use rand::rngs::StdRng;
13use rand::seq::SliceRandom;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
16/// Training data shuffling strategy.
17pub enum Shuffle {
18    /// Keep samples in dataset order.
19    #[default]
20    None,
21    /// Shuffle each epoch using a deterministic RNG seed.
22    Seeded(u64),
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Default)]
26/// Learning rate schedule (applied per epoch).
27pub enum LrSchedule {
28    /// Keep `lr` constant.
29    #[default]
30    Constant,
31    /// Step decay: `lr = lr0 * gamma^(epoch / step_size)`.
32    Step { step_size: usize, gamma: f32 },
33    /// Cosine annealing from `lr0` down to `min_lr`.
34    CosineAnnealing { min_lr: f32 },
35}
36
37impl LrSchedule {
38    pub fn validate(self) -> Result<()> {
39        match self {
40            LrSchedule::Constant => Ok(()),
41            LrSchedule::Step { step_size, gamma } => {
42                if step_size == 0 {
43                    return Err(Error::InvalidConfig(
44                        "lr_schedule step_size must be > 0".to_owned(),
45                    ));
46                }
47                if !(gamma.is_finite() && gamma > 0.0) {
48                    return Err(Error::InvalidConfig(format!(
49                        "lr_schedule gamma must be finite and > 0, got {gamma}"
50                    )));
51                }
52                Ok(())
53            }
54            LrSchedule::CosineAnnealing { min_lr } => {
55                if !(min_lr.is_finite() && min_lr > 0.0) {
56                    return Err(Error::InvalidConfig(format!(
57                        "lr_schedule min_lr must be finite and > 0, got {min_lr}"
58                    )));
59                }
60                Ok(())
61            }
62        }
63    }
64
65    fn lr_at_epoch(self, lr0: f32, epoch: usize, epochs: usize) -> f32 {
66        match self {
67            LrSchedule::Constant => lr0,
68            LrSchedule::Step { step_size, gamma } => {
69                let k = epoch / step_size;
70                lr0 * gamma.powi(k as i32)
71            }
72            LrSchedule::CosineAnnealing { min_lr } => {
73                if epochs <= 1 {
74                    return lr0;
75                }
76
77                let t = epoch as f32;
78                let t_max = (epochs - 1) as f32;
79                let cos = (std::f32::consts::PI * (t / t_max)).cos();
80                min_lr + (lr0 - min_lr) * 0.5 * (1.0 + cos)
81            }
82        }
83    }
84}
85
86#[derive(Debug, Clone)]
87/// Configuration for `Mlp::fit`.
88pub struct FitConfig {
89    pub epochs: usize,
90    pub lr: f32,
91    pub batch_size: usize,
92    pub shuffle: Shuffle,
93    pub lr_schedule: LrSchedule,
94    pub optimizer: Optimizer,
95    pub weight_decay: f32,
96    pub grad_clip_norm: Option<f32>,
97    pub loss: Loss,
98    pub metrics: Vec<Metric>,
99}
100
101impl Default for FitConfig {
102    fn default() -> Self {
103        Self {
104            epochs: 10,
105            lr: 1e-2,
106            batch_size: 1,
107            shuffle: Shuffle::None,
108            lr_schedule: LrSchedule::Constant,
109            optimizer: Optimizer::Sgd,
110            weight_decay: 0.0,
111            grad_clip_norm: None,
112            loss: Loss::Mse,
113            metrics: Vec::new(),
114        }
115    }
116}
117
118#[derive(Debug, Clone)]
119/// Output of a training run.
120pub struct FitReport {
121    pub epochs: Vec<EpochReport>,
122}
123
124#[derive(Debug, Clone)]
125/// Report for a single epoch.
126pub struct EpochReport {
127    pub train: EvalReport,
128    pub val: Option<EvalReport>,
129}
130
131#[derive(Debug, Clone)]
132/// Output of `Mlp::evaluate`.
133pub struct EvalReport {
134    pub loss: f32,
135    pub metrics: Vec<(Metric, f32)>,
136}
137
138impl EvalReport {
139    fn new(loss: f32, metrics: Vec<(Metric, f32)>) -> Self {
140        Self { loss, metrics }
141    }
142}
143
144impl Mlp {
145    /// Evaluate a dataset with a loss and optional metrics.
146    pub fn evaluate(
147        &self,
148        data: &Dataset,
149        loss_fn: Loss,
150        metrics: &[Metric],
151    ) -> Result<EvalReport> {
152        validate_dataset_shapes(self, data)?;
153        validate_loss_compat(self, loss_fn, data.target_dim())?;
154        for &m in metrics {
155            m.validate()?;
156        }
157
158        let mut scratch = self.scratch();
159        let mut out_buf = vec![0.0_f32; self.output_dim()];
160
161        let mut total_loss = 0.0_f32;
162        let mut metric_acc = MetricsAccum::new(self.output_dim(), metrics)?;
163
164        for idx in 0..data.len() {
165            let x = data.input(idx);
166            let t = data.target(idx);
167
168            self.predict_into(x, &mut scratch, &mut out_buf)?;
169            total_loss += loss_fn.forward(&out_buf, t);
170            metric_acc.update(&out_buf, t)?;
171        }
172
173        let inv_n = 1.0 / data.len() as f32;
174        Ok(EvalReport::new(
175            total_loss * inv_n,
176            metric_acc.finish(data.len()),
177        ))
178    }
179
180    /// Train the model on a dataset.
181    ///
182    /// This is a "batteries included" API intended to be easy to use.
183    /// Internally it still uses allocation-free forward/backward via scratch buffers.
184    pub fn fit(
185        &mut self,
186        train: &Dataset,
187        val: Option<&Dataset>,
188        cfg: FitConfig,
189    ) -> Result<FitReport> {
190        if train.is_empty() {
191            return Err(Error::InvalidData(
192                "train dataset must not be empty".to_owned(),
193            ));
194        }
195        validate_dataset_shapes(self, train)?;
196        validate_loss_compat(self, cfg.loss, train.target_dim())?;
197        for &m in &cfg.metrics {
198            m.validate()?;
199        }
200
201        if let Some(val) = val {
202            if val.is_empty() {
203                return Err(Error::InvalidData(
204                    "val dataset must not be empty".to_owned(),
205                ));
206            }
207            validate_dataset_shapes(self, val)?;
208            validate_loss_compat(self, cfg.loss, val.target_dim())?;
209        }
210
211        if cfg.epochs == 0 {
212            return Err(Error::InvalidConfig("epochs must be > 0".to_owned()));
213        }
214        if !(cfg.lr.is_finite() && cfg.lr > 0.0) {
215            return Err(Error::InvalidConfig("lr must be finite and > 0".to_owned()));
216        }
217        if cfg.batch_size == 0 {
218            return Err(Error::InvalidConfig("batch_size must be > 0".to_owned()));
219        }
220
221        cfg.lr_schedule.validate()?;
222
223        cfg.optimizer.validate()?;
224        if !(cfg.weight_decay.is_finite() && cfg.weight_decay >= 0.0) {
225            return Err(Error::InvalidConfig(
226                "weight_decay must be finite and >= 0".to_owned(),
227            ));
228        }
229        if let Some(v) = cfg.grad_clip_norm
230            && !(v.is_finite() && v > 0.0)
231        {
232            return Err(Error::InvalidConfig(
233                "grad_clip_norm must be finite and > 0".to_owned(),
234            ));
235        }
236
237        let mut opt_state: OptimizerState = cfg.optimizer.state(self)?;
238        let mut trainer = Trainer::new(self);
239        let mut batch_scratch = if cfg.batch_size > 1 {
240            Some(self.scratch_batch(cfg.batch_size))
241        } else {
242            None
243        };
244        let mut batch_backprop = if cfg.batch_size > 1 {
245            Some(self.backprop_scratch_batch(cfg.batch_size))
246        } else {
247            None
248        };
249        let mut d_outputs_batch = if cfg.batch_size > 1 {
250            Some(vec![0.0_f32; cfg.batch_size * self.output_dim()])
251        } else {
252            None
253        };
254        let mut gather_inputs = if cfg.batch_size > 1 {
255            match cfg.shuffle {
256                Shuffle::None => None,
257                Shuffle::Seeded(_) => Some(vec![0.0_f32; cfg.batch_size * self.input_dim()]),
258            }
259        } else {
260            None
261        };
262        let mut reports = Vec::with_capacity(cfg.epochs);
263
264        // Only allocate an indices buffer if we need shuffling.
265        let mut indices: Vec<usize> = match cfg.shuffle {
266            Shuffle::None => Vec::new(),
267            Shuffle::Seeded(_) => (0..train.len()).collect(),
268        };
269
270        let mut rng = match cfg.shuffle {
271            Shuffle::None => None,
272            Shuffle::Seeded(seed) => Some(StdRng::seed_from_u64(seed)),
273        };
274
275        for epoch in 0..cfg.epochs {
276            let epoch_lr = cfg.lr_schedule.lr_at_epoch(cfg.lr, epoch, cfg.epochs);
277            debug_assert!(epoch_lr.is_finite() && epoch_lr > 0.0);
278
279            let mut epoch_loss = 0.0_f32;
280            let mut metric_acc = MetricsAccum::new(self.output_dim(), &cfg.metrics)?;
281
282            match cfg.shuffle {
283                Shuffle::None => {
284                    if cfg.batch_size == 1 {
285                        for idx in 0..train.len() {
286                            let input = train.input(idx);
287                            let target = train.target(idx);
288
289                            self.forward(input, &mut trainer.scratch);
290                            let pred = trainer.scratch.output();
291
292                            let loss_val =
293                                cfg.loss
294                                    .backward(pred, target, trainer.grads.d_output_mut());
295                            epoch_loss += loss_val;
296                            metric_acc.update(pred, target)?;
297
298                            self.backward(input, &trainer.scratch, &mut trainer.grads);
299
300                            if let Some(max_norm) = cfg.grad_clip_norm {
301                                trainer.grads.clip_global_norm_params(max_norm);
302                            }
303                            self.apply_weight_decay(epoch_lr, cfg.weight_decay);
304                            opt_state.step(self, &mut trainer.grads, epoch_lr);
305                        }
306                    } else {
307                        for batch in train.batches(cfg.batch_size) {
308                            // Batched fast path for full-size batches.
309                            if batch.len() == cfg.batch_size {
310                                let bs = batch_scratch.as_mut().expect("batch_scratch must exist");
311                                let bb =
312                                    batch_backprop.as_mut().expect("batch_backprop must exist");
313                                let d_out = d_outputs_batch
314                                    .as_mut()
315                                    .expect("d_outputs_batch must exist");
316
317                                self.forward_batch(batch.inputs_flat(), bs);
318                                let preds = bs.output();
319
320                                for b in 0..batch.len() {
321                                    let pred =
322                                        &preds[b * self.output_dim()..(b + 1) * self.output_dim()];
323                                    let target = batch.target(b);
324                                    let d_slice = &mut d_out
325                                        [b * self.output_dim()..(b + 1) * self.output_dim()];
326                                    let loss_val = cfg.loss.backward(pred, target, d_slice);
327                                    epoch_loss += loss_val;
328                                    metric_acc.update(pred, target)?;
329                                }
330
331                                self.backward_batch(
332                                    batch.inputs_flat(),
333                                    bs,
334                                    d_out,
335                                    &mut trainer.grads,
336                                    bb,
337                                );
338                            } else {
339                                // Remainder batch: fall back to per-sample accumulation.
340                                trainer.grads.zero_params();
341                                for b in 0..batch.len() {
342                                    let input = batch.input(b);
343                                    let target = batch.target(b);
344
345                                    self.forward(input, &mut trainer.scratch);
346                                    let pred = trainer.scratch.output();
347
348                                    let loss_val = cfg.loss.backward(
349                                        pred,
350                                        target,
351                                        trainer.grads.d_output_mut(),
352                                    );
353                                    epoch_loss += loss_val;
354                                    metric_acc.update(pred, target)?;
355
356                                    self.backward_accumulate(
357                                        input,
358                                        &trainer.scratch,
359                                        &mut trainer.grads,
360                                    );
361                                }
362                                trainer.grads.scale_params(1.0 / batch.len() as f32);
363                            }
364
365                            if let Some(max_norm) = cfg.grad_clip_norm {
366                                trainer.grads.clip_global_norm_params(max_norm);
367                            }
368                            self.apply_weight_decay(epoch_lr, cfg.weight_decay);
369                            opt_state.step(self, &mut trainer.grads, epoch_lr);
370                        }
371                    }
372                }
373                Shuffle::Seeded(_) => {
374                    let rng = rng.as_mut().expect("rng must be initialized for shuffling");
375                    indices.shuffle(rng);
376
377                    if cfg.batch_size == 1 {
378                        for &idx in &indices {
379                            let input = train.input(idx);
380                            let target = train.target(idx);
381
382                            self.forward(input, &mut trainer.scratch);
383                            let pred = trainer.scratch.output();
384
385                            let loss_val =
386                                cfg.loss
387                                    .backward(pred, target, trainer.grads.d_output_mut());
388                            epoch_loss += loss_val;
389                            metric_acc.update(pred, target)?;
390
391                            self.backward(input, &trainer.scratch, &mut trainer.grads);
392
393                            if let Some(max_norm) = cfg.grad_clip_norm {
394                                trainer.grads.clip_global_norm_params(max_norm);
395                            }
396                            self.apply_weight_decay(epoch_lr, cfg.weight_decay);
397                            opt_state.step(self, &mut trainer.grads, epoch_lr);
398                        }
399                    } else {
400                        for batch in indices.chunks(cfg.batch_size) {
401                            // Batched fast path for full-size batches: gather inputs into a
402                            // contiguous buffer, then run GEMM-based forward/backward.
403                            if batch.len() == cfg.batch_size {
404                                let bs = batch_scratch.as_mut().expect("batch_scratch must exist");
405                                let bb =
406                                    batch_backprop.as_mut().expect("batch_backprop must exist");
407                                let d_out = d_outputs_batch
408                                    .as_mut()
409                                    .expect("d_outputs_batch must exist");
410                                let x_gather =
411                                    gather_inputs.as_mut().expect("gather_inputs must exist");
412
413                                let in_dim = self.input_dim();
414                                let out_dim = self.output_dim();
415                                debug_assert_eq!(x_gather.len(), cfg.batch_size * in_dim);
416                                debug_assert_eq!(d_out.len(), cfg.batch_size * out_dim);
417
418                                for (b, &idx) in batch.iter().enumerate() {
419                                    let x = train.input(idx);
420                                    let x0 = b * in_dim;
421                                    x_gather[x0..x0 + in_dim].copy_from_slice(x);
422                                }
423
424                                self.forward_batch(x_gather, bs);
425                                let preds = bs.output();
426
427                                for (b, &idx) in batch.iter().enumerate() {
428                                    let pred = &preds[b * out_dim..(b + 1) * out_dim];
429                                    let target = train.target(idx);
430                                    let d_slice = &mut d_out[b * out_dim..(b + 1) * out_dim];
431
432                                    let loss_val = cfg.loss.backward(pred, target, d_slice);
433                                    epoch_loss += loss_val;
434                                    metric_acc.update(pred, target)?;
435                                }
436
437                                self.backward_batch(x_gather, bs, d_out, &mut trainer.grads, bb);
438                            } else {
439                                // Remainder batch: fall back to per-sample accumulation.
440                                trainer.grads.zero_params();
441
442                                for &idx in batch {
443                                    let input = train.input(idx);
444                                    let target = train.target(idx);
445
446                                    self.forward(input, &mut trainer.scratch);
447                                    let pred = trainer.scratch.output();
448
449                                    let loss_val = cfg.loss.backward(
450                                        pred,
451                                        target,
452                                        trainer.grads.d_output_mut(),
453                                    );
454                                    epoch_loss += loss_val;
455                                    metric_acc.update(pred, target)?;
456
457                                    self.backward_accumulate(
458                                        input,
459                                        &trainer.scratch,
460                                        &mut trainer.grads,
461                                    );
462                                }
463
464                                trainer.grads.scale_params(1.0 / batch.len() as f32);
465                            }
466
467                            if let Some(max_norm) = cfg.grad_clip_norm {
468                                trainer.grads.clip_global_norm_params(max_norm);
469                            }
470                            self.apply_weight_decay(epoch_lr, cfg.weight_decay);
471                            opt_state.step(self, &mut trainer.grads, epoch_lr);
472                        }
473                    }
474                }
475            }
476
477            let inv_n = 1.0 / train.len() as f32;
478            let train_report = EvalReport::new(epoch_loss * inv_n, metric_acc.finish(train.len()));
479            let val_report = match val {
480                Some(v) => Some(self.evaluate(v, cfg.loss, &cfg.metrics)?),
481                None => None,
482            };
483
484            reports.push(EpochReport {
485                train: train_report,
486                val: val_report,
487            });
488        }
489
490        Ok(FitReport { epochs: reports })
491    }
492
493    /// Predict outputs for all inputs in `data`.
494    ///
495    /// Returns a flat buffer with shape `(len, output_dim)`.
496    pub fn predict(&self, data: &Dataset) -> Result<Vec<f32>> {
497        if data.is_empty() {
498            return Err(Error::InvalidData("dataset must not be empty".to_owned()));
499        }
500        if data.input_dim() != self.input_dim() {
501            return Err(Error::InvalidData(format!(
502                "dataset input_dim {} does not match model input_dim {}",
503                data.input_dim(),
504                self.input_dim()
505            )));
506        }
507
508        let mut scratch = self.scratch();
509        let out_dim = self.output_dim();
510        let mut preds = vec![0.0_f32; data.len() * out_dim];
511
512        for idx in 0..data.len() {
513            let input = data.input(idx);
514            let y = self.forward(input, &mut scratch);
515            let start = idx * out_dim;
516            preds[start..start + out_dim].copy_from_slice(y);
517        }
518
519        Ok(preds)
520    }
521
522    /// Predict outputs for inputs (X).
523    ///
524    /// Returns a flat buffer with shape `(len, output_dim)`.
525    pub fn predict_inputs(&self, inputs: &crate::Inputs) -> Result<Vec<f32>> {
526        if inputs.is_empty() {
527            return Err(Error::InvalidData("inputs must not be empty".to_owned()));
528        }
529        if inputs.input_dim() != self.input_dim() {
530            return Err(Error::InvalidData(format!(
531                "inputs input_dim {} does not match model input_dim {}",
532                inputs.input_dim(),
533                self.input_dim()
534            )));
535        }
536
537        let mut scratch = self.scratch();
538        let out_dim = self.output_dim();
539        let mut preds = vec![0.0_f32; inputs.len() * out_dim];
540
541        for idx in 0..inputs.len() {
542            let x = inputs.input(idx);
543            let y = self.forward(x, &mut scratch);
544            let start = idx * out_dim;
545            preds[start..start + out_dim].copy_from_slice(y);
546        }
547
548        Ok(preds)
549    }
550
551    /// Evaluate mean MSE over a dataset.
552    ///
553    /// This is a convenience wrapper around `evaluate`.
554    pub fn evaluate_mse(&self, data: &Dataset) -> Result<f32> {
555        if data.is_empty() {
556            return Err(Error::InvalidData("dataset must not be empty".to_owned()));
557        }
558        Ok(self.evaluate(data, Loss::Mse, &[])?.loss)
559    }
560}
561
562fn validate_dataset_shapes(model: &Mlp, data: &Dataset) -> Result<()> {
563    if data.input_dim() != model.input_dim() {
564        return Err(Error::InvalidData(format!(
565            "dataset input_dim {} does not match model input_dim {}",
566            data.input_dim(),
567            model.input_dim()
568        )));
569    }
570    if data.target_dim() != model.output_dim() {
571        return Err(Error::InvalidData(format!(
572            "dataset target_dim {} does not match model output_dim {}",
573            data.target_dim(),
574            model.output_dim()
575        )));
576    }
577    Ok(())
578}
579
580fn validate_loss_compat(model: &Mlp, loss_fn: Loss, target_dim: usize) -> Result<()> {
581    loss_fn.validate()?;
582
583    match loss_fn {
584        Loss::Mse | Loss::Mae => Ok(()),
585        Loss::BinaryCrossEntropyWithLogits => {
586            if target_dim != 1 {
587                return Err(Error::InvalidConfig(format!(
588                    "BinaryCrossEntropyWithLogits requires output_dim == 1, got {target_dim}"
589                )));
590            }
591            let last = last_layer_activation(model);
592            if last != Activation::Identity {
593                return Err(Error::InvalidConfig(
594                    "BinaryCrossEntropyWithLogits expects raw logits; set the output layer activation to Identity"
595                        .to_owned(),
596                ));
597            }
598            Ok(())
599        }
600        Loss::SoftmaxCrossEntropy => {
601            if target_dim < 2 {
602                return Err(Error::InvalidConfig(format!(
603                    "SoftmaxCrossEntropy requires output_dim >= 2, got {target_dim}"
604                )));
605            }
606            let last = last_layer_activation(model);
607            if last != Activation::Identity {
608                return Err(Error::InvalidConfig(
609                    "SoftmaxCrossEntropy expects raw logits; set the output layer activation to Identity".to_owned(),
610                ));
611            }
612            Ok(())
613        }
614    }
615}
616
617fn last_layer_activation(model: &Mlp) -> Activation {
618    // `Mlp` is guaranteed to have at least one layer when constructed via `MlpBuilder`.
619    last_layer(model)
620        .expect("mlp must have at least one layer")
621        .activation()
622}
623
624fn last_layer(model: &Mlp) -> Option<&Layer> {
625    // We intentionally keep `Mlp`'s internal layout private. This helper uses a
626    // public accessor to inspect the last layer when validating logits-based losses.
627    model.layer(model.num_layers().checked_sub(1)?)
628}
629
630struct MetricsAccum {
631    output_dim: usize,
632    metrics: Vec<Metric>,
633    sums: Vec<f32>,
634}
635
636impl MetricsAccum {
637    fn new(output_dim: usize, metrics: &[Metric]) -> Result<Self> {
638        let mut ms = Vec::with_capacity(metrics.len());
639        for &m in metrics {
640            m.validate()?;
641            ms.push(m);
642        }
643        Ok(Self {
644            output_dim,
645            metrics: ms,
646            sums: vec![0.0; metrics.len()],
647        })
648    }
649
650    fn update(&mut self, pred: &[f32], target: &[f32]) -> Result<()> {
651        if self.metrics.is_empty() {
652            return Ok(());
653        }
654        if pred.len() != target.len() {
655            return Err(Error::InvalidData(format!(
656                "pred/target length mismatch: {} vs {}",
657                pred.len(),
658                target.len()
659            )));
660        }
661        if pred.len() != self.output_dim {
662            return Err(Error::InvalidData(format!(
663                "pred len {} does not match expected output_dim {}",
664                pred.len(),
665                self.output_dim
666            )));
667        }
668
669        for (idx, &m) in self.metrics.iter().enumerate() {
670            self.sums[idx] += metric_value(m, pred, target)?;
671        }
672        Ok(())
673    }
674
675    fn finish(self, n: usize) -> Vec<(Metric, f32)> {
676        if self.metrics.is_empty() {
677            return Vec::new();
678        }
679
680        let inv_n = 1.0 / n as f32;
681        self.metrics
682            .into_iter()
683            .zip(self.sums)
684            .map(|(m, s)| (m, s * inv_n))
685            .collect()
686    }
687}
688
689fn metric_value(metric: Metric, pred: &[f32], target: &[f32]) -> Result<f32> {
690    match metric {
691        Metric::Mse => Ok(loss::mse(pred, target)),
692        Metric::Mae => Ok(loss::mae(pred, target)),
693        Metric::Accuracy => Ok(accuracy(pred, target)?),
694        Metric::TopKAccuracy { k } => Ok(topk_accuracy(pred, target, k)?),
695    }
696}
697
698fn accuracy(pred: &[f32], target: &[f32]) -> Result<f32> {
699    if pred.len() != target.len() {
700        return Err(Error::InvalidData(format!(
701            "pred/target length mismatch: {} vs {}",
702            pred.len(),
703            target.len()
704        )));
705    }
706    if pred.is_empty() {
707        return Ok(0.0);
708    }
709
710    if pred.len() == 1 {
711        // Binary accuracy.
712        let y = pred[0];
713        let t = target[0];
714        let pred_label = if y >= 0.5 { 1 } else { 0 };
715        let true_label = if t >= 0.5 { 1 } else { 0 };
716        Ok(if pred_label == true_label { 1.0 } else { 0.0 })
717    } else {
718        // Multiclass (argmax).
719        let pred_idx = argmax(pred);
720        let true_idx = argmax(target);
721        Ok(if pred_idx == true_idx { 1.0 } else { 0.0 })
722    }
723}
724
725fn topk_accuracy(pred: &[f32], target: &[f32], k: usize) -> Result<f32> {
726    if pred.len() != target.len() {
727        return Err(Error::InvalidData(format!(
728            "pred/target length mismatch: {} vs {}",
729            pred.len(),
730            target.len()
731        )));
732    }
733    if pred.len() <= 1 {
734        return Err(Error::InvalidConfig(
735            "TopKAccuracy requires output_dim > 1".to_owned(),
736        ));
737    }
738    if k == 0 || k > pred.len() {
739        return Err(Error::InvalidConfig(format!(
740            "TopKAccuracy requires 1 <= k <= output_dim, got k={k} output_dim={}",
741            pred.len()
742        )));
743    }
744
745    let true_idx = argmax(target);
746
747    // Find if true_idx is in top-k of pred without allocating:
748    // Count how many logits are strictly greater than pred[true_idx].
749    let true_score = pred[true_idx];
750    let mut num_greater = 0_usize;
751    for (i, &v) in pred.iter().enumerate() {
752        if i != true_idx && v > true_score {
753            num_greater += 1;
754        }
755    }
756    Ok(if num_greater < k { 1.0 } else { 0.0 })
757}
758
759fn argmax(xs: &[f32]) -> usize {
760    debug_assert!(!xs.is_empty());
761    let mut best_idx = 0;
762    let mut best_val = xs[0];
763    for (i, &v) in xs.iter().enumerate().skip(1) {
764        if v > best_val {
765            best_val = v;
766            best_idx = i;
767        }
768    }
769    best_idx
770}
771
772#[cfg(test)]
773mod tests {
774    use crate::{Activation, Dataset, Loss, Metric, MlpBuilder};
775
776    use super::Shuffle;
777
778    #[test]
779    fn evaluate_computes_accuracy_for_multiclass_one_hot() {
780        // Make a tiny dataset where the model is forced to output logits we can control.
781        // We'll build a 2 -> 3 identity-ish model.
782        let mlp = MlpBuilder::new(2)
783            .unwrap()
784            .add_layer(3, Activation::Identity)
785            .unwrap()
786            .build_with_seed(0)
787            .unwrap();
788
789        // Create data: two samples, one-hot targets.
790        let xs = vec![vec![0.0, 0.0], vec![1.0, 1.0]];
791        let ys = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
792        let data = Dataset::from_rows(&xs, &ys).unwrap();
793
794        // We cannot currently mutate parameters through the public API.
795        // This test focuses on metric shape handling; we still should be able to run evaluate.
796        let report = mlp
797            .evaluate(&data, Loss::SoftmaxCrossEntropy, &[Metric::Accuracy])
798            .unwrap();
799        assert_eq!(report.metrics.len(), 1);
800    }
801
802    #[test]
803    fn shuffle_seeded_is_deterministic() {
804        let mut a = MlpBuilder::new(2)
805            .unwrap()
806            .add_layer(4, Activation::Tanh)
807            .unwrap()
808            .add_layer(1, Activation::Identity)
809            .unwrap()
810            .build_with_seed(0)
811            .unwrap();
812        let mut b = a.clone();
813
814        // Tiny regression dataset.
815        let xs = vec![
816            vec![0.0, 0.0],
817            vec![0.0, 1.0],
818            vec![1.0, 0.0],
819            vec![1.0, 1.0],
820            vec![2.0, 0.0],
821        ];
822        let ys = vec![vec![0.0], vec![1.0], vec![1.0], vec![2.0], vec![2.0]];
823        let train = Dataset::from_rows(&xs, &ys).unwrap();
824
825        let cfg = super::FitConfig {
826            epochs: 10,
827            lr: 0.05,
828            batch_size: 2,
829            shuffle: Shuffle::Seeded(123),
830            lr_schedule: super::LrSchedule::Constant,
831            optimizer: crate::Optimizer::Sgd,
832            weight_decay: 0.0,
833            grad_clip_norm: None,
834            loss: Loss::Mse,
835            metrics: vec![],
836        };
837
838        let rep_a = a.fit(&train, None, cfg.clone()).unwrap();
839        let rep_b = b.fit(&train, None, cfg).unwrap();
840
841        let last_a = rep_a.epochs.last().unwrap().train.loss;
842        let last_b = rep_b.epochs.last().unwrap().train.loss;
843        assert_eq!(last_a.to_bits(), last_b.to_bits());
844    }
845}