scirs2_core/array_protocol/
training.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under either of
4//
5// * Apache License, Version 2.0
6//   (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
7// * MIT license
8//   (LICENSE-MIT or http://opensource.org/licenses/MIT)
9//
10// at your option.
11//
12
13//! Training utilities for neural networks using the array protocol.
14//!
15//! This module provides utilities for training neural networks using the
16//! array protocol, including datasets, dataloaders, loss functions, and
17//! training loops.
18
19use std::fmt;
20use std::time::Instant;
21
22use ::ndarray::{Array, Array0, Dimension};
23use rand::seq::SliceRandom;
24use rand::Rng;
25use rand::SeedableRng;
26
27use crate::array_protocol::grad::{GradientDict, Optimizer};
28use crate::array_protocol::ml_ops::ActivationFunc;
29use crate::array_protocol::neural::Sequential;
30use crate::array_protocol::operations::{multiply, subtract};
31use crate::array_protocol::{activation, ArrayProtocol, NdarrayWrapper};
32use crate::error::{CoreError, CoreResult, ErrorContext};
33
34/// Type alias for batch data
35pub type BatchData = (Vec<Box<dyn ArrayProtocol>>, Vec<Box<dyn ArrayProtocol>>);
36
37/// Dataset trait for providing data samples.
38pub trait Dataset {
39    /// Get the number of samples in the dataset.
40    fn len(&self) -> usize;
41
42    /// Check if the dataset is empty.
43    fn is_empty(&self) -> bool {
44        self.len() == 0
45    }
46
47    /// Get a sample from the dataset by index.
48    fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)>;
49
50    /// Get the input shape of the dataset.
51    fn inputshape(&self) -> Vec<usize>;
52
53    /// Get the output shape of the dataset.
54    fn outputshape(&self) -> Vec<usize>;
55}
56
57/// In-memory dataset with arrays.
58pub struct InMemoryDataset {
59    /// Input data samples.
60    inputs: Vec<Box<dyn ArrayProtocol>>,
61
62    /// Target output samples.
63    targets: Vec<Box<dyn ArrayProtocol>>,
64
65    /// Input shape.
66    inputshape: Vec<usize>,
67
68    /// Output shape.
69    outputshape: Vec<usize>,
70}
71
72impl InMemoryDataset {
73    /// Create a new in-memory dataset.
74    pub fn new(
75        inputs: Vec<Box<dyn ArrayProtocol>>,
76        targets: Vec<Box<dyn ArrayProtocol>>,
77        inputshape: Vec<usize>,
78        outputshape: Vec<usize>,
79    ) -> Self {
80        assert_eq!(
81            inputs.len(),
82            targets.len(),
83            "Inputs and targets must have the same length"
84        );
85
86        Self {
87            inputs,
88            targets,
89            inputshape,
90            outputshape,
91        }
92    }
93
94    /// Create an in-memory dataset from arrays.
95    pub fn from_arrays<T, D1, D2>(inputs: Array<T, D1>, targets: Array<T, D2>) -> Self
96    where
97        T: Clone + Send + Sync + 'static,
98        D1: Dimension + Send + Sync,
99        D2: Dimension + Send + Sync,
100    {
101        let inputshape = inputs.shape().to_vec();
102        let outputshape = targets.shape().to_vec();
103
104        // Handle batched _inputs
105        let num_samples = inputshape[0];
106        assert_eq!(
107            num_samples, outputshape[0],
108            "Inputs and targets must have the same number of samples"
109        );
110
111        let mut input_samples = Vec::with_capacity(num_samples);
112        let mut target_samples = Vec::with_capacity(num_samples);
113
114        // Create dynamic arrays with the appropriate shape to handle arbitrary dimensions
115        let to_dyn_inputs = inputs.into_dyn();
116        let to_dyn_targets = targets.into_dyn();
117
118        for i in 0..num_samples {
119            // Use index_axis instead of slice for better compatibility with different dimensions
120            let input_view = to_dyn_inputs.index_axis(crate::ndarray::Axis(0), i);
121            let inputarray = input_view.to_owned();
122            input_samples.push(Box::new(NdarrayWrapper::new(inputarray)) as Box<dyn ArrayProtocol>);
123
124            let target_view = to_dyn_targets.index_axis(crate::ndarray::Axis(0), i);
125            let target_array = target_view.to_owned();
126            target_samples
127                .push(Box::new(NdarrayWrapper::new(target_array)) as Box<dyn ArrayProtocol>);
128        }
129
130        Self {
131            inputs: input_samples,
132            targets: target_samples,
133            inputshape: inputshape[1..].to_vec(),
134            outputshape: outputshape[1..].to_vec(),
135        }
136    }
137}
138
139impl Dataset for InMemoryDataset {
140    fn len(&self) -> usize {
141        self.inputs.len()
142    }
143
144    fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)> {
145        if index >= self.len() {
146            return None;
147        }
148
149        Some((self.inputs[index].clone(), self.targets[index].clone()))
150    }
151
152    fn inputshape(&self) -> Vec<usize> {
153        self.inputshape.clone()
154    }
155
156    fn outputshape(&self) -> Vec<usize> {
157        self.outputshape.clone()
158    }
159}
160
161/// Data loader for batching and shuffling datasets.
162pub struct DataLoader {
163    /// The dataset to load from.
164    dataset: Box<dyn Dataset>,
165
166    /// Batch size.
167    batch_size: usize,
168
169    /// Whether to shuffle the dataset.
170    shuffle: bool,
171
172    /// Random number generator seed.
173    seed: Option<u64>,
174
175    /// Indices of the dataset.
176    indices: Vec<usize>,
177
178    /// Current position in the dataset.
179    position: usize,
180}
181
182impl DataLoader {
183    /// Create a new data loader.
184    pub fn new(
185        dataset: Box<dyn Dataset>,
186        batch_size: usize,
187        shuffle: bool,
188        seed: Option<u64>,
189    ) -> Self {
190        let indices = (0..dataset.len()).collect();
191
192        Self {
193            dataset,
194            batch_size,
195            shuffle,
196            seed,
197            indices,
198            position: 0,
199        }
200    }
201
202    /// Reset the data loader.
203    pub fn reset(&mut self) {
204        self.position = 0;
205
206        if self.shuffle {
207            let mut rng = match self.seed {
208                Some(s) => rand::rngs::StdRng::seed_from_u64(s),
209                None => {
210                    let mut rng = rand::rng();
211                    // Get a random seed from rng and create a new StdRng
212                    let random_seed: u64 = rng.random();
213                    rand::rngs::StdRng::seed_from_u64(random_seed)
214                }
215            };
216
217            self.indices.shuffle(&mut rng);
218        }
219    }
220
221    /// Get the next batch from the dataset.
222    pub fn next_batch(&mut self) -> Option<BatchData> {
223        if self.position >= self.dataset.len() {
224            return None;
225        }
226
227        // Determine how many samples to take
228        let remaining = self.dataset.len() - self.position;
229        let batch_size = std::cmp::min(self.batch_size, remaining);
230
231        // Get the batch
232        let mut inputs = Vec::with_capacity(batch_size);
233        let mut targets = Vec::with_capacity(batch_size);
234
235        for i in 0..batch_size {
236            let index = self.indices[self.position + i];
237            if let Some((input, target)) = self.dataset.get(index) {
238                inputs.push(input);
239                targets.push(target);
240            }
241        }
242
243        // Update position
244        self.position += batch_size;
245
246        Some((inputs, targets))
247    }
248
249    /// Get the number of batches in the dataset.
250    pub fn numbatches(&self) -> usize {
251        self.dataset.len().div_ceil(self.batch_size)
252    }
253
254    /// Get a reference to the dataset.
255    pub fn dataset(&self) -> &dyn Dataset {
256        self.dataset.as_ref()
257    }
258}
259
260/// Iterator implementation for DataLoader.
261impl Iterator for DataLoader {
262    type Item = BatchData;
263
264    fn next(&mut self) -> Option<Self::Item> {
265        self.next_batch()
266    }
267}
268
269/// Loss function trait.
270pub trait Loss {
271    /// Compute the loss between predictions and targets.
272    fn forward(
273        &self,
274        predictions: &dyn ArrayProtocol,
275        targets: &dyn ArrayProtocol,
276    ) -> CoreResult<Box<dyn ArrayProtocol>>;
277
278    /// Compute the gradient of the loss with respect to predictions.
279    fn backward(
280        &self,
281        predictions: &dyn ArrayProtocol,
282        targets: &dyn ArrayProtocol,
283    ) -> CoreResult<Box<dyn ArrayProtocol>>;
284
285    /// Get the name of the loss function.
286    fn name(&self) -> &str;
287}
288
289/// Mean squared error loss.
290pub struct MSELoss {
291    /// Name of the loss function.
292    name: String,
293
294    /// Whether to reduce the loss.
295    reduction: String,
296}
297
298impl MSELoss {
299    /// Create a new MSE loss.
300    pub fn new(reduction: Option<&str>) -> Self {
301        Self {
302            name: "MSELoss".to_string(),
303            reduction: reduction.unwrap_or("mean").to_string(),
304        }
305    }
306}
307
308impl Loss for MSELoss {
309    fn forward(
310        &self,
311        predictions: &dyn ArrayProtocol,
312        targets: &dyn ArrayProtocol,
313    ) -> CoreResult<Box<dyn ArrayProtocol>> {
314        // Compute squared difference
315        let diff = subtract(predictions, targets)?;
316        let squared = multiply(diff.as_ref(), diff.as_ref())?;
317
318        // Apply reduction
319        match self.reduction.as_str() {
320            "none" => Ok(squared),
321            "mean" => {
322                // Compute mean of all elements
323                if let Some(array) = squared
324                    .as_any()
325                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
326                {
327                    let mean = array.as_array().mean().unwrap();
328                    let result = Array0::<f64>::from_elem((), mean);
329                    Ok(Box::new(NdarrayWrapper::new(result)))
330                } else {
331                    Err(CoreError::NotImplementedError(ErrorContext::new(
332                        "Mean reduction not implemented for this array type".to_string(),
333                    )))
334                }
335            }
336            "sum" => {
337                // Compute sum of all elements
338                if let Some(array) = squared
339                    .as_any()
340                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
341                {
342                    let sum = array.as_array().sum();
343                    let result = Array0::<f64>::from_elem((), sum);
344                    Ok(Box::new(NdarrayWrapper::new(result)))
345                } else {
346                    Err(CoreError::NotImplementedError(ErrorContext::new(
347                        "Sum reduction not implemented for this array type".to_string(),
348                    )))
349                }
350            }
351            _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
352                "Unknown reduction: {reduction}",
353                reduction = self.reduction
354            )))),
355        }
356    }
357
358    fn backward(
359        &self,
360        predictions: &dyn ArrayProtocol,
361        targets: &dyn ArrayProtocol,
362    ) -> CoreResult<Box<dyn ArrayProtocol>> {
363        // Gradient of MSE loss: 2 * (predictions - targets)
364        let diff = subtract(predictions, targets)?;
365        let factor = Box::new(NdarrayWrapper::new(
366            crate::ndarray::Array0::<f64>::from_elem((), 2.0),
367        ));
368        let grad = multiply(factor.as_ref(), diff.as_ref())?;
369
370        // Apply reduction scaling if needed
371        match self.reduction.as_str() {
372            "none" => Ok(grad),
373            "mean" => {
374                // For mean reduction, scale by 1/N
375                if let Some(array) = grad
376                    .as_any()
377                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
378                {
379                    let n = array.as_array().len() as f64;
380                    let scale_factor = Box::new(NdarrayWrapper::new(
381                        crate::ndarray::Array0::<f64>::from_elem((), 1.0 / n),
382                    ));
383                    Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
384                } else {
385                    Ok(grad)
386                }
387            }
388            "sum" => Ok(grad),
389            _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
390                "Unknown reduction: {reduction}",
391                reduction = self.reduction
392            )))),
393        }
394    }
395
396    fn name(&self) -> &str {
397        &self.name
398    }
399}
400
401/// Cross-entropy loss.
402pub struct CrossEntropyLoss {
403    /// Name of the loss function.
404    name: String,
405
406    /// Whether to reduce the loss.
407    reduction: String,
408}
409
410impl CrossEntropyLoss {
411    /// Create a new cross-entropy loss.
412    pub fn new(reduction: Option<&str>) -> Self {
413        Self {
414            name: "CrossEntropyLoss".to_string(),
415            reduction: reduction.unwrap_or("mean").to_string(),
416        }
417    }
418}
419
420impl Loss for CrossEntropyLoss {
421    fn forward(
422        &self,
423        predictions: &dyn ArrayProtocol,
424        targets: &dyn ArrayProtocol,
425    ) -> CoreResult<Box<dyn ArrayProtocol>> {
426        // Apply softmax to predictions
427        let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
428
429        // Compute cross-entropy
430        if let (Some(preds_array), Some(targets_array)) = (
431            softmax_preds
432                .as_any()
433                .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>(),
434            targets
435                .as_any()
436                .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>(),
437        ) {
438            let preds = preds_array.as_array();
439            let targets = targets_array.as_array();
440
441            // Compute -targets * log(preds)
442            let log_preds = preds.mapv(|x| x.max(1e-10).ln());
443
444            // Compute element-wise multiplication and then negate
445            let mut losses = targets.clone();
446            losses.zip_mut_with(&log_preds, |t, l| *t = -(*t * *l));
447
448            // Apply reduction
449            match self.reduction.as_str() {
450                "none" => Ok(Box::new(NdarrayWrapper::new(losses))),
451                "mean" => {
452                    let mean = losses.mean().unwrap();
453                    let result = Array0::<f64>::from_elem((), mean);
454                    Ok(Box::new(NdarrayWrapper::new(result)))
455                }
456                "sum" => {
457                    let sum = losses.sum();
458                    let result = Array0::<f64>::from_elem((), sum);
459                    Ok(Box::new(NdarrayWrapper::new(result)))
460                }
461                _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
462                    "Unknown reduction: {reduction}",
463                    reduction = self.reduction
464                )))),
465            }
466        } else {
467            Err(CoreError::NotImplementedError(ErrorContext::new(
468                "CrossEntropy not implemented for these array types".to_string(),
469            )))
470        }
471    }
472
473    fn backward(
474        &self,
475        predictions: &dyn ArrayProtocol,
476        targets: &dyn ArrayProtocol,
477    ) -> CoreResult<Box<dyn ArrayProtocol>> {
478        // For cross-entropy with softmax: gradient is softmax(predictions) - targets
479        let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
480        let grad = subtract(softmax_preds.as_ref(), targets)?;
481
482        // Apply reduction scaling if needed
483        match self.reduction.as_str() {
484            "none" => Ok(grad),
485            "mean" => {
486                // For mean reduction, scale by 1/N
487                if let Some(array) = grad
488                    .as_any()
489                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
490                {
491                    let n = array.as_array().len() as f64;
492                    let scale_factor = Box::new(NdarrayWrapper::new(
493                        crate::ndarray::Array0::<f64>::from_elem((), 1.0 / n),
494                    ));
495                    Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
496                } else {
497                    Ok(grad)
498                }
499            }
500            "sum" => Ok(grad),
501            _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
502                "Unknown reduction: {reduction}",
503                reduction = self.reduction
504            )))),
505        }
506    }
507
508    fn name(&self) -> &str {
509        &self.name
510    }
511}
512
513/// Metrics for evaluating model performance.
514pub struct Metrics {
515    /// Loss values.
516    losses: Vec<f64>,
517
518    /// Accuracy values (if applicable).
519    accuracies: Option<Vec<f64>>,
520
521    /// Name of the metrics object.
522    name: String,
523}
524
525impl Metrics {
526    /// Create a new metrics object.
527    pub fn new(name: &str) -> Self {
528        Self {
529            losses: Vec::new(),
530            accuracies: None,
531            name: name.to_string(),
532        }
533    }
534
535    /// Add a loss value.
536    pub fn add_loss(&mut self, loss: f64) {
537        self.losses.push(loss);
538    }
539
540    /// Add an accuracy value.
541    pub fn add_accuracy(&mut self, accuracy: f64) {
542        if self.accuracies.is_none() {
543            self.accuracies = Some(Vec::new());
544        }
545
546        if let Some(accuracies) = &mut self.accuracies {
547            accuracies.push(accuracy);
548        }
549    }
550
551    /// Get the mean loss.
552    pub fn mean_loss(&self) -> Option<f64> {
553        if self.losses.is_empty() {
554            return None;
555        }
556
557        let sum: f64 = self.losses.iter().sum();
558        Some(sum / self.losses.len() as f64)
559    }
560
561    /// Get the mean accuracy.
562    pub fn mean_accuracy(&self) -> Option<f64> {
563        if let Some(accuracies) = &self.accuracies {
564            if accuracies.is_empty() {
565                return None;
566            }
567
568            let sum: f64 = accuracies.iter().sum();
569            Some(sum / accuracies.len() as f64)
570        } else {
571            None
572        }
573    }
574
575    /// Reset the metrics.
576    pub fn reset(&mut self) {
577        self.losses.clear();
578        if let Some(accuracies) = &mut self.accuracies {
579            accuracies.clear();
580        }
581    }
582
583    /// Get the name of the metrics object.
584    pub fn name(&self) -> &str {
585        &self.name
586    }
587}
588
589impl fmt::Display for Metrics {
590    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
591        write!(
592            f,
593            "{}: loss = {:.4}",
594            self.name,
595            self.mean_loss().unwrap_or(0.0)
596        )?;
597
598        if let Some(acc) = self.mean_accuracy() {
599            write!(f, ", accuracy = {acc:.4}")?;
600        }
601
602        Ok(())
603    }
604}
605
606/// Training progress callback trait.
607pub trait TrainingCallback {
608    /// Called at the start of each epoch.
609    fn on_epoch_start(&mut self, epoch: usize, numepochs: usize);
610
611    /// Called at the end of each epoch.
612    fn on_epoch_end(&mut self, epoch: usize, numepochs: usize, metrics: &Metrics);
613
614    /// Called at the start of each batch.
615    fn on_batch_start(&mut self, batch: usize, numbatches: usize);
616
617    /// Called at the end of each batch.
618    fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64);
619
620    /// Called at the start of training.
621    fn on_train_start(&mut self, numepochs: usize);
622
623    /// Called at the end of training.
624    fn on_train_end(&mut self, metrics: &Metrics);
625}
626
627/// Progress bar callback for displaying training progress.
628pub struct ProgressCallback {
629    /// Whether to display a progress bar.
630    verbose: bool,
631
632    /// Start time of the current epoch.
633    epoch_start: Option<Instant>,
634
635    /// Start time of training.
636    train_start: Option<Instant>,
637}
638
639impl ProgressCallback {
640    /// Create a new progress callback.
641    pub fn new(verbose: bool) -> Self {
642        Self {
643            verbose,
644            epoch_start: None,
645            train_start: None,
646        }
647    }
648}
649
650impl TrainingCallback for ProgressCallback {
651    fn on_epoch_start(&mut self, epoch: usize, numepochs: usize) {
652        if self.verbose {
653            println!("Epoch {}/{}", epoch + 1, numepochs);
654        }
655
656        self.epoch_start = Some(Instant::now());
657    }
658
659    fn on_epoch_end(&mut self, _epoch: usize, numepochs: usize, metrics: &Metrics) {
660        if self.verbose {
661            if let Some(start) = self.epoch_start {
662                let duration = start.elapsed();
663                println!("{} - {}ms", metrics, duration.as_millis());
664            } else {
665                println!("{metrics}");
666            }
667        }
668    }
669
670    fn on_batch_start(&mut self, _batch: usize, _numbatches: usize) {
671        // No-op for this callback
672    }
673
674    fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64) {
675        if self.verbose && (batch + 1) % (numbatches / 10).max(1) == 0 {
676            print!("\rBatch {}/{} - loss: {:.4}", batch + 1, numbatches, loss);
677            if batch + 1 == numbatches {
678                println!();
679            }
680        }
681    }
682
683    fn on_train_start(&mut self, numepochs: usize) {
684        if self.verbose {
685            println!("Starting training for {numepochs} epochs");
686        }
687
688        self.train_start = Some(Instant::now());
689    }
690
691    fn on_train_end(&mut self, metrics: &Metrics) {
692        if self.verbose {
693            if let Some(start) = self.train_start {
694                let duration = start.elapsed();
695                println!("Training completed in {}s", duration.as_secs());
696            } else {
697                println!("Training completed");
698            }
699
700            if let Some(acc) = metrics.mean_accuracy() {
701                println!("Final accuracy: {acc:.4}");
702            }
703        }
704    }
705}
706
707/// Model trainer for neural networks.
708pub struct Trainer {
709    /// The model to train.
710    model: Sequential,
711
712    /// The optimizer to use.
713    optimizer: Box<dyn Optimizer>,
714
715    /// The loss function to use.
716    lossfn: Box<dyn Loss>,
717
718    /// The callbacks to use during training.
719    callbacks: Vec<Box<dyn TrainingCallback>>,
720
721    /// Training metrics.
722    train_metrics: Metrics,
723
724    /// Validation metrics.
725    val_metrics: Option<Metrics>,
726}
727
728impl Trainer {
729    /// Create a new trainer.
730    pub fn new(model: Sequential, optimizer: Box<dyn Optimizer>, lossfn: Box<dyn Loss>) -> Self {
731        Self {
732            model,
733            optimizer,
734            lossfn,
735            callbacks: Vec::new(),
736            train_metrics: Metrics::new("train"),
737            val_metrics: None,
738        }
739    }
740
741    /// Add a callback to the trainer.
742    pub fn add_callback(&mut self, callback: Box<dyn TrainingCallback>) {
743        self.callbacks.push(callback);
744    }
745
746    /// Train the model.
747    pub fn train(
748        &mut self,
749        train_loader: &mut DataLoader,
750        numepochs: usize,
751        mut val_loader: Option<&mut DataLoader>,
752    ) -> CoreResult<()> {
753        // Notify callbacks that training is starting
754        for callback in &mut self.callbacks {
755            callback.on_train_start(numepochs);
756        }
757
758        // Initialize validation metrics if needed
759        if val_loader.is_some() && self.val_metrics.is_none() {
760            self.val_metrics = Some(Metrics::new("val"));
761        }
762
763        // Train for the specified number of epochs
764        for epoch in 0..numepochs {
765            // Reset metrics
766            self.train_metrics.reset();
767            if let Some(metrics) = &mut self.val_metrics {
768                metrics.reset();
769            }
770
771            // Notify callbacks that epoch is starting
772            for callback in &mut self.callbacks {
773                callback.on_epoch_start(epoch, numepochs);
774            }
775
776            // Train on the training set
777            self.train_epoch(train_loader)?;
778
779            // Validate on the validation set if provided
780            if let Some(ref mut val_loader) = val_loader {
781                self.validate(val_loader)?;
782            }
783
784            // Notify callbacks that epoch is ending
785            for callback in &mut self.callbacks {
786                callback.on_epoch_end(
787                    epoch,
788                    numepochs,
789                    if let Some(val_metrics) = &self.val_metrics {
790                        val_metrics
791                    } else {
792                        &self.train_metrics
793                    },
794                );
795            }
796        }
797
798        // Notify callbacks that training is ending
799        for callback in &mut self.callbacks {
800            callback.on_train_end(if let Some(val_metrics) = &self.val_metrics {
801                val_metrics
802            } else {
803                &self.train_metrics
804            });
805        }
806
807        Ok(())
808    }
809
810    /// Train for one epoch.
811    fn train_epoch(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
812        // Set model to training mode
813        self.model.train();
814
815        // Reset data loader
816        dataloader.reset();
817
818        let numbatches = dataloader.numbatches();
819
820        // Train on batches
821        for batch_idx in 0..numbatches {
822            let (inputs, targets) = dataloader.next_batch().unwrap();
823            // Notify callbacks that batch is starting
824            for callback in &mut self.callbacks {
825                callback.on_batch_start(batch_idx, numbatches);
826            }
827
828            // Forward pass
829            let batch_loss = self.train_batch(&inputs, &targets)?;
830
831            // Update metrics
832            self.train_metrics.add_loss(batch_loss);
833
834            // Notify callbacks that batch is ending
835            for callback in &mut self.callbacks {
836                callback.on_batch_end(batch_idx, numbatches, batch_loss);
837            }
838        }
839
840        Ok(())
841    }
842
843    /// Train on a single batch.
844    fn train_batch(
845        &mut self,
846        inputs: &[Box<dyn ArrayProtocol>],
847        targets: &[Box<dyn ArrayProtocol>],
848    ) -> CoreResult<f64> {
849        // Zero gradients
850        self.optimizer.zero_grad();
851
852        // Forward pass
853        let mut batch_loss = 0.0;
854
855        for (input, target) in inputs.iter().zip(targets.iter()) {
856            // Forward pass through model
857            let output = self.model.forward(input.as_ref())?;
858
859            // Compute loss
860            let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
861
862            // Get loss value
863            if let Some(loss_array) = loss
864                .as_any()
865                .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
866            {
867                let loss_value = loss_array.as_array().sum();
868                batch_loss += loss_value;
869            }
870
871            // Backward pass - compute gradients
872            // For now, implement a simple gradient approximation using finite differences
873            // In a full implementation, this would be automatic differentiation
874
875            let learningrate = 0.001; // Default learning rate
876
877            // Simple gradient estimation for demonstration
878            // This computes numerical gradients for the model parameters
879
880            // Get current output for gradient computation
881            let current_output = self.model.forward(input.as_ref())?;
882            let current_loss = self
883                .lossfn
884                .forward(current_output.as_ref(), target.as_ref())?;
885            let _current_loss_value = if let Some(loss_array) = current_loss
886                .as_any()
887                .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
888            {
889                loss_array.as_array().sum()
890            } else {
891                0.0
892            };
893
894            // Compute gradients via backpropagation
895            let gradients = self.compute_gradients(
896                input.as_ref(),
897                target.as_ref(),
898                current_output.as_ref(),
899                current_loss.as_ref(),
900            )?;
901
902            // Apply gradients to model parameters
903            self.apply_gradients(&gradients, learningrate)?;
904
905            // Store gradients in optimizer for momentum-based optimizers
906            self.optimizer.accumulate_gradients(&gradients)?;
907        }
908
909        // Compute average loss
910        let batch_loss = batch_loss / inputs.len() as f64;
911
912        // Update weights
913        self.optimizer.step()?;
914
915        Ok(batch_loss)
916    }
917
918    /// Compute gradients via backpropagation
919    fn compute_gradients(
920        &self,
921        input: &dyn ArrayProtocol,
922        target: &dyn ArrayProtocol,
923        output: &dyn ArrayProtocol,
924        _loss: &dyn ArrayProtocol,
925    ) -> CoreResult<GradientDict> {
926        // Start backpropagation from loss
927        let mut gradients = GradientDict::new();
928
929        // Compute gradient of loss with respect to output
930        let loss_grad = self.lossfn.backward(output, target)?;
931
932        // Backpropagate through the model
933        let model_gradients = self.model.backward(input, loss_grad.as_ref())?;
934
935        // Merge gradients
936        gradients.merge(model_gradients);
937
938        Ok(gradients)
939    }
940
941    /// Apply computed gradients to model parameters
942    fn apply_gradients(&mut self, gradients: &GradientDict, learningrate: f64) -> CoreResult<()> {
943        // Apply gradients to each parameter in the model
944        for (param_name, gradient) in gradients.iter() {
945            self.model
946                .update_parameter(param_name, gradient.as_ref(), learningrate)?;
947        }
948
949        Ok(())
950    }
951
952    /// Validate the model.
953    fn validate(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
954        // Set model to evaluation mode
955        self.model.eval();
956
957        // Reset validation metrics
958        if let Some(metrics) = &mut self.val_metrics {
959            metrics.reset();
960        } else {
961            return Ok(());
962        }
963
964        // Reset data loader
965        dataloader.reset();
966
967        let numbatches = dataloader.numbatches();
968
969        // Validate on batches
970        for _ in 0..numbatches {
971            let (inputs, targets) = dataloader.next_batch().unwrap();
972            // Forward pass without gradient tracking
973            let mut batch_loss = 0.0;
974            let mut batch_correct = 0;
975            let mut batch_total = 0;
976
977            for (input, target) in inputs.iter().zip(targets.iter()) {
978                // Forward pass through model
979                let output = self.model.forward(input.as_ref())?;
980
981                // Compute loss
982                let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
983
984                // Get loss value
985                if let Some(loss_array) = loss
986                    .as_any()
987                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
988                {
989                    let loss_value = loss_array.as_array().sum();
990                    batch_loss += loss_value;
991                }
992
993                // Compute accuracy for classification problems
994                if let (Some(output_array), Some(target_array)) = (
995                    output
996                        .as_any()
997                        .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>(),
998                    target
999                        .as_any()
1000                        .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>(),
1001                ) {
1002                    // Get predictions (argmax)
1003                    let output_vec = output_array.as_array();
1004                    let target_vec = target_array.as_array();
1005
1006                    // For simplicity, assume 2D arrays [batch_size, num_classes]
1007                    if output_vec.ndim() == 2 && target_vec.ndim() == 2 {
1008                        for (out_row, target_row) in
1009                            output_vec.outer_iter().zip(target_vec.outer_iter())
1010                        {
1011                            // Find the index of the maximum value in the output row
1012                            let mut max_idx = 0;
1013                            let mut max_val = out_row[0];
1014
1015                            for (i, &val) in out_row.iter().enumerate().skip(1) {
1016                                if val > max_val {
1017                                    max_idx = i;
1018                                    max_val = val;
1019                                }
1020                            }
1021
1022                            // Find the index of 1 in the target row (one-hot encoding)
1023                            if let Some(target_idx) = target_row.iter().position(|&x| x == 1.0) {
1024                                if max_idx == target_idx {
1025                                    batch_correct += 1;
1026                                }
1027                            }
1028
1029                            batch_total += 1;
1030                        }
1031                    }
1032                }
1033            }
1034
1035            // Compute average loss and accuracy
1036            let batch_loss = batch_loss / inputs.len() as f64;
1037            let batch_accuracy = if batch_total > 0 {
1038                batch_correct as f64 / batch_total as f64
1039            } else {
1040                0.0
1041            };
1042
1043            // Update validation metrics
1044            if let Some(metrics) = &mut self.val_metrics {
1045                metrics.add_loss(batch_loss);
1046                metrics.add_accuracy(batch_accuracy);
1047            }
1048        }
1049
1050        Ok(())
1051    }
1052
1053    /// Get training metrics.
1054    pub const fn train_metrics(&self) -> &Metrics {
1055        &self.train_metrics
1056    }
1057
1058    /// Get validation metrics.
1059    pub fn val_metrics(&self) -> Option<&Metrics> {
1060        self.val_metrics.as_ref()
1061    }
1062}
1063
1064// Helper functions
1065
1066#[cfg(test)]
1067mod tests {
1068    use super::*;
1069    use crate::array_protocol::{self, NdarrayWrapper};
1070    use ::ndarray::Array2;
1071
1072    #[test]
1073    fn test_in_memory_dataset() {
1074        // Create input and target arrays
1075        let inputs = Array2::<f64>::ones((10, 5));
1076        let targets = Array2::<f64>::zeros((10, 2));
1077
1078        // Create dataset
1079        let dataset = InMemoryDataset::from_arrays(inputs, targets);
1080
1081        // Check properties
1082        assert_eq!(dataset.len(), 10);
1083        assert_eq!(dataset.inputshape(), vec![5]);
1084        assert_eq!(dataset.outputshape(), vec![2]);
1085
1086        // Get a sample
1087        let (input, target) = dataset.get(0).unwrap();
1088        assert!(input
1089            .as_any()
1090            .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
1091            .is_some());
1092        assert!(target
1093            .as_any()
1094            .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
1095            .is_some());
1096    }
1097
1098    #[test]
1099    fn test_dataloader() {
1100        // Create input and target arrays
1101        let inputs = Array2::<f64>::ones((10, 5));
1102        let targets = Array2::<f64>::zeros((10, 2));
1103
1104        // Create dataset and data loader
1105        let dataset = Box::new(InMemoryDataset::from_arrays(inputs, targets));
1106        let mut loader = DataLoader::new(dataset, 4, true, Some(42));
1107
1108        // Check properties
1109        assert_eq!(loader.numbatches(), 3);
1110
1111        // Get batches
1112        let (batch1_inputs, batch1_targets) = loader.next_batch().unwrap();
1113        assert_eq!(batch1_inputs.len(), 4);
1114        assert_eq!(batch1_targets.len(), 4);
1115
1116        let (batch2_inputs, batch2_targets) = loader.next_batch().unwrap();
1117        assert_eq!(batch2_inputs.len(), 4);
1118        assert_eq!(batch2_targets.len(), 4);
1119
1120        let (batch3_inputs, batch3_targets) = loader.next_batch().unwrap();
1121        assert_eq!(batch3_inputs.len(), 2);
1122        assert_eq!(batch3_targets.len(), 2);
1123
1124        // Reset and get another batch
1125        loader.reset();
1126        let (batch1_inputs, batch1_targets) = loader.next_batch().unwrap();
1127        assert_eq!(batch1_inputs.len(), 4);
1128        assert_eq!(batch1_targets.len(), 4);
1129    }
1130
1131    #[test]
1132    fn test_mse_loss() {
1133        // Initialize the array protocol system
1134        array_protocol::init();
1135
1136        // Create prediction and target arrays
1137        let predictions = Array2::<f64>::ones((2, 3));
1138        let targets = Array2::<f64>::zeros((2, 3));
1139
1140        let predictions_wrapped = NdarrayWrapper::new(predictions);
1141        let targets_wrapped = NdarrayWrapper::new(targets);
1142
1143        // Create loss function
1144        let mse = MSELoss::new(Some("mean"));
1145
1146        // Compute loss with proper error handling
1147        match mse.forward(&predictions_wrapped, &targets_wrapped) {
1148            Ok(loss) => {
1149                if let Some(loss_array) = loss
1150                    .as_any()
1151                    .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix0>>()
1152                {
1153                    // Expected: mean((1 - 0)^2) = 1.0
1154                    assert_eq!(loss_array.as_array()[()], 1.0);
1155                } else {
1156                    println!("Loss not of expected type NdarrayWrapper<f64, Ix0>");
1157                }
1158            }
1159            Err(e) => {
1160                println!("MSE Loss forward not fully implemented: {e}");
1161            }
1162        }
1163    }
1164
1165    #[test]
1166    fn test_metrics() {
1167        // Create metrics
1168        let mut metrics = Metrics::new("test");
1169
1170        // Add loss values
1171        metrics.add_loss(1.0);
1172        metrics.add_loss(2.0);
1173        metrics.add_loss(3.0);
1174
1175        // Add accuracy values
1176        metrics.add_accuracy(0.5);
1177        metrics.add_accuracy(0.6);
1178        metrics.add_accuracy(0.7);
1179
1180        // Check mean values
1181        assert_eq!(metrics.mean_loss().unwrap(), 2.0);
1182        assert_eq!(metrics.mean_accuracy().unwrap(), 0.6);
1183
1184        // Reset metrics
1185        metrics.reset();
1186        assert!(metrics.mean_loss().is_none());
1187        assert!(metrics.mean_accuracy().is_none());
1188    }
1189}