1use std::fmt;
14use std::time::Instant;
15
16use ::ndarray::{Array, Array0, Dimension};
17use rand::seq::SliceRandom;
18use rand::{Rng, RngExt, SeedableRng};
19
20use crate::array_protocol::grad::{GradientDict, Optimizer};
21use crate::array_protocol::ml_ops::ActivationFunc;
22use crate::array_protocol::neural::Sequential;
23use crate::array_protocol::operations::{multiply, subtract};
24use crate::array_protocol::{activation, ArrayProtocol, NdarrayWrapper};
25use crate::error::{CoreError, CoreResult, ErrorContext};
26
27pub type BatchData = (Vec<Box<dyn ArrayProtocol>>, Vec<Box<dyn ArrayProtocol>>);
29
30pub trait Dataset {
32 fn len(&self) -> usize;
34
35 fn is_empty(&self) -> bool {
37 self.len() == 0
38 }
39
40 fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)>;
42
43 fn inputshape(&self) -> Vec<usize>;
45
46 fn outputshape(&self) -> Vec<usize>;
48}
49
50pub struct InMemoryDataset {
52 inputs: Vec<Box<dyn ArrayProtocol>>,
54
55 targets: Vec<Box<dyn ArrayProtocol>>,
57
58 inputshape: Vec<usize>,
60
61 outputshape: Vec<usize>,
63}
64
65impl InMemoryDataset {
66 pub fn new(
68 inputs: Vec<Box<dyn ArrayProtocol>>,
69 targets: Vec<Box<dyn ArrayProtocol>>,
70 inputshape: Vec<usize>,
71 outputshape: Vec<usize>,
72 ) -> Self {
73 assert_eq!(
74 inputs.len(),
75 targets.len(),
76 "Inputs and targets must have the same length"
77 );
78
79 Self {
80 inputs,
81 targets,
82 inputshape,
83 outputshape,
84 }
85 }
86
87 pub fn from_arrays<T, D1, D2>(inputs: Array<T, D1>, targets: Array<T, D2>) -> Self
89 where
90 T: Clone + Send + Sync + 'static,
91 D1: Dimension + Send + Sync,
92 D2: Dimension + Send + Sync,
93 {
94 let inputshape = inputs.shape().to_vec();
95 let outputshape = targets.shape().to_vec();
96
97 let num_samples = inputshape[0];
99 assert_eq!(
100 num_samples, outputshape[0],
101 "Inputs and targets must have the same number of samples"
102 );
103
104 let mut input_samples = Vec::with_capacity(num_samples);
105 let mut target_samples = Vec::with_capacity(num_samples);
106
107 let to_dyn_inputs = inputs.into_dyn();
109 let to_dyn_targets = targets.into_dyn();
110
111 for i in 0..num_samples {
112 let input_view = to_dyn_inputs.index_axis(crate::ndarray::Axis(0), i);
114 let inputarray = input_view.to_owned();
115 input_samples.push(Box::new(NdarrayWrapper::new(inputarray)) as Box<dyn ArrayProtocol>);
116
117 let target_view = to_dyn_targets.index_axis(crate::ndarray::Axis(0), i);
118 let target_array = target_view.to_owned();
119 target_samples
120 .push(Box::new(NdarrayWrapper::new(target_array)) as Box<dyn ArrayProtocol>);
121 }
122
123 Self {
124 inputs: input_samples,
125 targets: target_samples,
126 inputshape: inputshape[1..].to_vec(),
127 outputshape: outputshape[1..].to_vec(),
128 }
129 }
130}
131
132impl Dataset for InMemoryDataset {
133 fn len(&self) -> usize {
134 self.inputs.len()
135 }
136
137 fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)> {
138 if index >= self.len() {
139 return None;
140 }
141
142 Some((self.inputs[index].clone(), self.targets[index].clone()))
143 }
144
145 fn inputshape(&self) -> Vec<usize> {
146 self.inputshape.clone()
147 }
148
149 fn outputshape(&self) -> Vec<usize> {
150 self.outputshape.clone()
151 }
152}
153
154pub struct DataLoader {
156 dataset: Box<dyn Dataset>,
158
159 batch_size: usize,
161
162 shuffle: bool,
164
165 seed: Option<u64>,
167
168 indices: Vec<usize>,
170
171 position: usize,
173}
174
175impl DataLoader {
176 pub fn new(
178 dataset: Box<dyn Dataset>,
179 batch_size: usize,
180 shuffle: bool,
181 seed: Option<u64>,
182 ) -> Self {
183 let indices = (0..dataset.len()).collect();
184
185 Self {
186 dataset,
187 batch_size,
188 shuffle,
189 seed,
190 indices,
191 position: 0,
192 }
193 }
194
195 pub fn reset(&mut self) {
197 self.position = 0;
198
199 if self.shuffle {
200 let mut rng = match self.seed {
201 Some(s) => rand::rngs::StdRng::seed_from_u64(s),
202 None => {
203 let mut rng = rand::rng();
204 let random_seed: u64 = rng.random();
206 rand::rngs::StdRng::seed_from_u64(random_seed)
207 }
208 };
209
210 self.indices.shuffle(&mut rng);
211 }
212 }
213
214 pub fn next_batch(&mut self) -> Option<BatchData> {
216 if self.position >= self.dataset.len() {
217 return None;
218 }
219
220 let remaining = self.dataset.len() - self.position;
222 let batch_size = std::cmp::min(self.batch_size, remaining);
223
224 let mut inputs = Vec::with_capacity(batch_size);
226 let mut targets = Vec::with_capacity(batch_size);
227
228 for i in 0..batch_size {
229 let index = self.indices[self.position + i];
230 if let Some((input, target)) = self.dataset.get(index) {
231 inputs.push(input);
232 targets.push(target);
233 }
234 }
235
236 self.position += batch_size;
238
239 Some((inputs, targets))
240 }
241
242 pub fn numbatches(&self) -> usize {
244 self.dataset.len().div_ceil(self.batch_size)
245 }
246
247 pub fn dataset(&self) -> &dyn Dataset {
249 self.dataset.as_ref()
250 }
251}
252
253impl Iterator for DataLoader {
255 type Item = BatchData;
256
257 fn next(&mut self) -> Option<Self::Item> {
258 self.next_batch()
259 }
260}
261
262pub trait Loss {
264 fn forward(
266 &self,
267 predictions: &dyn ArrayProtocol,
268 targets: &dyn ArrayProtocol,
269 ) -> CoreResult<Box<dyn ArrayProtocol>>;
270
271 fn backward(
273 &self,
274 predictions: &dyn ArrayProtocol,
275 targets: &dyn ArrayProtocol,
276 ) -> CoreResult<Box<dyn ArrayProtocol>>;
277
278 fn name(&self) -> &str;
280}
281
282pub struct MSELoss {
284 name: String,
286
287 reduction: String,
289}
290
291impl MSELoss {
292 pub fn new(reduction: Option<&str>) -> Self {
294 Self {
295 name: "MSELoss".to_string(),
296 reduction: reduction.unwrap_or("mean").to_string(),
297 }
298 }
299}
300
301impl Loss for MSELoss {
302 fn forward(
303 &self,
304 predictions: &dyn ArrayProtocol,
305 targets: &dyn ArrayProtocol,
306 ) -> CoreResult<Box<dyn ArrayProtocol>> {
307 let diff = subtract(predictions, targets)?;
309 let squared = multiply(diff.as_ref(), diff.as_ref())?;
310
311 match self.reduction.as_str() {
313 "none" => Ok(squared),
314 "mean" => {
315 if let Some(array) = squared
317 .as_any()
318 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
319 {
320 let mean = array.as_array().mean().expect("Operation failed");
321 let result = Array0::<f64>::from_elem((), mean);
322 Ok(Box::new(NdarrayWrapper::new(result)))
323 } else {
324 Err(CoreError::NotImplementedError(ErrorContext::new(
325 "Mean reduction not implemented for this array type".to_string(),
326 )))
327 }
328 }
329 "sum" => {
330 if let Some(array) = squared
332 .as_any()
333 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
334 {
335 let sum = array.as_array().sum();
336 let result = Array0::<f64>::from_elem((), sum);
337 Ok(Box::new(NdarrayWrapper::new(result)))
338 } else {
339 Err(CoreError::NotImplementedError(ErrorContext::new(
340 "Sum reduction not implemented for this array type".to_string(),
341 )))
342 }
343 }
344 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
345 "Unknown reduction: {reduction}",
346 reduction = self.reduction
347 )))),
348 }
349 }
350
351 fn backward(
352 &self,
353 predictions: &dyn ArrayProtocol,
354 targets: &dyn ArrayProtocol,
355 ) -> CoreResult<Box<dyn ArrayProtocol>> {
356 let diff = subtract(predictions, targets)?;
358 let factor = Box::new(NdarrayWrapper::new(
359 crate::ndarray::Array0::<f64>::from_elem((), 2.0),
360 ));
361 let grad = multiply(factor.as_ref(), diff.as_ref())?;
362
363 match self.reduction.as_str() {
365 "none" => Ok(grad),
366 "mean" => {
367 if let Some(array) = grad
369 .as_any()
370 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
371 {
372 let n = array.as_array().len() as f64;
373 let scale_factor = Box::new(NdarrayWrapper::new(
374 crate::ndarray::Array0::<f64>::from_elem((), 1.0 / n),
375 ));
376 Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
377 } else {
378 Ok(grad)
379 }
380 }
381 "sum" => Ok(grad),
382 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
383 "Unknown reduction: {reduction}",
384 reduction = self.reduction
385 )))),
386 }
387 }
388
389 fn name(&self) -> &str {
390 &self.name
391 }
392}
393
394pub struct CrossEntropyLoss {
396 name: String,
398
399 reduction: String,
401}
402
403impl CrossEntropyLoss {
404 pub fn new(reduction: Option<&str>) -> Self {
406 Self {
407 name: "CrossEntropyLoss".to_string(),
408 reduction: reduction.unwrap_or("mean").to_string(),
409 }
410 }
411}
412
413impl Loss for CrossEntropyLoss {
414 fn forward(
415 &self,
416 predictions: &dyn ArrayProtocol,
417 targets: &dyn ArrayProtocol,
418 ) -> CoreResult<Box<dyn ArrayProtocol>> {
419 let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
421
422 if let (Some(preds_array), Some(targets_array)) = (
424 softmax_preds
425 .as_any()
426 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>(),
427 targets
428 .as_any()
429 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>(),
430 ) {
431 let preds = preds_array.as_array();
432 let targets = targets_array.as_array();
433
434 let log_preds = preds.mapv(|x| x.max(1e-10).ln());
436
437 let mut losses = targets.clone();
439 losses.zip_mut_with(&log_preds, |t, l| *t = -(*t * *l));
440
441 match self.reduction.as_str() {
443 "none" => Ok(Box::new(NdarrayWrapper::new(losses))),
444 "mean" => {
445 let mean = losses.mean().expect("Operation failed");
446 let result = Array0::<f64>::from_elem((), mean);
447 Ok(Box::new(NdarrayWrapper::new(result)))
448 }
449 "sum" => {
450 let sum = losses.sum();
451 let result = Array0::<f64>::from_elem((), sum);
452 Ok(Box::new(NdarrayWrapper::new(result)))
453 }
454 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
455 "Unknown reduction: {reduction}",
456 reduction = self.reduction
457 )))),
458 }
459 } else {
460 Err(CoreError::NotImplementedError(ErrorContext::new(
461 "CrossEntropy not implemented for these array types".to_string(),
462 )))
463 }
464 }
465
466 fn backward(
467 &self,
468 predictions: &dyn ArrayProtocol,
469 targets: &dyn ArrayProtocol,
470 ) -> CoreResult<Box<dyn ArrayProtocol>> {
471 let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
473 let grad = subtract(softmax_preds.as_ref(), targets)?;
474
475 match self.reduction.as_str() {
477 "none" => Ok(grad),
478 "mean" => {
479 if let Some(array) = grad
481 .as_any()
482 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
483 {
484 let n = array.as_array().len() as f64;
485 let scale_factor = Box::new(NdarrayWrapper::new(
486 crate::ndarray::Array0::<f64>::from_elem((), 1.0 / n),
487 ));
488 Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
489 } else {
490 Ok(grad)
491 }
492 }
493 "sum" => Ok(grad),
494 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
495 "Unknown reduction: {reduction}",
496 reduction = self.reduction
497 )))),
498 }
499 }
500
501 fn name(&self) -> &str {
502 &self.name
503 }
504}
505
506pub struct Metrics {
508 losses: Vec<f64>,
510
511 accuracies: Option<Vec<f64>>,
513
514 name: String,
516}
517
518impl Metrics {
519 pub fn new(name: &str) -> Self {
521 Self {
522 losses: Vec::new(),
523 accuracies: None,
524 name: name.to_string(),
525 }
526 }
527
528 pub fn add_loss(&mut self, loss: f64) {
530 self.losses.push(loss);
531 }
532
533 pub fn add_accuracy(&mut self, accuracy: f64) {
535 if self.accuracies.is_none() {
536 self.accuracies = Some(Vec::new());
537 }
538
539 if let Some(accuracies) = &mut self.accuracies {
540 accuracies.push(accuracy);
541 }
542 }
543
544 pub fn mean_loss(&self) -> Option<f64> {
546 if self.losses.is_empty() {
547 return None;
548 }
549
550 let sum: f64 = self.losses.iter().sum();
551 Some(sum / self.losses.len() as f64)
552 }
553
554 pub fn mean_accuracy(&self) -> Option<f64> {
556 if let Some(accuracies) = &self.accuracies {
557 if accuracies.is_empty() {
558 return None;
559 }
560
561 let sum: f64 = accuracies.iter().sum();
562 Some(sum / accuracies.len() as f64)
563 } else {
564 None
565 }
566 }
567
568 pub fn reset(&mut self) {
570 self.losses.clear();
571 if let Some(accuracies) = &mut self.accuracies {
572 accuracies.clear();
573 }
574 }
575
576 pub fn name(&self) -> &str {
578 &self.name
579 }
580}
581
582impl fmt::Display for Metrics {
583 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
584 write!(
585 f,
586 "{}: loss = {:.4}",
587 self.name,
588 self.mean_loss().unwrap_or(0.0)
589 )?;
590
591 if let Some(acc) = self.mean_accuracy() {
592 write!(f, ", accuracy = {acc:.4}")?;
593 }
594
595 Ok(())
596 }
597}
598
599pub trait TrainingCallback {
601 fn on_epoch_start(&mut self, epoch: usize, numepochs: usize);
603
604 fn on_epoch_end(&mut self, epoch: usize, numepochs: usize, metrics: &Metrics);
606
607 fn on_batch_start(&mut self, batch: usize, numbatches: usize);
609
610 fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64);
612
613 fn on_train_start(&mut self, numepochs: usize);
615
616 fn on_train_end(&mut self, metrics: &Metrics);
618}
619
620pub struct ProgressCallback {
622 verbose: bool,
624
625 epoch_start: Option<Instant>,
627
628 train_start: Option<Instant>,
630}
631
632impl ProgressCallback {
633 pub fn new(verbose: bool) -> Self {
635 Self {
636 verbose,
637 epoch_start: None,
638 train_start: None,
639 }
640 }
641}
642
643impl TrainingCallback for ProgressCallback {
644 fn on_epoch_start(&mut self, epoch: usize, numepochs: usize) {
645 if self.verbose {
646 println!("Epoch {}/{}", epoch + 1, numepochs);
647 }
648
649 self.epoch_start = Some(Instant::now());
650 }
651
652 fn on_epoch_end(&mut self, _epoch: usize, numepochs: usize, metrics: &Metrics) {
653 if self.verbose {
654 if let Some(start) = self.epoch_start {
655 let duration = start.elapsed();
656 println!("{} - {}ms", metrics, duration.as_millis());
657 } else {
658 println!("{metrics}");
659 }
660 }
661 }
662
663 fn on_batch_start(&mut self, _batch: usize, _numbatches: usize) {
664 }
666
667 fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64) {
668 if self.verbose && (batch + 1) % (numbatches / 10).max(1) == 0 {
669 print!("\rBatch {}/{} - loss: {:.4}", batch + 1, numbatches, loss);
670 if batch + 1 == numbatches {
671 println!();
672 }
673 }
674 }
675
676 fn on_train_start(&mut self, numepochs: usize) {
677 if self.verbose {
678 println!("Starting training for {numepochs} epochs");
679 }
680
681 self.train_start = Some(Instant::now());
682 }
683
684 fn on_train_end(&mut self, metrics: &Metrics) {
685 if self.verbose {
686 if let Some(start) = self.train_start {
687 let duration = start.elapsed();
688 println!("Training completed in {}s", duration.as_secs());
689 } else {
690 println!("Training completed");
691 }
692
693 if let Some(acc) = metrics.mean_accuracy() {
694 println!("Final accuracy: {acc:.4}");
695 }
696 }
697 }
698}
699
700pub struct Trainer {
702 model: Sequential,
704
705 optimizer: Box<dyn Optimizer>,
707
708 lossfn: Box<dyn Loss>,
710
711 callbacks: Vec<Box<dyn TrainingCallback>>,
713
714 train_metrics: Metrics,
716
717 val_metrics: Option<Metrics>,
719}
720
721impl Trainer {
722 pub fn new(model: Sequential, optimizer: Box<dyn Optimizer>, lossfn: Box<dyn Loss>) -> Self {
724 Self {
725 model,
726 optimizer,
727 lossfn,
728 callbacks: Vec::new(),
729 train_metrics: Metrics::new("train"),
730 val_metrics: None,
731 }
732 }
733
734 pub fn add_callback(&mut self, callback: Box<dyn TrainingCallback>) {
736 self.callbacks.push(callback);
737 }
738
739 pub fn train(
741 &mut self,
742 train_loader: &mut DataLoader,
743 numepochs: usize,
744 mut val_loader: Option<&mut DataLoader>,
745 ) -> CoreResult<()> {
746 for callback in &mut self.callbacks {
748 callback.on_train_start(numepochs);
749 }
750
751 if val_loader.is_some() && self.val_metrics.is_none() {
753 self.val_metrics = Some(Metrics::new("val"));
754 }
755
756 for epoch in 0..numepochs {
758 self.train_metrics.reset();
760 if let Some(metrics) = &mut self.val_metrics {
761 metrics.reset();
762 }
763
764 for callback in &mut self.callbacks {
766 callback.on_epoch_start(epoch, numepochs);
767 }
768
769 self.train_epoch(train_loader)?;
771
772 if let Some(ref mut val_loader) = val_loader {
774 self.validate(val_loader)?;
775 }
776
777 for callback in &mut self.callbacks {
779 callback.on_epoch_end(
780 epoch,
781 numepochs,
782 if let Some(val_metrics) = &self.val_metrics {
783 val_metrics
784 } else {
785 &self.train_metrics
786 },
787 );
788 }
789 }
790
791 for callback in &mut self.callbacks {
793 callback.on_train_end(if let Some(val_metrics) = &self.val_metrics {
794 val_metrics
795 } else {
796 &self.train_metrics
797 });
798 }
799
800 Ok(())
801 }
802
803 fn train_epoch(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
805 self.model.train();
807
808 dataloader.reset();
810
811 let numbatches = dataloader.numbatches();
812
813 for batch_idx in 0..numbatches {
815 let (inputs, targets) = dataloader.next_batch().expect("Operation failed");
816 for callback in &mut self.callbacks {
818 callback.on_batch_start(batch_idx, numbatches);
819 }
820
821 let batch_loss = self.train_batch(&inputs, &targets)?;
823
824 self.train_metrics.add_loss(batch_loss);
826
827 for callback in &mut self.callbacks {
829 callback.on_batch_end(batch_idx, numbatches, batch_loss);
830 }
831 }
832
833 Ok(())
834 }
835
836 fn train_batch(
838 &mut self,
839 inputs: &[Box<dyn ArrayProtocol>],
840 targets: &[Box<dyn ArrayProtocol>],
841 ) -> CoreResult<f64> {
842 self.optimizer.zero_grad();
844
845 let mut batch_loss = 0.0;
847
848 for (input, target) in inputs.iter().zip(targets.iter()) {
849 let output = self.model.forward(input.as_ref())?;
851
852 let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
854
855 if let Some(loss_array) = loss
857 .as_any()
858 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
859 {
860 let loss_value = loss_array.as_array().sum();
861 batch_loss += loss_value;
862 }
863
864 let learningrate = 0.001; let current_output = self.model.forward(input.as_ref())?;
875 let current_loss = self
876 .lossfn
877 .forward(current_output.as_ref(), target.as_ref())?;
878 let _current_loss_value = if let Some(loss_array) = current_loss
879 .as_any()
880 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
881 {
882 loss_array.as_array().sum()
883 } else {
884 0.0
885 };
886
887 let gradients = self.compute_gradients(
889 input.as_ref(),
890 target.as_ref(),
891 current_output.as_ref(),
892 current_loss.as_ref(),
893 )?;
894
895 self.apply_gradients(&gradients, learningrate)?;
897
898 self.optimizer.accumulate_gradients(&gradients)?;
900 }
901
902 let batch_loss = batch_loss / inputs.len() as f64;
904
905 self.optimizer.step()?;
907
908 Ok(batch_loss)
909 }
910
911 fn compute_gradients(
913 &self,
914 input: &dyn ArrayProtocol,
915 target: &dyn ArrayProtocol,
916 output: &dyn ArrayProtocol,
917 _loss: &dyn ArrayProtocol,
918 ) -> CoreResult<GradientDict> {
919 let mut gradients = GradientDict::new();
921
922 let loss_grad = self.lossfn.backward(output, target)?;
924
925 let model_gradients = self.model.backward(input, loss_grad.as_ref())?;
927
928 gradients.merge(model_gradients);
930
931 Ok(gradients)
932 }
933
934 fn apply_gradients(&mut self, gradients: &GradientDict, learningrate: f64) -> CoreResult<()> {
936 for (param_name, gradient) in gradients.iter() {
938 self.model
939 .update_parameter(param_name, gradient.as_ref(), learningrate)?;
940 }
941
942 Ok(())
943 }
944
945 fn validate(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
947 self.model.eval();
949
950 if let Some(metrics) = &mut self.val_metrics {
952 metrics.reset();
953 } else {
954 return Ok(());
955 }
956
957 dataloader.reset();
959
960 let numbatches = dataloader.numbatches();
961
962 for _ in 0..numbatches {
964 let (inputs, targets) = dataloader.next_batch().expect("Operation failed");
965 let mut batch_loss = 0.0;
967 let mut batch_correct = 0;
968 let mut batch_total = 0;
969
970 for (input, target) in inputs.iter().zip(targets.iter()) {
971 let output = self.model.forward(input.as_ref())?;
973
974 let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
976
977 if let Some(loss_array) = loss
979 .as_any()
980 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
981 {
982 let loss_value = loss_array.as_array().sum();
983 batch_loss += loss_value;
984 }
985
986 if let (Some(output_array), Some(target_array)) = (
988 output
989 .as_any()
990 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>(),
991 target
992 .as_any()
993 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>(),
994 ) {
995 let output_vec = output_array.as_array();
997 let target_vec = target_array.as_array();
998
999 if output_vec.ndim() == 2 && target_vec.ndim() == 2 {
1001 for (out_row, target_row) in
1002 output_vec.outer_iter().zip(target_vec.outer_iter())
1003 {
1004 let mut max_idx = 0;
1006 let mut max_val = out_row[0];
1007
1008 for (i, &val) in out_row.iter().enumerate().skip(1) {
1009 if val > max_val {
1010 max_idx = i;
1011 max_val = val;
1012 }
1013 }
1014
1015 if let Some(target_idx) = target_row.iter().position(|&x| x == 1.0) {
1017 if max_idx == target_idx {
1018 batch_correct += 1;
1019 }
1020 }
1021
1022 batch_total += 1;
1023 }
1024 }
1025 }
1026 }
1027
1028 let batch_loss = batch_loss / inputs.len() as f64;
1030 let batch_accuracy = if batch_total > 0 {
1031 batch_correct as f64 / batch_total as f64
1032 } else {
1033 0.0
1034 };
1035
1036 if let Some(metrics) = &mut self.val_metrics {
1038 metrics.add_loss(batch_loss);
1039 metrics.add_accuracy(batch_accuracy);
1040 }
1041 }
1042
1043 Ok(())
1044 }
1045
1046 pub const fn train_metrics(&self) -> &Metrics {
1048 &self.train_metrics
1049 }
1050
1051 pub fn val_metrics(&self) -> Option<&Metrics> {
1053 self.val_metrics.as_ref()
1054 }
1055}
1056
1057#[cfg(test)]
1060mod tests {
1061 use super::*;
1062 use crate::array_protocol::{self, NdarrayWrapper};
1063 use ::ndarray::Array2;
1064
1065 #[test]
1066 fn test_in_memory_dataset() {
1067 let inputs = Array2::<f64>::ones((10, 5));
1069 let targets = Array2::<f64>::zeros((10, 2));
1070
1071 let dataset = InMemoryDataset::from_arrays(inputs, targets);
1073
1074 assert_eq!(dataset.len(), 10);
1076 assert_eq!(dataset.inputshape(), vec![5]);
1077 assert_eq!(dataset.outputshape(), vec![2]);
1078
1079 let (input, target) = dataset.get(0).expect("Operation failed");
1081 assert!(input
1082 .as_any()
1083 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
1084 .is_some());
1085 assert!(target
1086 .as_any()
1087 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
1088 .is_some());
1089 }
1090
1091 #[test]
1092 fn test_dataloader() {
1093 let inputs = Array2::<f64>::ones((10, 5));
1095 let targets = Array2::<f64>::zeros((10, 2));
1096
1097 let dataset = Box::new(InMemoryDataset::from_arrays(inputs, targets));
1099 let mut loader = DataLoader::new(dataset, 4, true, Some(42));
1100
1101 assert_eq!(loader.numbatches(), 3);
1103
1104 let (batch1_inputs, batch1_targets) = loader.next_batch().expect("Operation failed");
1106 assert_eq!(batch1_inputs.len(), 4);
1107 assert_eq!(batch1_targets.len(), 4);
1108
1109 let (batch2_inputs, batch2_targets) = loader.next_batch().expect("Operation failed");
1110 assert_eq!(batch2_inputs.len(), 4);
1111 assert_eq!(batch2_targets.len(), 4);
1112
1113 let (batch3_inputs, batch3_targets) = loader.next_batch().expect("Operation failed");
1114 assert_eq!(batch3_inputs.len(), 2);
1115 assert_eq!(batch3_targets.len(), 2);
1116
1117 loader.reset();
1119 let (batch1_inputs, batch1_targets) = loader.next_batch().expect("Operation failed");
1120 assert_eq!(batch1_inputs.len(), 4);
1121 assert_eq!(batch1_targets.len(), 4);
1122 }
1123
1124 #[test]
1125 fn test_mse_loss() {
1126 array_protocol::init();
1128
1129 let predictions = Array2::<f64>::ones((2, 3));
1131 let targets = Array2::<f64>::zeros((2, 3));
1132
1133 let predictions_wrapped = NdarrayWrapper::new(predictions);
1134 let targets_wrapped = NdarrayWrapper::new(targets);
1135
1136 let mse = MSELoss::new(Some("mean"));
1138
1139 match mse.forward(&predictions_wrapped, &targets_wrapped) {
1141 Ok(loss) => {
1142 if let Some(loss_array) = loss
1143 .as_any()
1144 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix0>>()
1145 {
1146 assert_eq!(loss_array.as_array()[()], 1.0);
1148 } else {
1149 println!("Loss not of expected type NdarrayWrapper<f64, Ix0>");
1150 }
1151 }
1152 Err(e) => {
1153 println!("MSE Loss forward not fully implemented: {e}");
1154 }
1155 }
1156 }
1157
1158 #[test]
1159 fn test_metrics() {
1160 let mut metrics = Metrics::new("test");
1162
1163 metrics.add_loss(1.0);
1165 metrics.add_loss(2.0);
1166 metrics.add_loss(3.0);
1167
1168 metrics.add_accuracy(0.5);
1170 metrics.add_accuracy(0.6);
1171 metrics.add_accuracy(0.7);
1172
1173 assert_eq!(metrics.mean_loss().expect("Operation failed"), 2.0);
1175 assert_eq!(metrics.mean_accuracy().expect("Operation failed"), 0.6);
1176
1177 metrics.reset();
1179 assert!(metrics.mean_loss().is_none());
1180 assert!(metrics.mean_accuracy().is_none());
1181 }
1182}