scirs2_core/array_protocol/
training.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under either of
4//
5// * Apache License, Version 2.0
6//   (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
7// * MIT license
8//   (LICENSE-MIT or http://opensource.org/licenses/MIT)
9//
10// at your option.
11//
12
13//! Training utilities for neural networks using the array protocol.
14//!
15//! This module provides utilities for training neural networks using the
16//! array protocol, including datasets, dataloaders, loss functions, and
17//! training loops.
18
19use std::fmt;
20use std::time::Instant;
21
22use ndarray::{Array, Array0, Dimension};
23use rand::seq::SliceRandom;
24use rand::Rng;
25use rand::SeedableRng;
26
27use crate::array_protocol::grad::{GradientDict, Optimizer};
28use crate::array_protocol::ml_ops::ActivationFunc;
29use crate::array_protocol::neural::Sequential;
30use crate::array_protocol::operations::{multiply, subtract};
31use crate::array_protocol::{activation, ArrayProtocol, NdarrayWrapper};
32use crate::error::{CoreError, CoreResult, ErrorContext};
33
34/// Type alias for batch data
35pub type BatchData = (Vec<Box<dyn ArrayProtocol>>, Vec<Box<dyn ArrayProtocol>>);
36
37/// Dataset trait for providing data samples.
38pub trait Dataset {
39    /// Get the number of samples in the dataset.
40    fn len(&self) -> usize;
41
42    /// Check if the dataset is empty.
43    fn is_empty(&self) -> bool {
44        self.len() == 0
45    }
46
47    /// Get a sample from the dataset by index.
48    fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)>;
49
50    /// Get the input shape of the dataset.
51    fn inputshape(&self) -> Vec<usize>;
52
53    /// Get the output shape of the dataset.
54    fn outputshape(&self) -> Vec<usize>;
55}
56
57/// In-memory dataset with arrays.
58pub struct InMemoryDataset {
59    /// Input data samples.
60    inputs: Vec<Box<dyn ArrayProtocol>>,
61
62    /// Target output samples.
63    targets: Vec<Box<dyn ArrayProtocol>>,
64
65    /// Input shape.
66    inputshape: Vec<usize>,
67
68    /// Output shape.
69    outputshape: Vec<usize>,
70}
71
72impl InMemoryDataset {
73    /// Create a new in-memory dataset.
74    pub fn new(
75        inputs: Vec<Box<dyn ArrayProtocol>>,
76        targets: Vec<Box<dyn ArrayProtocol>>,
77        inputshape: Vec<usize>,
78        outputshape: Vec<usize>,
79    ) -> Self {
80        assert_eq!(
81            inputs.len(),
82            targets.len(),
83            "Inputs and targets must have the same length"
84        );
85
86        Self {
87            inputs,
88            targets,
89            inputshape,
90            outputshape,
91        }
92    }
93
94    /// Create an in-memory dataset from arrays.
95    pub fn from_arrays<T, D1, D2>(inputs: Array<T, D1>, targets: Array<T, D2>) -> Self
96    where
97        T: Clone + Send + Sync + 'static,
98        D1: Dimension + Send + Sync,
99        D2: Dimension + Send + Sync,
100    {
101        let inputshape = inputs.shape().to_vec();
102        let outputshape = targets.shape().to_vec();
103
104        // Handle batched _inputs
105        let num_samples = inputshape[0];
106        assert_eq!(
107            num_samples, outputshape[0],
108            "Inputs and targets must have the same number of samples"
109        );
110
111        let mut input_samples = Vec::with_capacity(num_samples);
112        let mut target_samples = Vec::with_capacity(num_samples);
113
114        // Create dynamic arrays with the appropriate shape to handle arbitrary dimensions
115        let to_dyn_inputs = inputs.into_dyn();
116        let to_dyn_targets = targets.into_dyn();
117
118        for i in 0..num_samples {
119            // Use index_axis instead of slice for better compatibility with different dimensions
120            let input_view = to_dyn_inputs.index_axis(ndarray::Axis(0), i);
121            let inputarray = input_view.to_owned();
122            input_samples.push(Box::new(NdarrayWrapper::new(inputarray)) as Box<dyn ArrayProtocol>);
123
124            let target_view = to_dyn_targets.index_axis(ndarray::Axis(0), i);
125            let target_array = target_view.to_owned();
126            target_samples
127                .push(Box::new(NdarrayWrapper::new(target_array)) as Box<dyn ArrayProtocol>);
128        }
129
130        Self {
131            inputs: input_samples,
132            targets: target_samples,
133            inputshape: inputshape[1..].to_vec(),
134            outputshape: outputshape[1..].to_vec(),
135        }
136    }
137}
138
139impl Dataset for InMemoryDataset {
140    fn len(&self) -> usize {
141        self.inputs.len()
142    }
143
144    fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)> {
145        if index >= self.len() {
146            return None;
147        }
148
149        Some((self.inputs[index].clone(), self.targets[index].clone()))
150    }
151
152    fn inputshape(&self) -> Vec<usize> {
153        self.inputshape.clone()
154    }
155
156    fn outputshape(&self) -> Vec<usize> {
157        self.outputshape.clone()
158    }
159}
160
161/// Data loader for batching and shuffling datasets.
162pub struct DataLoader {
163    /// The dataset to load from.
164    dataset: Box<dyn Dataset>,
165
166    /// Batch size.
167    batch_size: usize,
168
169    /// Whether to shuffle the dataset.
170    shuffle: bool,
171
172    /// Random number generator seed.
173    seed: Option<u64>,
174
175    /// Indices of the dataset.
176    indices: Vec<usize>,
177
178    /// Current position in the dataset.
179    position: usize,
180}
181
182impl DataLoader {
183    /// Create a new data loader.
184    pub fn new(
185        dataset: Box<dyn Dataset>,
186        batch_size: usize,
187        shuffle: bool,
188        seed: Option<u64>,
189    ) -> Self {
190        let indices = (0..dataset.len()).collect();
191
192        Self {
193            dataset,
194            batch_size,
195            shuffle,
196            seed,
197            indices,
198            position: 0,
199        }
200    }
201
202    /// Reset the data loader.
203    pub fn reset(&mut self) {
204        self.position = 0;
205
206        if self.shuffle {
207            let mut rng = match self.seed {
208                Some(s) => rand::rngs::StdRng::seed_from_u64(s),
209                None => {
210                    let mut rng = rand::rng();
211                    // Get a random seed from rng and create a new StdRng
212                    let random_seed: u64 = rng.random();
213                    rand::rngs::StdRng::seed_from_u64(random_seed)
214                }
215            };
216
217            self.indices.shuffle(&mut rng);
218        }
219    }
220
221    /// Get the next batch from the dataset.
222    pub fn next_batch(&mut self) -> Option<BatchData> {
223        if self.position >= self.dataset.len() {
224            return None;
225        }
226
227        // Determine how many samples to take
228        let remaining = self.dataset.len() - self.position;
229        let batch_size = std::cmp::min(self.batch_size, remaining);
230
231        // Get the batch
232        let mut inputs = Vec::with_capacity(batch_size);
233        let mut targets = Vec::with_capacity(batch_size);
234
235        for i in 0..batch_size {
236            let index = self.indices[self.position + i];
237            if let Some((input, target)) = self.dataset.get(index) {
238                inputs.push(input);
239                targets.push(target);
240            }
241        }
242
243        // Update position
244        self.position += batch_size;
245
246        Some((inputs, targets))
247    }
248
249    /// Get the number of batches in the dataset.
250    pub fn numbatches(&self) -> usize {
251        self.dataset.len().div_ceil(self.batch_size)
252    }
253
254    /// Get a reference to the dataset.
255    pub fn dataset(&self) -> &dyn Dataset {
256        self.dataset.as_ref()
257    }
258}
259
260/// Iterator implementation for DataLoader.
261impl Iterator for DataLoader {
262    type Item = BatchData;
263
264    fn next(&mut self) -> Option<Self::Item> {
265        self.next_batch()
266    }
267}
268
269/// Loss function trait.
270pub trait Loss {
271    /// Compute the loss between predictions and targets.
272    fn forward(
273        &self,
274        predictions: &dyn ArrayProtocol,
275        targets: &dyn ArrayProtocol,
276    ) -> CoreResult<Box<dyn ArrayProtocol>>;
277
278    /// Compute the gradient of the loss with respect to predictions.
279    fn backward(
280        &self,
281        predictions: &dyn ArrayProtocol,
282        targets: &dyn ArrayProtocol,
283    ) -> CoreResult<Box<dyn ArrayProtocol>>;
284
285    /// Get the name of the loss function.
286    fn name(&self) -> &str;
287}
288
289/// Mean squared error loss.
290pub struct MSELoss {
291    /// Name of the loss function.
292    name: String,
293
294    /// Whether to reduce the loss.
295    reduction: String,
296}
297
298impl MSELoss {
299    /// Create a new MSE loss.
300    pub fn new(reduction: Option<&str>) -> Self {
301        Self {
302            name: "MSELoss".to_string(),
303            reduction: reduction.unwrap_or("mean").to_string(),
304        }
305    }
306}
307
308impl Loss for MSELoss {
309    fn forward(
310        &self,
311        predictions: &dyn ArrayProtocol,
312        targets: &dyn ArrayProtocol,
313    ) -> CoreResult<Box<dyn ArrayProtocol>> {
314        // Compute squared difference
315        let diff = subtract(predictions, targets)?;
316        let squared = multiply(diff.as_ref(), diff.as_ref())?;
317
318        // Apply reduction
319        match self.reduction.as_str() {
320            "none" => Ok(squared),
321            "mean" => {
322                // Compute mean of all elements
323                if let Some(array) = squared
324                    .as_any()
325                    .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
326                {
327                    let mean = array.as_array().mean().unwrap();
328                    let result = Array0::<f64>::from_elem((), mean);
329                    Ok(Box::new(NdarrayWrapper::new(result)))
330                } else {
331                    Err(CoreError::NotImplementedError(ErrorContext::new(
332                        "Mean reduction not implemented for this array type".to_string(),
333                    )))
334                }
335            }
336            "sum" => {
337                // Compute sum of all elements
338                if let Some(array) = squared
339                    .as_any()
340                    .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
341                {
342                    let sum = array.as_array().sum();
343                    let result = Array0::<f64>::from_elem((), sum);
344                    Ok(Box::new(NdarrayWrapper::new(result)))
345                } else {
346                    Err(CoreError::NotImplementedError(ErrorContext::new(
347                        "Sum reduction not implemented for this array type".to_string(),
348                    )))
349                }
350            }
351            _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
352                "Unknown reduction: {reduction}",
353                reduction = self.reduction
354            )))),
355        }
356    }
357
358    fn backward(
359        &self,
360        predictions: &dyn ArrayProtocol,
361        targets: &dyn ArrayProtocol,
362    ) -> CoreResult<Box<dyn ArrayProtocol>> {
363        // Gradient of MSE loss: 2 * (predictions - targets)
364        let diff = subtract(predictions, targets)?;
365        let factor = Box::new(NdarrayWrapper::new(ndarray::Array0::<f64>::from_elem(
366            (),
367            2.0,
368        )));
369        let grad = multiply(factor.as_ref(), diff.as_ref())?;
370
371        // Apply reduction scaling if needed
372        match self.reduction.as_str() {
373            "none" => Ok(grad),
374            "mean" => {
375                // For mean reduction, scale by 1/N
376                if let Some(array) = grad
377                    .as_any()
378                    .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
379                {
380                    let n = array.as_array().len() as f64;
381                    let scale_factor = Box::new(NdarrayWrapper::new(
382                        ndarray::Array0::<f64>::from_elem((), 1.0 / n),
383                    ));
384                    Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
385                } else {
386                    Ok(grad)
387                }
388            }
389            "sum" => Ok(grad),
390            _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
391                "Unknown reduction: {reduction}",
392                reduction = self.reduction
393            )))),
394        }
395    }
396
397    fn name(&self) -> &str {
398        &self.name
399    }
400}
401
402/// Cross-entropy loss.
403pub struct CrossEntropyLoss {
404    /// Name of the loss function.
405    name: String,
406
407    /// Whether to reduce the loss.
408    reduction: String,
409}
410
411impl CrossEntropyLoss {
412    /// Create a new cross-entropy loss.
413    pub fn new(reduction: Option<&str>) -> Self {
414        Self {
415            name: "CrossEntropyLoss".to_string(),
416            reduction: reduction.unwrap_or("mean").to_string(),
417        }
418    }
419}
420
421impl Loss for CrossEntropyLoss {
422    fn forward(
423        &self,
424        predictions: &dyn ArrayProtocol,
425        targets: &dyn ArrayProtocol,
426    ) -> CoreResult<Box<dyn ArrayProtocol>> {
427        // Apply softmax to predictions
428        let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
429
430        // Compute cross-entropy
431        if let (Some(preds_array), Some(targets_array)) = (
432            softmax_preds
433                .as_any()
434                .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>(),
435            targets
436                .as_any()
437                .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>(),
438        ) {
439            let preds = preds_array.as_array();
440            let targets = targets_array.as_array();
441
442            // Compute -targets * log(preds)
443            let log_preds = preds.mapv(|x| x.max(1e-10).ln());
444
445            // Compute element-wise multiplication and then negate
446            let mut losses = targets.clone();
447            losses.zip_mut_with(&log_preds, |t, l| *t = -(*t * *l));
448
449            // Apply reduction
450            match self.reduction.as_str() {
451                "none" => Ok(Box::new(NdarrayWrapper::new(losses))),
452                "mean" => {
453                    let mean = losses.mean().unwrap();
454                    let result = Array0::<f64>::from_elem((), mean);
455                    Ok(Box::new(NdarrayWrapper::new(result)))
456                }
457                "sum" => {
458                    let sum = losses.sum();
459                    let result = Array0::<f64>::from_elem((), sum);
460                    Ok(Box::new(NdarrayWrapper::new(result)))
461                }
462                _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
463                    "Unknown reduction: {reduction}",
464                    reduction = self.reduction
465                )))),
466            }
467        } else {
468            Err(CoreError::NotImplementedError(ErrorContext::new(
469                "CrossEntropy not implemented for these array types".to_string(),
470            )))
471        }
472    }
473
474    fn backward(
475        &self,
476        predictions: &dyn ArrayProtocol,
477        targets: &dyn ArrayProtocol,
478    ) -> CoreResult<Box<dyn ArrayProtocol>> {
479        // For cross-entropy with softmax: gradient is softmax(predictions) - targets
480        let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
481        let grad = subtract(softmax_preds.as_ref(), targets)?;
482
483        // Apply reduction scaling if needed
484        match self.reduction.as_str() {
485            "none" => Ok(grad),
486            "mean" => {
487                // For mean reduction, scale by 1/N
488                if let Some(array) = grad
489                    .as_any()
490                    .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
491                {
492                    let n = array.as_array().len() as f64;
493                    let scale_factor = Box::new(NdarrayWrapper::new(
494                        ndarray::Array0::<f64>::from_elem((), 1.0 / n),
495                    ));
496                    Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
497                } else {
498                    Ok(grad)
499                }
500            }
501            "sum" => Ok(grad),
502            _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
503                "Unknown reduction: {reduction}",
504                reduction = self.reduction
505            )))),
506        }
507    }
508
509    fn name(&self) -> &str {
510        &self.name
511    }
512}
513
514/// Metrics for evaluating model performance.
515pub struct Metrics {
516    /// Loss values.
517    losses: Vec<f64>,
518
519    /// Accuracy values (if applicable).
520    accuracies: Option<Vec<f64>>,
521
522    /// Name of the metrics object.
523    name: String,
524}
525
526impl Metrics {
527    /// Create a new metrics object.
528    pub fn new(name: &str) -> Self {
529        Self {
530            losses: Vec::new(),
531            accuracies: None,
532            name: name.to_string(),
533        }
534    }
535
536    /// Add a loss value.
537    pub fn add_loss(&mut self, loss: f64) {
538        self.losses.push(loss);
539    }
540
541    /// Add an accuracy value.
542    pub fn add_accuracy(&mut self, accuracy: f64) {
543        if self.accuracies.is_none() {
544            self.accuracies = Some(Vec::new());
545        }
546
547        if let Some(accuracies) = &mut self.accuracies {
548            accuracies.push(accuracy);
549        }
550    }
551
552    /// Get the mean loss.
553    pub fn mean_loss(&self) -> Option<f64> {
554        if self.losses.is_empty() {
555            return None;
556        }
557
558        let sum: f64 = self.losses.iter().sum();
559        Some(sum / self.losses.len() as f64)
560    }
561
562    /// Get the mean accuracy.
563    pub fn mean_accuracy(&self) -> Option<f64> {
564        if let Some(accuracies) = &self.accuracies {
565            if accuracies.is_empty() {
566                return None;
567            }
568
569            let sum: f64 = accuracies.iter().sum();
570            Some(sum / accuracies.len() as f64)
571        } else {
572            None
573        }
574    }
575
576    /// Reset the metrics.
577    pub fn reset(&mut self) {
578        self.losses.clear();
579        if let Some(accuracies) = &mut self.accuracies {
580            accuracies.clear();
581        }
582    }
583
584    /// Get the name of the metrics object.
585    pub fn name(&self) -> &str {
586        &self.name
587    }
588}
589
590impl fmt::Display for Metrics {
591    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
592        write!(
593            f,
594            "{}: loss = {:.4}",
595            self.name,
596            self.mean_loss().unwrap_or(0.0)
597        )?;
598
599        if let Some(acc) = self.mean_accuracy() {
600            write!(f, ", accuracy = {acc:.4}")?;
601        }
602
603        Ok(())
604    }
605}
606
607/// Training progress callback trait.
608pub trait TrainingCallback {
609    /// Called at the start of each epoch.
610    fn on_epoch_start(&mut self, epoch: usize, numepochs: usize);
611
612    /// Called at the end of each epoch.
613    fn on_epoch_end(&mut self, epoch: usize, numepochs: usize, metrics: &Metrics);
614
615    /// Called at the start of each batch.
616    fn on_batch_start(&mut self, batch: usize, numbatches: usize);
617
618    /// Called at the end of each batch.
619    fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64);
620
621    /// Called at the start of training.
622    fn on_train_start(&mut self, numepochs: usize);
623
624    /// Called at the end of training.
625    fn on_train_end(&mut self, metrics: &Metrics);
626}
627
628/// Progress bar callback for displaying training progress.
629pub struct ProgressCallback {
630    /// Whether to display a progress bar.
631    verbose: bool,
632
633    /// Start time of the current epoch.
634    epoch_start: Option<Instant>,
635
636    /// Start time of training.
637    train_start: Option<Instant>,
638}
639
640impl ProgressCallback {
641    /// Create a new progress callback.
642    pub fn new(verbose: bool) -> Self {
643        Self {
644            verbose,
645            epoch_start: None,
646            train_start: None,
647        }
648    }
649}
650
651impl TrainingCallback for ProgressCallback {
652    fn on_epoch_start(&mut self, epoch: usize, numepochs: usize) {
653        if self.verbose {
654            println!("Epoch {}/{}", epoch + 1, numepochs);
655        }
656
657        self.epoch_start = Some(Instant::now());
658    }
659
660    fn on_epoch_end(&mut self, _epoch: usize, numepochs: usize, metrics: &Metrics) {
661        if self.verbose {
662            if let Some(start) = self.epoch_start {
663                let duration = start.elapsed();
664                println!("{} - {}ms", metrics, duration.as_millis());
665            } else {
666                println!("{metrics}");
667            }
668        }
669    }
670
671    fn on_batch_start(&mut self, _batch: usize, _numbatches: usize) {
672        // No-op for this callback
673    }
674
675    fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64) {
676        if self.verbose && (batch + 1).is_multiple_of((numbatches / 10).max(1)) {
677            print!("\rBatch {}/{} - loss: {:.4}", batch + 1, numbatches, loss);
678            if batch + 1 == numbatches {
679                println!();
680            }
681        }
682    }
683
684    fn on_train_start(&mut self, numepochs: usize) {
685        if self.verbose {
686            println!("Starting training for {numepochs} epochs");
687        }
688
689        self.train_start = Some(Instant::now());
690    }
691
692    fn on_train_end(&mut self, metrics: &Metrics) {
693        if self.verbose {
694            if let Some(start) = self.train_start {
695                let duration = start.elapsed();
696                println!("Training completed in {}s", duration.as_secs());
697            } else {
698                println!("Training completed");
699            }
700
701            if let Some(acc) = metrics.mean_accuracy() {
702                println!("Final accuracy: {acc:.4}");
703            }
704        }
705    }
706}
707
708/// Model trainer for neural networks.
709pub struct Trainer {
710    /// The model to train.
711    model: Sequential,
712
713    /// The optimizer to use.
714    optimizer: Box<dyn Optimizer>,
715
716    /// The loss function to use.
717    lossfn: Box<dyn Loss>,
718
719    /// The callbacks to use during training.
720    callbacks: Vec<Box<dyn TrainingCallback>>,
721
722    /// Training metrics.
723    train_metrics: Metrics,
724
725    /// Validation metrics.
726    val_metrics: Option<Metrics>,
727}
728
729impl Trainer {
730    /// Create a new trainer.
731    pub fn new(model: Sequential, optimizer: Box<dyn Optimizer>, lossfn: Box<dyn Loss>) -> Self {
732        Self {
733            model,
734            optimizer,
735            lossfn,
736            callbacks: Vec::new(),
737            train_metrics: Metrics::new("train"),
738            val_metrics: None,
739        }
740    }
741
742    /// Add a callback to the trainer.
743    pub fn add_callback(&mut self, callback: Box<dyn TrainingCallback>) {
744        self.callbacks.push(callback);
745    }
746
747    /// Train the model.
748    pub fn train(
749        &mut self,
750        train_loader: &mut DataLoader,
751        numepochs: usize,
752        mut val_loader: Option<&mut DataLoader>,
753    ) -> CoreResult<()> {
754        // Notify callbacks that training is starting
755        for callback in &mut self.callbacks {
756            callback.on_train_start(numepochs);
757        }
758
759        // Initialize validation metrics if needed
760        if val_loader.is_some() && self.val_metrics.is_none() {
761            self.val_metrics = Some(Metrics::new("val"));
762        }
763
764        // Train for the specified number of epochs
765        for epoch in 0..numepochs {
766            // Reset metrics
767            self.train_metrics.reset();
768            if let Some(metrics) = &mut self.val_metrics {
769                metrics.reset();
770            }
771
772            // Notify callbacks that epoch is starting
773            for callback in &mut self.callbacks {
774                callback.on_epoch_start(epoch, numepochs);
775            }
776
777            // Train on the training set
778            self.train_epoch(train_loader)?;
779
780            // Validate on the validation set if provided
781            if let Some(ref mut val_loader) = val_loader {
782                self.validate(val_loader)?;
783            }
784
785            // Notify callbacks that epoch is ending
786            for callback in &mut self.callbacks {
787                callback.on_epoch_end(
788                    epoch,
789                    numepochs,
790                    if let Some(val_metrics) = &self.val_metrics {
791                        val_metrics
792                    } else {
793                        &self.train_metrics
794                    },
795                );
796            }
797        }
798
799        // Notify callbacks that training is ending
800        for callback in &mut self.callbacks {
801            callback.on_train_end(if let Some(val_metrics) = &self.val_metrics {
802                val_metrics
803            } else {
804                &self.train_metrics
805            });
806        }
807
808        Ok(())
809    }
810
811    /// Train for one epoch.
812    fn train_epoch(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
813        // Set model to training mode
814        self.model.train();
815
816        // Reset data loader
817        dataloader.reset();
818
819        let numbatches = dataloader.numbatches();
820
821        // Train on batches
822        for batch_idx in 0..numbatches {
823            let (inputs, targets) = dataloader.next_batch().unwrap();
824            // Notify callbacks that batch is starting
825            for callback in &mut self.callbacks {
826                callback.on_batch_start(batch_idx, numbatches);
827            }
828
829            // Forward pass
830            let batch_loss = self.train_batch(&inputs, &targets)?;
831
832            // Update metrics
833            self.train_metrics.add_loss(batch_loss);
834
835            // Notify callbacks that batch is ending
836            for callback in &mut self.callbacks {
837                callback.on_batch_end(batch_idx, numbatches, batch_loss);
838            }
839        }
840
841        Ok(())
842    }
843
844    /// Train on a single batch.
845    fn train_batch(
846        &mut self,
847        inputs: &[Box<dyn ArrayProtocol>],
848        targets: &[Box<dyn ArrayProtocol>],
849    ) -> CoreResult<f64> {
850        // Zero gradients
851        self.optimizer.zero_grad();
852
853        // Forward pass
854        let mut batch_loss = 0.0;
855
856        for (input, target) in inputs.iter().zip(targets.iter()) {
857            // Forward pass through model
858            let output = self.model.forward(input.as_ref())?;
859
860            // Compute loss
861            let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
862
863            // Get loss value
864            if let Some(loss_array) = loss
865                .as_any()
866                .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
867            {
868                let loss_value = loss_array.as_array().sum();
869                batch_loss += loss_value;
870            }
871
872            // Backward pass - compute gradients
873            // For now, implement a simple gradient approximation using finite differences
874            // In a full implementation, this would be automatic differentiation
875
876            let learningrate = 0.001; // Default learning rate
877
878            // Simple gradient estimation for demonstration
879            // This computes numerical gradients for the model parameters
880
881            // Get current output for gradient computation
882            let current_output = self.model.forward(input.as_ref())?;
883            let current_loss = self
884                .lossfn
885                .forward(current_output.as_ref(), target.as_ref())?;
886            let _current_loss_value = if let Some(loss_array) = current_loss
887                .as_any()
888                .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
889            {
890                loss_array.as_array().sum()
891            } else {
892                0.0
893            };
894
895            // Compute gradients via backpropagation
896            let gradients = self.compute_gradients(
897                input.as_ref(),
898                target.as_ref(),
899                current_output.as_ref(),
900                current_loss.as_ref(),
901            )?;
902
903            // Apply gradients to model parameters
904            self.apply_gradients(&gradients, learningrate)?;
905
906            // Store gradients in optimizer for momentum-based optimizers
907            self.optimizer.accumulate_gradients(&gradients)?;
908        }
909
910        // Compute average loss
911        let batch_loss = batch_loss / inputs.len() as f64;
912
913        // Update weights
914        self.optimizer.step()?;
915
916        Ok(batch_loss)
917    }
918
919    /// Compute gradients via backpropagation
920    fn compute_gradients(
921        &self,
922        input: &dyn ArrayProtocol,
923        target: &dyn ArrayProtocol,
924        output: &dyn ArrayProtocol,
925        _loss: &dyn ArrayProtocol,
926    ) -> CoreResult<GradientDict> {
927        // Start backpropagation from loss
928        let mut gradients = GradientDict::new();
929
930        // Compute gradient of loss with respect to output
931        let loss_grad = self.lossfn.backward(output, target)?;
932
933        // Backpropagate through the model
934        let model_gradients = self.model.backward(input, loss_grad.as_ref())?;
935
936        // Merge gradients
937        gradients.merge(model_gradients);
938
939        Ok(gradients)
940    }
941
942    /// Apply computed gradients to model parameters
943    fn apply_gradients(&mut self, gradients: &GradientDict, learningrate: f64) -> CoreResult<()> {
944        // Apply gradients to each parameter in the model
945        for (param_name, gradient) in gradients.iter() {
946            self.model
947                .update_parameter(param_name, gradient.as_ref(), learningrate)?;
948        }
949
950        Ok(())
951    }
952
953    /// Validate the model.
954    fn validate(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
955        // Set model to evaluation mode
956        self.model.eval();
957
958        // Reset validation metrics
959        if let Some(metrics) = &mut self.val_metrics {
960            metrics.reset();
961        } else {
962            return Ok(());
963        }
964
965        // Reset data loader
966        dataloader.reset();
967
968        let numbatches = dataloader.numbatches();
969
970        // Validate on batches
971        for _ in 0..numbatches {
972            let (inputs, targets) = dataloader.next_batch().unwrap();
973            // Forward pass without gradient tracking
974            let mut batch_loss = 0.0;
975            let mut batch_correct = 0;
976            let mut batch_total = 0;
977
978            for (input, target) in inputs.iter().zip(targets.iter()) {
979                // Forward pass through model
980                let output = self.model.forward(input.as_ref())?;
981
982                // Compute loss
983                let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
984
985                // Get loss value
986                if let Some(loss_array) = loss
987                    .as_any()
988                    .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
989                {
990                    let loss_value = loss_array.as_array().sum();
991                    batch_loss += loss_value;
992                }
993
994                // Compute accuracy for classification problems
995                if let (Some(output_array), Some(target_array)) = (
996                    output
997                        .as_any()
998                        .downcast_ref::<NdarrayWrapper<f64, ndarray::Ix2>>(),
999                    target
1000                        .as_any()
1001                        .downcast_ref::<NdarrayWrapper<f64, ndarray::Ix2>>(),
1002                ) {
1003                    // Get predictions (argmax)
1004                    let output_vec = output_array.as_array();
1005                    let target_vec = target_array.as_array();
1006
1007                    // For simplicity, assume 2D arrays [batch_size, num_classes]
1008                    if output_vec.ndim() == 2 && target_vec.ndim() == 2 {
1009                        for (out_row, target_row) in
1010                            output_vec.outer_iter().zip(target_vec.outer_iter())
1011                        {
1012                            // Find the index of the maximum value in the output row
1013                            let mut max_idx = 0;
1014                            let mut max_val = out_row[0];
1015
1016                            for (i, &val) in out_row.iter().enumerate().skip(1) {
1017                                if val > max_val {
1018                                    max_idx = i;
1019                                    max_val = val;
1020                                }
1021                            }
1022
1023                            // Find the index of 1 in the target row (one-hot encoding)
1024                            if let Some(target_idx) = target_row.iter().position(|&x| x == 1.0) {
1025                                if max_idx == target_idx {
1026                                    batch_correct += 1;
1027                                }
1028                            }
1029
1030                            batch_total += 1;
1031                        }
1032                    }
1033                }
1034            }
1035
1036            // Compute average loss and accuracy
1037            let batch_loss = batch_loss / inputs.len() as f64;
1038            let batch_accuracy = if batch_total > 0 {
1039                batch_correct as f64 / batch_total as f64
1040            } else {
1041                0.0
1042            };
1043
1044            // Update validation metrics
1045            if let Some(metrics) = &mut self.val_metrics {
1046                metrics.add_loss(batch_loss);
1047                metrics.add_accuracy(batch_accuracy);
1048            }
1049        }
1050
1051        Ok(())
1052    }
1053
1054    /// Get training metrics.
1055    pub const fn train_metrics(&self) -> &Metrics {
1056        &self.train_metrics
1057    }
1058
1059    /// Get validation metrics.
1060    pub fn val_metrics(&self) -> Option<&Metrics> {
1061        self.val_metrics.as_ref()
1062    }
1063}
1064
1065// Helper functions
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070    use crate::array_protocol::{self, NdarrayWrapper};
1071    use ndarray::Array2;
1072
1073    #[test]
1074    fn test_in_memory_dataset() {
1075        // Create input and target arrays
1076        let inputs = Array2::<f64>::ones((10, 5));
1077        let targets = Array2::<f64>::zeros((10, 2));
1078
1079        // Create dataset
1080        let dataset = InMemoryDataset::from_arrays(inputs, targets);
1081
1082        // Check properties
1083        assert_eq!(dataset.len(), 10);
1084        assert_eq!(dataset.inputshape(), vec![5]);
1085        assert_eq!(dataset.outputshape(), vec![2]);
1086
1087        // Get a sample
1088        let (input, target) = dataset.get(0).unwrap();
1089        assert!(input
1090            .as_any()
1091            .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
1092            .is_some());
1093        assert!(target
1094            .as_any()
1095            .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
1096            .is_some());
1097    }
1098
1099    #[test]
1100    fn test_dataloader() {
1101        // Create input and target arrays
1102        let inputs = Array2::<f64>::ones((10, 5));
1103        let targets = Array2::<f64>::zeros((10, 2));
1104
1105        // Create dataset and data loader
1106        let dataset = Box::new(InMemoryDataset::from_arrays(inputs, targets));
1107        let mut loader = DataLoader::new(dataset, 4, true, Some(42));
1108
1109        // Check properties
1110        assert_eq!(loader.numbatches(), 3);
1111
1112        // Get batches
1113        let (batch1_inputs, batch1_targets) = loader.next_batch().unwrap();
1114        assert_eq!(batch1_inputs.len(), 4);
1115        assert_eq!(batch1_targets.len(), 4);
1116
1117        let (batch2_inputs, batch2_targets) = loader.next_batch().unwrap();
1118        assert_eq!(batch2_inputs.len(), 4);
1119        assert_eq!(batch2_targets.len(), 4);
1120
1121        let (batch3_inputs, batch3_targets) = loader.next_batch().unwrap();
1122        assert_eq!(batch3_inputs.len(), 2);
1123        assert_eq!(batch3_targets.len(), 2);
1124
1125        // Reset and get another batch
1126        loader.reset();
1127        let (batch1_inputs, batch1_targets) = loader.next_batch().unwrap();
1128        assert_eq!(batch1_inputs.len(), 4);
1129        assert_eq!(batch1_targets.len(), 4);
1130    }
1131
1132    #[test]
1133    fn test_mse_loss() {
1134        // Initialize the array protocol system
1135        array_protocol::init();
1136
1137        // Create prediction and target arrays
1138        let predictions = Array2::<f64>::ones((2, 3));
1139        let targets = Array2::<f64>::zeros((2, 3));
1140
1141        let predictions_wrapped = NdarrayWrapper::new(predictions);
1142        let targets_wrapped = NdarrayWrapper::new(targets);
1143
1144        // Create loss function
1145        let mse = MSELoss::new(Some("mean"));
1146
1147        // Compute loss with proper error handling
1148        match mse.forward(&predictions_wrapped, &targets_wrapped) {
1149            Ok(loss) => {
1150                if let Some(loss_array) = loss
1151                    .as_any()
1152                    .downcast_ref::<NdarrayWrapper<f64, ndarray::Ix0>>()
1153                {
1154                    // Expected: mean((1 - 0)^2) = 1.0
1155                    assert_eq!(loss_array.as_array()[()], 1.0);
1156                } else {
1157                    println!("Loss not of expected type NdarrayWrapper<f64, Ix0>");
1158                }
1159            }
1160            Err(e) => {
1161                println!("MSE Loss forward not fully implemented: {e}");
1162            }
1163        }
1164    }
1165
1166    #[test]
1167    fn test_metrics() {
1168        // Create metrics
1169        let mut metrics = Metrics::new("test");
1170
1171        // Add loss values
1172        metrics.add_loss(1.0);
1173        metrics.add_loss(2.0);
1174        metrics.add_loss(3.0);
1175
1176        // Add accuracy values
1177        metrics.add_accuracy(0.5);
1178        metrics.add_accuracy(0.6);
1179        metrics.add_accuracy(0.7);
1180
1181        // Check mean values
1182        assert_eq!(metrics.mean_loss().unwrap(), 2.0);
1183        assert_eq!(metrics.mean_accuracy().unwrap(), 0.6);
1184
1185        // Reset metrics
1186        metrics.reset();
1187        assert!(metrics.mean_loss().is_none());
1188        assert!(metrics.mean_accuracy().is_none());
1189    }
1190}