Skip to main content

scirs2_core/array_protocol/
training.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under the Apache License, Version 2.0
4// (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
5//
6
7//! Training utilities for neural networks using the array protocol.
8//!
9//! This module provides utilities for training neural networks using the
10//! array protocol, including datasets, dataloaders, loss functions, and
11//! training loops.
12
13use std::fmt;
14use std::time::Instant;
15
16use ::ndarray::{Array, Array0, Dimension};
17use rand::seq::SliceRandom;
18use rand::{Rng, RngExt, SeedableRng};
19
20use crate::array_protocol::grad::{GradientDict, Optimizer};
21use crate::array_protocol::ml_ops::ActivationFunc;
22use crate::array_protocol::neural::Sequential;
23use crate::array_protocol::operations::{multiply, subtract};
24use crate::array_protocol::{activation, ArrayProtocol, NdarrayWrapper};
25use crate::error::{CoreError, CoreResult, ErrorContext};
26
27/// Type alias for batch data
28pub type BatchData = (Vec<Box<dyn ArrayProtocol>>, Vec<Box<dyn ArrayProtocol>>);
29
30/// Dataset trait for providing data samples.
31pub trait Dataset {
32    /// Get the number of samples in the dataset.
33    fn len(&self) -> usize;
34
35    /// Check if the dataset is empty.
36    fn is_empty(&self) -> bool {
37        self.len() == 0
38    }
39
40    /// Get a sample from the dataset by index.
41    fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)>;
42
43    /// Get the input shape of the dataset.
44    fn inputshape(&self) -> Vec<usize>;
45
46    /// Get the output shape of the dataset.
47    fn outputshape(&self) -> Vec<usize>;
48}
49
50/// In-memory dataset with arrays.
51pub struct InMemoryDataset {
52    /// Input data samples.
53    inputs: Vec<Box<dyn ArrayProtocol>>,
54
55    /// Target output samples.
56    targets: Vec<Box<dyn ArrayProtocol>>,
57
58    /// Input shape.
59    inputshape: Vec<usize>,
60
61    /// Output shape.
62    outputshape: Vec<usize>,
63}
64
65impl InMemoryDataset {
66    /// Create a new in-memory dataset.
67    pub fn new(
68        inputs: Vec<Box<dyn ArrayProtocol>>,
69        targets: Vec<Box<dyn ArrayProtocol>>,
70        inputshape: Vec<usize>,
71        outputshape: Vec<usize>,
72    ) -> Self {
73        assert_eq!(
74            inputs.len(),
75            targets.len(),
76            "Inputs and targets must have the same length"
77        );
78
79        Self {
80            inputs,
81            targets,
82            inputshape,
83            outputshape,
84        }
85    }
86
87    /// Create an in-memory dataset from arrays.
88    pub fn from_arrays<T, D1, D2>(inputs: Array<T, D1>, targets: Array<T, D2>) -> Self
89    where
90        T: Clone + Send + Sync + 'static,
91        D1: Dimension + Send + Sync,
92        D2: Dimension + Send + Sync,
93    {
94        let inputshape = inputs.shape().to_vec();
95        let outputshape = targets.shape().to_vec();
96
97        // Handle batched _inputs
98        let num_samples = inputshape[0];
99        assert_eq!(
100            num_samples, outputshape[0],
101            "Inputs and targets must have the same number of samples"
102        );
103
104        let mut input_samples = Vec::with_capacity(num_samples);
105        let mut target_samples = Vec::with_capacity(num_samples);
106
107        // Create dynamic arrays with the appropriate shape to handle arbitrary dimensions
108        let to_dyn_inputs = inputs.into_dyn();
109        let to_dyn_targets = targets.into_dyn();
110
111        for i in 0..num_samples {
112            // Use index_axis instead of slice for better compatibility with different dimensions
113            let input_view = to_dyn_inputs.index_axis(crate::ndarray::Axis(0), i);
114            let inputarray = input_view.to_owned();
115            input_samples.push(Box::new(NdarrayWrapper::new(inputarray)) as Box<dyn ArrayProtocol>);
116
117            let target_view = to_dyn_targets.index_axis(crate::ndarray::Axis(0), i);
118            let target_array = target_view.to_owned();
119            target_samples
120                .push(Box::new(NdarrayWrapper::new(target_array)) as Box<dyn ArrayProtocol>);
121        }
122
123        Self {
124            inputs: input_samples,
125            targets: target_samples,
126            inputshape: inputshape[1..].to_vec(),
127            outputshape: outputshape[1..].to_vec(),
128        }
129    }
130}
131
132impl Dataset for InMemoryDataset {
133    fn len(&self) -> usize {
134        self.inputs.len()
135    }
136
137    fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)> {
138        if index >= self.len() {
139            return None;
140        }
141
142        Some((self.inputs[index].clone(), self.targets[index].clone()))
143    }
144
145    fn inputshape(&self) -> Vec<usize> {
146        self.inputshape.clone()
147    }
148
149    fn outputshape(&self) -> Vec<usize> {
150        self.outputshape.clone()
151    }
152}
153
154/// Data loader for batching and shuffling datasets.
155pub struct DataLoader {
156    /// The dataset to load from.
157    dataset: Box<dyn Dataset>,
158
159    /// Batch size.
160    batch_size: usize,
161
162    /// Whether to shuffle the dataset.
163    shuffle: bool,
164
165    /// Random number generator seed.
166    seed: Option<u64>,
167
168    /// Indices of the dataset.
169    indices: Vec<usize>,
170
171    /// Current position in the dataset.
172    position: usize,
173}
174
175impl DataLoader {
176    /// Create a new data loader.
177    pub fn new(
178        dataset: Box<dyn Dataset>,
179        batch_size: usize,
180        shuffle: bool,
181        seed: Option<u64>,
182    ) -> Self {
183        let indices = (0..dataset.len()).collect();
184
185        Self {
186            dataset,
187            batch_size,
188            shuffle,
189            seed,
190            indices,
191            position: 0,
192        }
193    }
194
195    /// Reset the data loader.
196    pub fn reset(&mut self) {
197        self.position = 0;
198
199        if self.shuffle {
200            let mut rng = match self.seed {
201                Some(s) => rand::rngs::StdRng::seed_from_u64(s),
202                None => {
203                    let mut rng = rand::rng();
204                    // Get a random seed from rng and create a new StdRng
205                    let random_seed: u64 = rng.random();
206                    rand::rngs::StdRng::seed_from_u64(random_seed)
207                }
208            };
209
210            self.indices.shuffle(&mut rng);
211        }
212    }
213
214    /// Get the next batch from the dataset.
215    pub fn next_batch(&mut self) -> Option<BatchData> {
216        if self.position >= self.dataset.len() {
217            return None;
218        }
219
220        // Determine how many samples to take
221        let remaining = self.dataset.len() - self.position;
222        let batch_size = std::cmp::min(self.batch_size, remaining);
223
224        // Get the batch
225        let mut inputs = Vec::with_capacity(batch_size);
226        let mut targets = Vec::with_capacity(batch_size);
227
228        for i in 0..batch_size {
229            let index = self.indices[self.position + i];
230            if let Some((input, target)) = self.dataset.get(index) {
231                inputs.push(input);
232                targets.push(target);
233            }
234        }
235
236        // Update position
237        self.position += batch_size;
238
239        Some((inputs, targets))
240    }
241
242    /// Get the number of batches in the dataset.
243    pub fn numbatches(&self) -> usize {
244        self.dataset.len().div_ceil(self.batch_size)
245    }
246
247    /// Get a reference to the dataset.
248    pub fn dataset(&self) -> &dyn Dataset {
249        self.dataset.as_ref()
250    }
251}
252
253/// Iterator implementation for DataLoader.
254impl Iterator for DataLoader {
255    type Item = BatchData;
256
257    fn next(&mut self) -> Option<Self::Item> {
258        self.next_batch()
259    }
260}
261
262/// Loss function trait.
263pub trait Loss {
264    /// Compute the loss between predictions and targets.
265    fn forward(
266        &self,
267        predictions: &dyn ArrayProtocol,
268        targets: &dyn ArrayProtocol,
269    ) -> CoreResult<Box<dyn ArrayProtocol>>;
270
271    /// Compute the gradient of the loss with respect to predictions.
272    fn backward(
273        &self,
274        predictions: &dyn ArrayProtocol,
275        targets: &dyn ArrayProtocol,
276    ) -> CoreResult<Box<dyn ArrayProtocol>>;
277
278    /// Get the name of the loss function.
279    fn name(&self) -> &str;
280}
281
282/// Mean squared error loss.
283pub struct MSELoss {
284    /// Name of the loss function.
285    name: String,
286
287    /// Whether to reduce the loss.
288    reduction: String,
289}
290
291impl MSELoss {
292    /// Create a new MSE loss.
293    pub fn new(reduction: Option<&str>) -> Self {
294        Self {
295            name: "MSELoss".to_string(),
296            reduction: reduction.unwrap_or("mean").to_string(),
297        }
298    }
299}
300
301impl Loss for MSELoss {
302    fn forward(
303        &self,
304        predictions: &dyn ArrayProtocol,
305        targets: &dyn ArrayProtocol,
306    ) -> CoreResult<Box<dyn ArrayProtocol>> {
307        // Compute squared difference
308        let diff = subtract(predictions, targets)?;
309        let squared = multiply(diff.as_ref(), diff.as_ref())?;
310
311        // Apply reduction
312        match self.reduction.as_str() {
313            "none" => Ok(squared),
314            "mean" => {
315                // Compute mean of all elements
316                if let Some(array) = squared
317                    .as_any()
318                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
319                {
320                    let mean = array.as_array().mean().expect("Operation failed");
321                    let result = Array0::<f64>::from_elem((), mean);
322                    Ok(Box::new(NdarrayWrapper::new(result)))
323                } else {
324                    Err(CoreError::NotImplementedError(ErrorContext::new(
325                        "Mean reduction not implemented for this array type".to_string(),
326                    )))
327                }
328            }
329            "sum" => {
330                // Compute sum of all elements
331                if let Some(array) = squared
332                    .as_any()
333                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
334                {
335                    let sum = array.as_array().sum();
336                    let result = Array0::<f64>::from_elem((), sum);
337                    Ok(Box::new(NdarrayWrapper::new(result)))
338                } else {
339                    Err(CoreError::NotImplementedError(ErrorContext::new(
340                        "Sum reduction not implemented for this array type".to_string(),
341                    )))
342                }
343            }
344            _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
345                "Unknown reduction: {reduction}",
346                reduction = self.reduction
347            )))),
348        }
349    }
350
351    fn backward(
352        &self,
353        predictions: &dyn ArrayProtocol,
354        targets: &dyn ArrayProtocol,
355    ) -> CoreResult<Box<dyn ArrayProtocol>> {
356        // Gradient of MSE loss: 2 * (predictions - targets)
357        let diff = subtract(predictions, targets)?;
358        let factor = Box::new(NdarrayWrapper::new(
359            crate::ndarray::Array0::<f64>::from_elem((), 2.0),
360        ));
361        let grad = multiply(factor.as_ref(), diff.as_ref())?;
362
363        // Apply reduction scaling if needed
364        match self.reduction.as_str() {
365            "none" => Ok(grad),
366            "mean" => {
367                // For mean reduction, scale by 1/N
368                if let Some(array) = grad
369                    .as_any()
370                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
371                {
372                    let n = array.as_array().len() as f64;
373                    let scale_factor = Box::new(NdarrayWrapper::new(
374                        crate::ndarray::Array0::<f64>::from_elem((), 1.0 / n),
375                    ));
376                    Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
377                } else {
378                    Ok(grad)
379                }
380            }
381            "sum" => Ok(grad),
382            _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
383                "Unknown reduction: {reduction}",
384                reduction = self.reduction
385            )))),
386        }
387    }
388
389    fn name(&self) -> &str {
390        &self.name
391    }
392}
393
394/// Cross-entropy loss.
395pub struct CrossEntropyLoss {
396    /// Name of the loss function.
397    name: String,
398
399    /// Whether to reduce the loss.
400    reduction: String,
401}
402
403impl CrossEntropyLoss {
404    /// Create a new cross-entropy loss.
405    pub fn new(reduction: Option<&str>) -> Self {
406        Self {
407            name: "CrossEntropyLoss".to_string(),
408            reduction: reduction.unwrap_or("mean").to_string(),
409        }
410    }
411}
412
413impl Loss for CrossEntropyLoss {
414    fn forward(
415        &self,
416        predictions: &dyn ArrayProtocol,
417        targets: &dyn ArrayProtocol,
418    ) -> CoreResult<Box<dyn ArrayProtocol>> {
419        // Apply softmax to predictions
420        let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
421
422        // Compute cross-entropy
423        if let (Some(preds_array), Some(targets_array)) = (
424            softmax_preds
425                .as_any()
426                .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>(),
427            targets
428                .as_any()
429                .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>(),
430        ) {
431            let preds = preds_array.as_array();
432            let targets = targets_array.as_array();
433
434            // Compute -targets * log(preds)
435            let log_preds = preds.mapv(|x| x.max(1e-10).ln());
436
437            // Compute element-wise multiplication and then negate
438            let mut losses = targets.clone();
439            losses.zip_mut_with(&log_preds, |t, l| *t = -(*t * *l));
440
441            // Apply reduction
442            match self.reduction.as_str() {
443                "none" => Ok(Box::new(NdarrayWrapper::new(losses))),
444                "mean" => {
445                    let mean = losses.mean().expect("Operation failed");
446                    let result = Array0::<f64>::from_elem((), mean);
447                    Ok(Box::new(NdarrayWrapper::new(result)))
448                }
449                "sum" => {
450                    let sum = losses.sum();
451                    let result = Array0::<f64>::from_elem((), sum);
452                    Ok(Box::new(NdarrayWrapper::new(result)))
453                }
454                _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
455                    "Unknown reduction: {reduction}",
456                    reduction = self.reduction
457                )))),
458            }
459        } else {
460            Err(CoreError::NotImplementedError(ErrorContext::new(
461                "CrossEntropy not implemented for these array types".to_string(),
462            )))
463        }
464    }
465
466    fn backward(
467        &self,
468        predictions: &dyn ArrayProtocol,
469        targets: &dyn ArrayProtocol,
470    ) -> CoreResult<Box<dyn ArrayProtocol>> {
471        // For cross-entropy with softmax: gradient is softmax(predictions) - targets
472        let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
473        let grad = subtract(softmax_preds.as_ref(), targets)?;
474
475        // Apply reduction scaling if needed
476        match self.reduction.as_str() {
477            "none" => Ok(grad),
478            "mean" => {
479                // For mean reduction, scale by 1/N
480                if let Some(array) = grad
481                    .as_any()
482                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
483                {
484                    let n = array.as_array().len() as f64;
485                    let scale_factor = Box::new(NdarrayWrapper::new(
486                        crate::ndarray::Array0::<f64>::from_elem((), 1.0 / n),
487                    ));
488                    Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
489                } else {
490                    Ok(grad)
491                }
492            }
493            "sum" => Ok(grad),
494            _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
495                "Unknown reduction: {reduction}",
496                reduction = self.reduction
497            )))),
498        }
499    }
500
501    fn name(&self) -> &str {
502        &self.name
503    }
504}
505
506/// Metrics for evaluating model performance.
507pub struct Metrics {
508    /// Loss values.
509    losses: Vec<f64>,
510
511    /// Accuracy values (if applicable).
512    accuracies: Option<Vec<f64>>,
513
514    /// Name of the metrics object.
515    name: String,
516}
517
518impl Metrics {
519    /// Create a new metrics object.
520    pub fn new(name: &str) -> Self {
521        Self {
522            losses: Vec::new(),
523            accuracies: None,
524            name: name.to_string(),
525        }
526    }
527
528    /// Add a loss value.
529    pub fn add_loss(&mut self, loss: f64) {
530        self.losses.push(loss);
531    }
532
533    /// Add an accuracy value.
534    pub fn add_accuracy(&mut self, accuracy: f64) {
535        if self.accuracies.is_none() {
536            self.accuracies = Some(Vec::new());
537        }
538
539        if let Some(accuracies) = &mut self.accuracies {
540            accuracies.push(accuracy);
541        }
542    }
543
544    /// Get the mean loss.
545    pub fn mean_loss(&self) -> Option<f64> {
546        if self.losses.is_empty() {
547            return None;
548        }
549
550        let sum: f64 = self.losses.iter().sum();
551        Some(sum / self.losses.len() as f64)
552    }
553
554    /// Get the mean accuracy.
555    pub fn mean_accuracy(&self) -> Option<f64> {
556        if let Some(accuracies) = &self.accuracies {
557            if accuracies.is_empty() {
558                return None;
559            }
560
561            let sum: f64 = accuracies.iter().sum();
562            Some(sum / accuracies.len() as f64)
563        } else {
564            None
565        }
566    }
567
568    /// Reset the metrics.
569    pub fn reset(&mut self) {
570        self.losses.clear();
571        if let Some(accuracies) = &mut self.accuracies {
572            accuracies.clear();
573        }
574    }
575
576    /// Get the name of the metrics object.
577    pub fn name(&self) -> &str {
578        &self.name
579    }
580}
581
582impl fmt::Display for Metrics {
583    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
584        write!(
585            f,
586            "{}: loss = {:.4}",
587            self.name,
588            self.mean_loss().unwrap_or(0.0)
589        )?;
590
591        if let Some(acc) = self.mean_accuracy() {
592            write!(f, ", accuracy = {acc:.4}")?;
593        }
594
595        Ok(())
596    }
597}
598
599/// Training progress callback trait.
600pub trait TrainingCallback {
601    /// Called at the start of each epoch.
602    fn on_epoch_start(&mut self, epoch: usize, numepochs: usize);
603
604    /// Called at the end of each epoch.
605    fn on_epoch_end(&mut self, epoch: usize, numepochs: usize, metrics: &Metrics);
606
607    /// Called at the start of each batch.
608    fn on_batch_start(&mut self, batch: usize, numbatches: usize);
609
610    /// Called at the end of each batch.
611    fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64);
612
613    /// Called at the start of training.
614    fn on_train_start(&mut self, numepochs: usize);
615
616    /// Called at the end of training.
617    fn on_train_end(&mut self, metrics: &Metrics);
618}
619
620/// Progress bar callback for displaying training progress.
621pub struct ProgressCallback {
622    /// Whether to display a progress bar.
623    verbose: bool,
624
625    /// Start time of the current epoch.
626    epoch_start: Option<Instant>,
627
628    /// Start time of training.
629    train_start: Option<Instant>,
630}
631
632impl ProgressCallback {
633    /// Create a new progress callback.
634    pub fn new(verbose: bool) -> Self {
635        Self {
636            verbose,
637            epoch_start: None,
638            train_start: None,
639        }
640    }
641}
642
643impl TrainingCallback for ProgressCallback {
644    fn on_epoch_start(&mut self, epoch: usize, numepochs: usize) {
645        if self.verbose {
646            println!("Epoch {}/{}", epoch + 1, numepochs);
647        }
648
649        self.epoch_start = Some(Instant::now());
650    }
651
652    fn on_epoch_end(&mut self, _epoch: usize, numepochs: usize, metrics: &Metrics) {
653        if self.verbose {
654            if let Some(start) = self.epoch_start {
655                let duration = start.elapsed();
656                println!("{} - {}ms", metrics, duration.as_millis());
657            } else {
658                println!("{metrics}");
659            }
660        }
661    }
662
663    fn on_batch_start(&mut self, _batch: usize, _numbatches: usize) {
664        // No-op for this callback
665    }
666
667    fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64) {
668        if self.verbose && (batch + 1) % (numbatches / 10).max(1) == 0 {
669            print!("\rBatch {}/{} - loss: {:.4}", batch + 1, numbatches, loss);
670            if batch + 1 == numbatches {
671                println!();
672            }
673        }
674    }
675
676    fn on_train_start(&mut self, numepochs: usize) {
677        if self.verbose {
678            println!("Starting training for {numepochs} epochs");
679        }
680
681        self.train_start = Some(Instant::now());
682    }
683
684    fn on_train_end(&mut self, metrics: &Metrics) {
685        if self.verbose {
686            if let Some(start) = self.train_start {
687                let duration = start.elapsed();
688                println!("Training completed in {}s", duration.as_secs());
689            } else {
690                println!("Training completed");
691            }
692
693            if let Some(acc) = metrics.mean_accuracy() {
694                println!("Final accuracy: {acc:.4}");
695            }
696        }
697    }
698}
699
700/// Model trainer for neural networks.
701pub struct Trainer {
702    /// The model to train.
703    model: Sequential,
704
705    /// The optimizer to use.
706    optimizer: Box<dyn Optimizer>,
707
708    /// The loss function to use.
709    lossfn: Box<dyn Loss>,
710
711    /// The callbacks to use during training.
712    callbacks: Vec<Box<dyn TrainingCallback>>,
713
714    /// Training metrics.
715    train_metrics: Metrics,
716
717    /// Validation metrics.
718    val_metrics: Option<Metrics>,
719}
720
721impl Trainer {
722    /// Create a new trainer.
723    pub fn new(model: Sequential, optimizer: Box<dyn Optimizer>, lossfn: Box<dyn Loss>) -> Self {
724        Self {
725            model,
726            optimizer,
727            lossfn,
728            callbacks: Vec::new(),
729            train_metrics: Metrics::new("train"),
730            val_metrics: None,
731        }
732    }
733
734    /// Add a callback to the trainer.
735    pub fn add_callback(&mut self, callback: Box<dyn TrainingCallback>) {
736        self.callbacks.push(callback);
737    }
738
739    /// Train the model.
740    pub fn train(
741        &mut self,
742        train_loader: &mut DataLoader,
743        numepochs: usize,
744        mut val_loader: Option<&mut DataLoader>,
745    ) -> CoreResult<()> {
746        // Notify callbacks that training is starting
747        for callback in &mut self.callbacks {
748            callback.on_train_start(numepochs);
749        }
750
751        // Initialize validation metrics if needed
752        if val_loader.is_some() && self.val_metrics.is_none() {
753            self.val_metrics = Some(Metrics::new("val"));
754        }
755
756        // Train for the specified number of epochs
757        for epoch in 0..numepochs {
758            // Reset metrics
759            self.train_metrics.reset();
760            if let Some(metrics) = &mut self.val_metrics {
761                metrics.reset();
762            }
763
764            // Notify callbacks that epoch is starting
765            for callback in &mut self.callbacks {
766                callback.on_epoch_start(epoch, numepochs);
767            }
768
769            // Train on the training set
770            self.train_epoch(train_loader)?;
771
772            // Validate on the validation set if provided
773            if let Some(ref mut val_loader) = val_loader {
774                self.validate(val_loader)?;
775            }
776
777            // Notify callbacks that epoch is ending
778            for callback in &mut self.callbacks {
779                callback.on_epoch_end(
780                    epoch,
781                    numepochs,
782                    if let Some(val_metrics) = &self.val_metrics {
783                        val_metrics
784                    } else {
785                        &self.train_metrics
786                    },
787                );
788            }
789        }
790
791        // Notify callbacks that training is ending
792        for callback in &mut self.callbacks {
793            callback.on_train_end(if let Some(val_metrics) = &self.val_metrics {
794                val_metrics
795            } else {
796                &self.train_metrics
797            });
798        }
799
800        Ok(())
801    }
802
803    /// Train for one epoch.
804    fn train_epoch(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
805        // Set model to training mode
806        self.model.train();
807
808        // Reset data loader
809        dataloader.reset();
810
811        let numbatches = dataloader.numbatches();
812
813        // Train on batches
814        for batch_idx in 0..numbatches {
815            let (inputs, targets) = dataloader.next_batch().expect("Operation failed");
816            // Notify callbacks that batch is starting
817            for callback in &mut self.callbacks {
818                callback.on_batch_start(batch_idx, numbatches);
819            }
820
821            // Forward pass
822            let batch_loss = self.train_batch(&inputs, &targets)?;
823
824            // Update metrics
825            self.train_metrics.add_loss(batch_loss);
826
827            // Notify callbacks that batch is ending
828            for callback in &mut self.callbacks {
829                callback.on_batch_end(batch_idx, numbatches, batch_loss);
830            }
831        }
832
833        Ok(())
834    }
835
836    /// Train on a single batch.
837    fn train_batch(
838        &mut self,
839        inputs: &[Box<dyn ArrayProtocol>],
840        targets: &[Box<dyn ArrayProtocol>],
841    ) -> CoreResult<f64> {
842        // Zero gradients
843        self.optimizer.zero_grad();
844
845        // Forward pass
846        let mut batch_loss = 0.0;
847
848        for (input, target) in inputs.iter().zip(targets.iter()) {
849            // Forward pass through model
850            let output = self.model.forward(input.as_ref())?;
851
852            // Compute loss
853            let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
854
855            // Get loss value
856            if let Some(loss_array) = loss
857                .as_any()
858                .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
859            {
860                let loss_value = loss_array.as_array().sum();
861                batch_loss += loss_value;
862            }
863
864            // Backward pass - compute gradients
865            // For now, implement a simple gradient approximation using finite differences
866            // In a full implementation, this would be automatic differentiation
867
868            let learningrate = 0.001; // Default learning rate
869
870            // Simple gradient estimation for demonstration
871            // This computes numerical gradients for the model parameters
872
873            // Get current output for gradient computation
874            let current_output = self.model.forward(input.as_ref())?;
875            let current_loss = self
876                .lossfn
877                .forward(current_output.as_ref(), target.as_ref())?;
878            let _current_loss_value = if let Some(loss_array) = current_loss
879                .as_any()
880                .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
881            {
882                loss_array.as_array().sum()
883            } else {
884                0.0
885            };
886
887            // Compute gradients via backpropagation
888            let gradients = self.compute_gradients(
889                input.as_ref(),
890                target.as_ref(),
891                current_output.as_ref(),
892                current_loss.as_ref(),
893            )?;
894
895            // Apply gradients to model parameters
896            self.apply_gradients(&gradients, learningrate)?;
897
898            // Store gradients in optimizer for momentum-based optimizers
899            self.optimizer.accumulate_gradients(&gradients)?;
900        }
901
902        // Compute average loss
903        let batch_loss = batch_loss / inputs.len() as f64;
904
905        // Update weights
906        self.optimizer.step()?;
907
908        Ok(batch_loss)
909    }
910
911    /// Compute gradients via backpropagation
912    fn compute_gradients(
913        &self,
914        input: &dyn ArrayProtocol,
915        target: &dyn ArrayProtocol,
916        output: &dyn ArrayProtocol,
917        _loss: &dyn ArrayProtocol,
918    ) -> CoreResult<GradientDict> {
919        // Start backpropagation from loss
920        let mut gradients = GradientDict::new();
921
922        // Compute gradient of loss with respect to output
923        let loss_grad = self.lossfn.backward(output, target)?;
924
925        // Backpropagate through the model
926        let model_gradients = self.model.backward(input, loss_grad.as_ref())?;
927
928        // Merge gradients
929        gradients.merge(model_gradients);
930
931        Ok(gradients)
932    }
933
934    /// Apply computed gradients to model parameters
935    fn apply_gradients(&mut self, gradients: &GradientDict, learningrate: f64) -> CoreResult<()> {
936        // Apply gradients to each parameter in the model
937        for (param_name, gradient) in gradients.iter() {
938            self.model
939                .update_parameter(param_name, gradient.as_ref(), learningrate)?;
940        }
941
942        Ok(())
943    }
944
945    /// Validate the model.
946    fn validate(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
947        // Set model to evaluation mode
948        self.model.eval();
949
950        // Reset validation metrics
951        if let Some(metrics) = &mut self.val_metrics {
952            metrics.reset();
953        } else {
954            return Ok(());
955        }
956
957        // Reset data loader
958        dataloader.reset();
959
960        let numbatches = dataloader.numbatches();
961
962        // Validate on batches
963        for _ in 0..numbatches {
964            let (inputs, targets) = dataloader.next_batch().expect("Operation failed");
965            // Forward pass without gradient tracking
966            let mut batch_loss = 0.0;
967            let mut batch_correct = 0;
968            let mut batch_total = 0;
969
970            for (input, target) in inputs.iter().zip(targets.iter()) {
971                // Forward pass through model
972                let output = self.model.forward(input.as_ref())?;
973
974                // Compute loss
975                let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
976
977                // Get loss value
978                if let Some(loss_array) = loss
979                    .as_any()
980                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
981                {
982                    let loss_value = loss_array.as_array().sum();
983                    batch_loss += loss_value;
984                }
985
986                // Compute accuracy for classification problems
987                if let (Some(output_array), Some(target_array)) = (
988                    output
989                        .as_any()
990                        .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>(),
991                    target
992                        .as_any()
993                        .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>(),
994                ) {
995                    // Get predictions (argmax)
996                    let output_vec = output_array.as_array();
997                    let target_vec = target_array.as_array();
998
999                    // For simplicity, assume 2D arrays [batch_size, num_classes]
1000                    if output_vec.ndim() == 2 && target_vec.ndim() == 2 {
1001                        for (out_row, target_row) in
1002                            output_vec.outer_iter().zip(target_vec.outer_iter())
1003                        {
1004                            // Find the index of the maximum value in the output row
1005                            let mut max_idx = 0;
1006                            let mut max_val = out_row[0];
1007
1008                            for (i, &val) in out_row.iter().enumerate().skip(1) {
1009                                if val > max_val {
1010                                    max_idx = i;
1011                                    max_val = val;
1012                                }
1013                            }
1014
1015                            // Find the index of 1 in the target row (one-hot encoding)
1016                            if let Some(target_idx) = target_row.iter().position(|&x| x == 1.0) {
1017                                if max_idx == target_idx {
1018                                    batch_correct += 1;
1019                                }
1020                            }
1021
1022                            batch_total += 1;
1023                        }
1024                    }
1025                }
1026            }
1027
1028            // Compute average loss and accuracy
1029            let batch_loss = batch_loss / inputs.len() as f64;
1030            let batch_accuracy = if batch_total > 0 {
1031                batch_correct as f64 / batch_total as f64
1032            } else {
1033                0.0
1034            };
1035
1036            // Update validation metrics
1037            if let Some(metrics) = &mut self.val_metrics {
1038                metrics.add_loss(batch_loss);
1039                metrics.add_accuracy(batch_accuracy);
1040            }
1041        }
1042
1043        Ok(())
1044    }
1045
1046    /// Get training metrics.
1047    pub const fn train_metrics(&self) -> &Metrics {
1048        &self.train_metrics
1049    }
1050
1051    /// Get validation metrics.
1052    pub fn val_metrics(&self) -> Option<&Metrics> {
1053        self.val_metrics.as_ref()
1054    }
1055}
1056
1057// Helper functions
1058
1059#[cfg(test)]
1060mod tests {
1061    use super::*;
1062    use crate::array_protocol::{self, NdarrayWrapper};
1063    use ::ndarray::Array2;
1064
1065    #[test]
1066    fn test_in_memory_dataset() {
1067        // Create input and target arrays
1068        let inputs = Array2::<f64>::ones((10, 5));
1069        let targets = Array2::<f64>::zeros((10, 2));
1070
1071        // Create dataset
1072        let dataset = InMemoryDataset::from_arrays(inputs, targets);
1073
1074        // Check properties
1075        assert_eq!(dataset.len(), 10);
1076        assert_eq!(dataset.inputshape(), vec![5]);
1077        assert_eq!(dataset.outputshape(), vec![2]);
1078
1079        // Get a sample
1080        let (input, target) = dataset.get(0).expect("Operation failed");
1081        assert!(input
1082            .as_any()
1083            .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
1084            .is_some());
1085        assert!(target
1086            .as_any()
1087            .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
1088            .is_some());
1089    }
1090
1091    #[test]
1092    fn test_dataloader() {
1093        // Create input and target arrays
1094        let inputs = Array2::<f64>::ones((10, 5));
1095        let targets = Array2::<f64>::zeros((10, 2));
1096
1097        // Create dataset and data loader
1098        let dataset = Box::new(InMemoryDataset::from_arrays(inputs, targets));
1099        let mut loader = DataLoader::new(dataset, 4, true, Some(42));
1100
1101        // Check properties
1102        assert_eq!(loader.numbatches(), 3);
1103
1104        // Get batches
1105        let (batch1_inputs, batch1_targets) = loader.next_batch().expect("Operation failed");
1106        assert_eq!(batch1_inputs.len(), 4);
1107        assert_eq!(batch1_targets.len(), 4);
1108
1109        let (batch2_inputs, batch2_targets) = loader.next_batch().expect("Operation failed");
1110        assert_eq!(batch2_inputs.len(), 4);
1111        assert_eq!(batch2_targets.len(), 4);
1112
1113        let (batch3_inputs, batch3_targets) = loader.next_batch().expect("Operation failed");
1114        assert_eq!(batch3_inputs.len(), 2);
1115        assert_eq!(batch3_targets.len(), 2);
1116
1117        // Reset and get another batch
1118        loader.reset();
1119        let (batch1_inputs, batch1_targets) = loader.next_batch().expect("Operation failed");
1120        assert_eq!(batch1_inputs.len(), 4);
1121        assert_eq!(batch1_targets.len(), 4);
1122    }
1123
1124    #[test]
1125    fn test_mse_loss() {
1126        // Initialize the array protocol system
1127        array_protocol::init();
1128
1129        // Create prediction and target arrays
1130        let predictions = Array2::<f64>::ones((2, 3));
1131        let targets = Array2::<f64>::zeros((2, 3));
1132
1133        let predictions_wrapped = NdarrayWrapper::new(predictions);
1134        let targets_wrapped = NdarrayWrapper::new(targets);
1135
1136        // Create loss function
1137        let mse = MSELoss::new(Some("mean"));
1138
1139        // Compute loss with proper error handling
1140        match mse.forward(&predictions_wrapped, &targets_wrapped) {
1141            Ok(loss) => {
1142                if let Some(loss_array) = loss
1143                    .as_any()
1144                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix0>>()
1145                {
1146                    // Expected: mean((1 - 0)^2) = 1.0
1147                    assert_eq!(loss_array.as_array()[()], 1.0);
1148                } else {
1149                    println!("Loss not of expected type NdarrayWrapper<f64, Ix0>");
1150                }
1151            }
1152            Err(e) => {
1153                println!("MSE Loss forward not fully implemented: {e}");
1154            }
1155        }
1156    }
1157
1158    #[test]
1159    fn test_metrics() {
1160        // Create metrics
1161        let mut metrics = Metrics::new("test");
1162
1163        // Add loss values
1164        metrics.add_loss(1.0);
1165        metrics.add_loss(2.0);
1166        metrics.add_loss(3.0);
1167
1168        // Add accuracy values
1169        metrics.add_accuracy(0.5);
1170        metrics.add_accuracy(0.6);
1171        metrics.add_accuracy(0.7);
1172
1173        // Check mean values
1174        assert_eq!(metrics.mean_loss().expect("Operation failed"), 2.0);
1175        assert_eq!(metrics.mean_accuracy().expect("Operation failed"), 0.6);
1176
1177        // Reset metrics
1178        metrics.reset();
1179        assert!(metrics.mean_loss().is_none());
1180        assert!(metrics.mean_accuracy().is_none());
1181    }
1182}