1use std::fmt;
20use std::time::Instant;
21
22use ::ndarray::{Array, Array0, Dimension};
23use rand::seq::SliceRandom;
24use rand::Rng;
25use rand::SeedableRng;
26
27use crate::array_protocol::grad::{GradientDict, Optimizer};
28use crate::array_protocol::ml_ops::ActivationFunc;
29use crate::array_protocol::neural::Sequential;
30use crate::array_protocol::operations::{multiply, subtract};
31use crate::array_protocol::{activation, ArrayProtocol, NdarrayWrapper};
32use crate::error::{CoreError, CoreResult, ErrorContext};
33
34pub type BatchData = (Vec<Box<dyn ArrayProtocol>>, Vec<Box<dyn ArrayProtocol>>);
36
37pub trait Dataset {
39 fn len(&self) -> usize;
41
42 fn is_empty(&self) -> bool {
44 self.len() == 0
45 }
46
47 fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)>;
49
50 fn inputshape(&self) -> Vec<usize>;
52
53 fn outputshape(&self) -> Vec<usize>;
55}
56
57pub struct InMemoryDataset {
59 inputs: Vec<Box<dyn ArrayProtocol>>,
61
62 targets: Vec<Box<dyn ArrayProtocol>>,
64
65 inputshape: Vec<usize>,
67
68 outputshape: Vec<usize>,
70}
71
72impl InMemoryDataset {
73 pub fn new(
75 inputs: Vec<Box<dyn ArrayProtocol>>,
76 targets: Vec<Box<dyn ArrayProtocol>>,
77 inputshape: Vec<usize>,
78 outputshape: Vec<usize>,
79 ) -> Self {
80 assert_eq!(
81 inputs.len(),
82 targets.len(),
83 "Inputs and targets must have the same length"
84 );
85
86 Self {
87 inputs,
88 targets,
89 inputshape,
90 outputshape,
91 }
92 }
93
94 pub fn from_arrays<T, D1, D2>(inputs: Array<T, D1>, targets: Array<T, D2>) -> Self
96 where
97 T: Clone + Send + Sync + 'static,
98 D1: Dimension + Send + Sync,
99 D2: Dimension + Send + Sync,
100 {
101 let inputshape = inputs.shape().to_vec();
102 let outputshape = targets.shape().to_vec();
103
104 let num_samples = inputshape[0];
106 assert_eq!(
107 num_samples, outputshape[0],
108 "Inputs and targets must have the same number of samples"
109 );
110
111 let mut input_samples = Vec::with_capacity(num_samples);
112 let mut target_samples = Vec::with_capacity(num_samples);
113
114 let to_dyn_inputs = inputs.into_dyn();
116 let to_dyn_targets = targets.into_dyn();
117
118 for i in 0..num_samples {
119 let input_view = to_dyn_inputs.index_axis(crate::ndarray::Axis(0), i);
121 let inputarray = input_view.to_owned();
122 input_samples.push(Box::new(NdarrayWrapper::new(inputarray)) as Box<dyn ArrayProtocol>);
123
124 let target_view = to_dyn_targets.index_axis(crate::ndarray::Axis(0), i);
125 let target_array = target_view.to_owned();
126 target_samples
127 .push(Box::new(NdarrayWrapper::new(target_array)) as Box<dyn ArrayProtocol>);
128 }
129
130 Self {
131 inputs: input_samples,
132 targets: target_samples,
133 inputshape: inputshape[1..].to_vec(),
134 outputshape: outputshape[1..].to_vec(),
135 }
136 }
137}
138
139impl Dataset for InMemoryDataset {
140 fn len(&self) -> usize {
141 self.inputs.len()
142 }
143
144 fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)> {
145 if index >= self.len() {
146 return None;
147 }
148
149 Some((self.inputs[index].clone(), self.targets[index].clone()))
150 }
151
152 fn inputshape(&self) -> Vec<usize> {
153 self.inputshape.clone()
154 }
155
156 fn outputshape(&self) -> Vec<usize> {
157 self.outputshape.clone()
158 }
159}
160
161pub struct DataLoader {
163 dataset: Box<dyn Dataset>,
165
166 batch_size: usize,
168
169 shuffle: bool,
171
172 seed: Option<u64>,
174
175 indices: Vec<usize>,
177
178 position: usize,
180}
181
182impl DataLoader {
183 pub fn new(
185 dataset: Box<dyn Dataset>,
186 batch_size: usize,
187 shuffle: bool,
188 seed: Option<u64>,
189 ) -> Self {
190 let indices = (0..dataset.len()).collect();
191
192 Self {
193 dataset,
194 batch_size,
195 shuffle,
196 seed,
197 indices,
198 position: 0,
199 }
200 }
201
202 pub fn reset(&mut self) {
204 self.position = 0;
205
206 if self.shuffle {
207 let mut rng = match self.seed {
208 Some(s) => rand::rngs::StdRng::seed_from_u64(s),
209 None => {
210 let mut rng = rand::rng();
211 let random_seed: u64 = rng.random();
213 rand::rngs::StdRng::seed_from_u64(random_seed)
214 }
215 };
216
217 self.indices.shuffle(&mut rng);
218 }
219 }
220
221 pub fn next_batch(&mut self) -> Option<BatchData> {
223 if self.position >= self.dataset.len() {
224 return None;
225 }
226
227 let remaining = self.dataset.len() - self.position;
229 let batch_size = std::cmp::min(self.batch_size, remaining);
230
231 let mut inputs = Vec::with_capacity(batch_size);
233 let mut targets = Vec::with_capacity(batch_size);
234
235 for i in 0..batch_size {
236 let index = self.indices[self.position + i];
237 if let Some((input, target)) = self.dataset.get(index) {
238 inputs.push(input);
239 targets.push(target);
240 }
241 }
242
243 self.position += batch_size;
245
246 Some((inputs, targets))
247 }
248
249 pub fn numbatches(&self) -> usize {
251 self.dataset.len().div_ceil(self.batch_size)
252 }
253
254 pub fn dataset(&self) -> &dyn Dataset {
256 self.dataset.as_ref()
257 }
258}
259
260impl Iterator for DataLoader {
262 type Item = BatchData;
263
264 fn next(&mut self) -> Option<Self::Item> {
265 self.next_batch()
266 }
267}
268
269pub trait Loss {
271 fn forward(
273 &self,
274 predictions: &dyn ArrayProtocol,
275 targets: &dyn ArrayProtocol,
276 ) -> CoreResult<Box<dyn ArrayProtocol>>;
277
278 fn backward(
280 &self,
281 predictions: &dyn ArrayProtocol,
282 targets: &dyn ArrayProtocol,
283 ) -> CoreResult<Box<dyn ArrayProtocol>>;
284
285 fn name(&self) -> &str;
287}
288
289pub struct MSELoss {
291 name: String,
293
294 reduction: String,
296}
297
298impl MSELoss {
299 pub fn new(reduction: Option<&str>) -> Self {
301 Self {
302 name: "MSELoss".to_string(),
303 reduction: reduction.unwrap_or("mean").to_string(),
304 }
305 }
306}
307
308impl Loss for MSELoss {
309 fn forward(
310 &self,
311 predictions: &dyn ArrayProtocol,
312 targets: &dyn ArrayProtocol,
313 ) -> CoreResult<Box<dyn ArrayProtocol>> {
314 let diff = subtract(predictions, targets)?;
316 let squared = multiply(diff.as_ref(), diff.as_ref())?;
317
318 match self.reduction.as_str() {
320 "none" => Ok(squared),
321 "mean" => {
322 if let Some(array) = squared
324 .as_any()
325 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
326 {
327 let mean = array.as_array().mean().unwrap();
328 let result = Array0::<f64>::from_elem((), mean);
329 Ok(Box::new(NdarrayWrapper::new(result)))
330 } else {
331 Err(CoreError::NotImplementedError(ErrorContext::new(
332 "Mean reduction not implemented for this array type".to_string(),
333 )))
334 }
335 }
336 "sum" => {
337 if let Some(array) = squared
339 .as_any()
340 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
341 {
342 let sum = array.as_array().sum();
343 let result = Array0::<f64>::from_elem((), sum);
344 Ok(Box::new(NdarrayWrapper::new(result)))
345 } else {
346 Err(CoreError::NotImplementedError(ErrorContext::new(
347 "Sum reduction not implemented for this array type".to_string(),
348 )))
349 }
350 }
351 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
352 "Unknown reduction: {reduction}",
353 reduction = self.reduction
354 )))),
355 }
356 }
357
358 fn backward(
359 &self,
360 predictions: &dyn ArrayProtocol,
361 targets: &dyn ArrayProtocol,
362 ) -> CoreResult<Box<dyn ArrayProtocol>> {
363 let diff = subtract(predictions, targets)?;
365 let factor = Box::new(NdarrayWrapper::new(
366 crate::ndarray::Array0::<f64>::from_elem((), 2.0),
367 ));
368 let grad = multiply(factor.as_ref(), diff.as_ref())?;
369
370 match self.reduction.as_str() {
372 "none" => Ok(grad),
373 "mean" => {
374 if let Some(array) = grad
376 .as_any()
377 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
378 {
379 let n = array.as_array().len() as f64;
380 let scale_factor = Box::new(NdarrayWrapper::new(
381 crate::ndarray::Array0::<f64>::from_elem((), 1.0 / n),
382 ));
383 Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
384 } else {
385 Ok(grad)
386 }
387 }
388 "sum" => Ok(grad),
389 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
390 "Unknown reduction: {reduction}",
391 reduction = self.reduction
392 )))),
393 }
394 }
395
396 fn name(&self) -> &str {
397 &self.name
398 }
399}
400
401pub struct CrossEntropyLoss {
403 name: String,
405
406 reduction: String,
408}
409
410impl CrossEntropyLoss {
411 pub fn new(reduction: Option<&str>) -> Self {
413 Self {
414 name: "CrossEntropyLoss".to_string(),
415 reduction: reduction.unwrap_or("mean").to_string(),
416 }
417 }
418}
419
420impl Loss for CrossEntropyLoss {
421 fn forward(
422 &self,
423 predictions: &dyn ArrayProtocol,
424 targets: &dyn ArrayProtocol,
425 ) -> CoreResult<Box<dyn ArrayProtocol>> {
426 let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
428
429 if let (Some(preds_array), Some(targets_array)) = (
431 softmax_preds
432 .as_any()
433 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>(),
434 targets
435 .as_any()
436 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>(),
437 ) {
438 let preds = preds_array.as_array();
439 let targets = targets_array.as_array();
440
441 let log_preds = preds.mapv(|x| x.max(1e-10).ln());
443
444 let mut losses = targets.clone();
446 losses.zip_mut_with(&log_preds, |t, l| *t = -(*t * *l));
447
448 match self.reduction.as_str() {
450 "none" => Ok(Box::new(NdarrayWrapper::new(losses))),
451 "mean" => {
452 let mean = losses.mean().unwrap();
453 let result = Array0::<f64>::from_elem((), mean);
454 Ok(Box::new(NdarrayWrapper::new(result)))
455 }
456 "sum" => {
457 let sum = losses.sum();
458 let result = Array0::<f64>::from_elem((), sum);
459 Ok(Box::new(NdarrayWrapper::new(result)))
460 }
461 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
462 "Unknown reduction: {reduction}",
463 reduction = self.reduction
464 )))),
465 }
466 } else {
467 Err(CoreError::NotImplementedError(ErrorContext::new(
468 "CrossEntropy not implemented for these array types".to_string(),
469 )))
470 }
471 }
472
473 fn backward(
474 &self,
475 predictions: &dyn ArrayProtocol,
476 targets: &dyn ArrayProtocol,
477 ) -> CoreResult<Box<dyn ArrayProtocol>> {
478 let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
480 let grad = subtract(softmax_preds.as_ref(), targets)?;
481
482 match self.reduction.as_str() {
484 "none" => Ok(grad),
485 "mean" => {
486 if let Some(array) = grad
488 .as_any()
489 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
490 {
491 let n = array.as_array().len() as f64;
492 let scale_factor = Box::new(NdarrayWrapper::new(
493 crate::ndarray::Array0::<f64>::from_elem((), 1.0 / n),
494 ));
495 Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
496 } else {
497 Ok(grad)
498 }
499 }
500 "sum" => Ok(grad),
501 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
502 "Unknown reduction: {reduction}",
503 reduction = self.reduction
504 )))),
505 }
506 }
507
508 fn name(&self) -> &str {
509 &self.name
510 }
511}
512
513pub struct Metrics {
515 losses: Vec<f64>,
517
518 accuracies: Option<Vec<f64>>,
520
521 name: String,
523}
524
525impl Metrics {
526 pub fn new(name: &str) -> Self {
528 Self {
529 losses: Vec::new(),
530 accuracies: None,
531 name: name.to_string(),
532 }
533 }
534
535 pub fn add_loss(&mut self, loss: f64) {
537 self.losses.push(loss);
538 }
539
540 pub fn add_accuracy(&mut self, accuracy: f64) {
542 if self.accuracies.is_none() {
543 self.accuracies = Some(Vec::new());
544 }
545
546 if let Some(accuracies) = &mut self.accuracies {
547 accuracies.push(accuracy);
548 }
549 }
550
551 pub fn mean_loss(&self) -> Option<f64> {
553 if self.losses.is_empty() {
554 return None;
555 }
556
557 let sum: f64 = self.losses.iter().sum();
558 Some(sum / self.losses.len() as f64)
559 }
560
561 pub fn mean_accuracy(&self) -> Option<f64> {
563 if let Some(accuracies) = &self.accuracies {
564 if accuracies.is_empty() {
565 return None;
566 }
567
568 let sum: f64 = accuracies.iter().sum();
569 Some(sum / accuracies.len() as f64)
570 } else {
571 None
572 }
573 }
574
575 pub fn reset(&mut self) {
577 self.losses.clear();
578 if let Some(accuracies) = &mut self.accuracies {
579 accuracies.clear();
580 }
581 }
582
583 pub fn name(&self) -> &str {
585 &self.name
586 }
587}
588
589impl fmt::Display for Metrics {
590 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
591 write!(
592 f,
593 "{}: loss = {:.4}",
594 self.name,
595 self.mean_loss().unwrap_or(0.0)
596 )?;
597
598 if let Some(acc) = self.mean_accuracy() {
599 write!(f, ", accuracy = {acc:.4}")?;
600 }
601
602 Ok(())
603 }
604}
605
606pub trait TrainingCallback {
608 fn on_epoch_start(&mut self, epoch: usize, numepochs: usize);
610
611 fn on_epoch_end(&mut self, epoch: usize, numepochs: usize, metrics: &Metrics);
613
614 fn on_batch_start(&mut self, batch: usize, numbatches: usize);
616
617 fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64);
619
620 fn on_train_start(&mut self, numepochs: usize);
622
623 fn on_train_end(&mut self, metrics: &Metrics);
625}
626
627pub struct ProgressCallback {
629 verbose: bool,
631
632 epoch_start: Option<Instant>,
634
635 train_start: Option<Instant>,
637}
638
639impl ProgressCallback {
640 pub fn new(verbose: bool) -> Self {
642 Self {
643 verbose,
644 epoch_start: None,
645 train_start: None,
646 }
647 }
648}
649
650impl TrainingCallback for ProgressCallback {
651 fn on_epoch_start(&mut self, epoch: usize, numepochs: usize) {
652 if self.verbose {
653 println!("Epoch {}/{}", epoch + 1, numepochs);
654 }
655
656 self.epoch_start = Some(Instant::now());
657 }
658
659 fn on_epoch_end(&mut self, _epoch: usize, numepochs: usize, metrics: &Metrics) {
660 if self.verbose {
661 if let Some(start) = self.epoch_start {
662 let duration = start.elapsed();
663 println!("{} - {}ms", metrics, duration.as_millis());
664 } else {
665 println!("{metrics}");
666 }
667 }
668 }
669
670 fn on_batch_start(&mut self, _batch: usize, _numbatches: usize) {
671 }
673
674 fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64) {
675 if self.verbose && (batch + 1) % (numbatches / 10).max(1) == 0 {
676 print!("\rBatch {}/{} - loss: {:.4}", batch + 1, numbatches, loss);
677 if batch + 1 == numbatches {
678 println!();
679 }
680 }
681 }
682
683 fn on_train_start(&mut self, numepochs: usize) {
684 if self.verbose {
685 println!("Starting training for {numepochs} epochs");
686 }
687
688 self.train_start = Some(Instant::now());
689 }
690
691 fn on_train_end(&mut self, metrics: &Metrics) {
692 if self.verbose {
693 if let Some(start) = self.train_start {
694 let duration = start.elapsed();
695 println!("Training completed in {}s", duration.as_secs());
696 } else {
697 println!("Training completed");
698 }
699
700 if let Some(acc) = metrics.mean_accuracy() {
701 println!("Final accuracy: {acc:.4}");
702 }
703 }
704 }
705}
706
707pub struct Trainer {
709 model: Sequential,
711
712 optimizer: Box<dyn Optimizer>,
714
715 lossfn: Box<dyn Loss>,
717
718 callbacks: Vec<Box<dyn TrainingCallback>>,
720
721 train_metrics: Metrics,
723
724 val_metrics: Option<Metrics>,
726}
727
728impl Trainer {
729 pub fn new(model: Sequential, optimizer: Box<dyn Optimizer>, lossfn: Box<dyn Loss>) -> Self {
731 Self {
732 model,
733 optimizer,
734 lossfn,
735 callbacks: Vec::new(),
736 train_metrics: Metrics::new("train"),
737 val_metrics: None,
738 }
739 }
740
741 pub fn add_callback(&mut self, callback: Box<dyn TrainingCallback>) {
743 self.callbacks.push(callback);
744 }
745
746 pub fn train(
748 &mut self,
749 train_loader: &mut DataLoader,
750 numepochs: usize,
751 mut val_loader: Option<&mut DataLoader>,
752 ) -> CoreResult<()> {
753 for callback in &mut self.callbacks {
755 callback.on_train_start(numepochs);
756 }
757
758 if val_loader.is_some() && self.val_metrics.is_none() {
760 self.val_metrics = Some(Metrics::new("val"));
761 }
762
763 for epoch in 0..numepochs {
765 self.train_metrics.reset();
767 if let Some(metrics) = &mut self.val_metrics {
768 metrics.reset();
769 }
770
771 for callback in &mut self.callbacks {
773 callback.on_epoch_start(epoch, numepochs);
774 }
775
776 self.train_epoch(train_loader)?;
778
779 if let Some(ref mut val_loader) = val_loader {
781 self.validate(val_loader)?;
782 }
783
784 for callback in &mut self.callbacks {
786 callback.on_epoch_end(
787 epoch,
788 numepochs,
789 if let Some(val_metrics) = &self.val_metrics {
790 val_metrics
791 } else {
792 &self.train_metrics
793 },
794 );
795 }
796 }
797
798 for callback in &mut self.callbacks {
800 callback.on_train_end(if let Some(val_metrics) = &self.val_metrics {
801 val_metrics
802 } else {
803 &self.train_metrics
804 });
805 }
806
807 Ok(())
808 }
809
810 fn train_epoch(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
812 self.model.train();
814
815 dataloader.reset();
817
818 let numbatches = dataloader.numbatches();
819
820 for batch_idx in 0..numbatches {
822 let (inputs, targets) = dataloader.next_batch().unwrap();
823 for callback in &mut self.callbacks {
825 callback.on_batch_start(batch_idx, numbatches);
826 }
827
828 let batch_loss = self.train_batch(&inputs, &targets)?;
830
831 self.train_metrics.add_loss(batch_loss);
833
834 for callback in &mut self.callbacks {
836 callback.on_batch_end(batch_idx, numbatches, batch_loss);
837 }
838 }
839
840 Ok(())
841 }
842
843 fn train_batch(
845 &mut self,
846 inputs: &[Box<dyn ArrayProtocol>],
847 targets: &[Box<dyn ArrayProtocol>],
848 ) -> CoreResult<f64> {
849 self.optimizer.zero_grad();
851
852 let mut batch_loss = 0.0;
854
855 for (input, target) in inputs.iter().zip(targets.iter()) {
856 let output = self.model.forward(input.as_ref())?;
858
859 let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
861
862 if let Some(loss_array) = loss
864 .as_any()
865 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
866 {
867 let loss_value = loss_array.as_array().sum();
868 batch_loss += loss_value;
869 }
870
871 let learningrate = 0.001; let current_output = self.model.forward(input.as_ref())?;
882 let current_loss = self
883 .lossfn
884 .forward(current_output.as_ref(), target.as_ref())?;
885 let _current_loss_value = if let Some(loss_array) = current_loss
886 .as_any()
887 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
888 {
889 loss_array.as_array().sum()
890 } else {
891 0.0
892 };
893
894 let gradients = self.compute_gradients(
896 input.as_ref(),
897 target.as_ref(),
898 current_output.as_ref(),
899 current_loss.as_ref(),
900 )?;
901
902 self.apply_gradients(&gradients, learningrate)?;
904
905 self.optimizer.accumulate_gradients(&gradients)?;
907 }
908
909 let batch_loss = batch_loss / inputs.len() as f64;
911
912 self.optimizer.step()?;
914
915 Ok(batch_loss)
916 }
917
918 fn compute_gradients(
920 &self,
921 input: &dyn ArrayProtocol,
922 target: &dyn ArrayProtocol,
923 output: &dyn ArrayProtocol,
924 _loss: &dyn ArrayProtocol,
925 ) -> CoreResult<GradientDict> {
926 let mut gradients = GradientDict::new();
928
929 let loss_grad = self.lossfn.backward(output, target)?;
931
932 let model_gradients = self.model.backward(input, loss_grad.as_ref())?;
934
935 gradients.merge(model_gradients);
937
938 Ok(gradients)
939 }
940
941 fn apply_gradients(&mut self, gradients: &GradientDict, learningrate: f64) -> CoreResult<()> {
943 for (param_name, gradient) in gradients.iter() {
945 self.model
946 .update_parameter(param_name, gradient.as_ref(), learningrate)?;
947 }
948
949 Ok(())
950 }
951
952 fn validate(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
954 self.model.eval();
956
957 if let Some(metrics) = &mut self.val_metrics {
959 metrics.reset();
960 } else {
961 return Ok(());
962 }
963
964 dataloader.reset();
966
967 let numbatches = dataloader.numbatches();
968
969 for _ in 0..numbatches {
971 let (inputs, targets) = dataloader.next_batch().unwrap();
972 let mut batch_loss = 0.0;
974 let mut batch_correct = 0;
975 let mut batch_total = 0;
976
977 for (input, target) in inputs.iter().zip(targets.iter()) {
978 let output = self.model.forward(input.as_ref())?;
980
981 let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
983
984 if let Some(loss_array) = loss
986 .as_any()
987 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
988 {
989 let loss_value = loss_array.as_array().sum();
990 batch_loss += loss_value;
991 }
992
993 if let (Some(output_array), Some(target_array)) = (
995 output
996 .as_any()
997 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>(),
998 target
999 .as_any()
1000 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix2>>(),
1001 ) {
1002 let output_vec = output_array.as_array();
1004 let target_vec = target_array.as_array();
1005
1006 if output_vec.ndim() == 2 && target_vec.ndim() == 2 {
1008 for (out_row, target_row) in
1009 output_vec.outer_iter().zip(target_vec.outer_iter())
1010 {
1011 let mut max_idx = 0;
1013 let mut max_val = out_row[0];
1014
1015 for (i, &val) in out_row.iter().enumerate().skip(1) {
1016 if val > max_val {
1017 max_idx = i;
1018 max_val = val;
1019 }
1020 }
1021
1022 if let Some(target_idx) = target_row.iter().position(|&x| x == 1.0) {
1024 if max_idx == target_idx {
1025 batch_correct += 1;
1026 }
1027 }
1028
1029 batch_total += 1;
1030 }
1031 }
1032 }
1033 }
1034
1035 let batch_loss = batch_loss / inputs.len() as f64;
1037 let batch_accuracy = if batch_total > 0 {
1038 batch_correct as f64 / batch_total as f64
1039 } else {
1040 0.0
1041 };
1042
1043 if let Some(metrics) = &mut self.val_metrics {
1045 metrics.add_loss(batch_loss);
1046 metrics.add_accuracy(batch_accuracy);
1047 }
1048 }
1049
1050 Ok(())
1051 }
1052
1053 pub const fn train_metrics(&self) -> &Metrics {
1055 &self.train_metrics
1056 }
1057
1058 pub fn val_metrics(&self) -> Option<&Metrics> {
1060 self.val_metrics.as_ref()
1061 }
1062}
1063
1064#[cfg(test)]
1067mod tests {
1068 use super::*;
1069 use crate::array_protocol::{self, NdarrayWrapper};
1070 use ::ndarray::Array2;
1071
1072 #[test]
1073 fn test_in_memory_dataset() {
1074 let inputs = Array2::<f64>::ones((10, 5));
1076 let targets = Array2::<f64>::zeros((10, 2));
1077
1078 let dataset = InMemoryDataset::from_arrays(inputs, targets);
1080
1081 assert_eq!(dataset.len(), 10);
1083 assert_eq!(dataset.inputshape(), vec![5]);
1084 assert_eq!(dataset.outputshape(), vec![2]);
1085
1086 let (input, target) = dataset.get(0).unwrap();
1088 assert!(input
1089 .as_any()
1090 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
1091 .is_some());
1092 assert!(target
1093 .as_any()
1094 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::IxDyn>>()
1095 .is_some());
1096 }
1097
1098 #[test]
1099 fn test_dataloader() {
1100 let inputs = Array2::<f64>::ones((10, 5));
1102 let targets = Array2::<f64>::zeros((10, 2));
1103
1104 let dataset = Box::new(InMemoryDataset::from_arrays(inputs, targets));
1106 let mut loader = DataLoader::new(dataset, 4, true, Some(42));
1107
1108 assert_eq!(loader.numbatches(), 3);
1110
1111 let (batch1_inputs, batch1_targets) = loader.next_batch().unwrap();
1113 assert_eq!(batch1_inputs.len(), 4);
1114 assert_eq!(batch1_targets.len(), 4);
1115
1116 let (batch2_inputs, batch2_targets) = loader.next_batch().unwrap();
1117 assert_eq!(batch2_inputs.len(), 4);
1118 assert_eq!(batch2_targets.len(), 4);
1119
1120 let (batch3_inputs, batch3_targets) = loader.next_batch().unwrap();
1121 assert_eq!(batch3_inputs.len(), 2);
1122 assert_eq!(batch3_targets.len(), 2);
1123
1124 loader.reset();
1126 let (batch1_inputs, batch1_targets) = loader.next_batch().unwrap();
1127 assert_eq!(batch1_inputs.len(), 4);
1128 assert_eq!(batch1_targets.len(), 4);
1129 }
1130
1131 #[test]
1132 fn test_mse_loss() {
1133 array_protocol::init();
1135
1136 let predictions = Array2::<f64>::ones((2, 3));
1138 let targets = Array2::<f64>::zeros((2, 3));
1139
1140 let predictions_wrapped = NdarrayWrapper::new(predictions);
1141 let targets_wrapped = NdarrayWrapper::new(targets);
1142
1143 let mse = MSELoss::new(Some("mean"));
1145
1146 match mse.forward(&predictions_wrapped, &targets_wrapped) {
1148 Ok(loss) => {
1149 if let Some(loss_array) = loss
1150 .as_any()
1151 .downcast_ref::<NdarrayWrapper<f64, crate::ndarray::Ix0>>()
1152 {
1153 assert_eq!(loss_array.as_array()[()], 1.0);
1155 } else {
1156 println!("Loss not of expected type NdarrayWrapper<f64, Ix0>");
1157 }
1158 }
1159 Err(e) => {
1160 println!("MSE Loss forward not fully implemented: {e}");
1161 }
1162 }
1163 }
1164
1165 #[test]
1166 fn test_metrics() {
1167 let mut metrics = Metrics::new("test");
1169
1170 metrics.add_loss(1.0);
1172 metrics.add_loss(2.0);
1173 metrics.add_loss(3.0);
1174
1175 metrics.add_accuracy(0.5);
1177 metrics.add_accuracy(0.6);
1178 metrics.add_accuracy(0.7);
1179
1180 assert_eq!(metrics.mean_loss().unwrap(), 2.0);
1182 assert_eq!(metrics.mean_accuracy().unwrap(), 0.6);
1183
1184 metrics.reset();
1186 assert!(metrics.mean_loss().is_none());
1187 assert!(metrics.mean_accuracy().is_none());
1188 }
1189}