1use std::fmt;
20use std::time::Instant;
21
22use ndarray::{Array, Array0, Dimension};
23use rand::seq::SliceRandom;
24use rand::Rng;
25use rand::SeedableRng;
26
27use crate::array_protocol::grad::{GradientDict, Optimizer};
28use crate::array_protocol::ml_ops::ActivationFunc;
29use crate::array_protocol::neural::Sequential;
30use crate::array_protocol::operations::{multiply, subtract};
31use crate::array_protocol::{activation, ArrayProtocol, NdarrayWrapper};
32use crate::error::{CoreError, CoreResult, ErrorContext};
33
34pub type BatchData = (Vec<Box<dyn ArrayProtocol>>, Vec<Box<dyn ArrayProtocol>>);
36
37pub trait Dataset {
39 fn len(&self) -> usize;
41
42 fn is_empty(&self) -> bool {
44 self.len() == 0
45 }
46
47 fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)>;
49
50 fn inputshape(&self) -> Vec<usize>;
52
53 fn outputshape(&self) -> Vec<usize>;
55}
56
57pub struct InMemoryDataset {
59 inputs: Vec<Box<dyn ArrayProtocol>>,
61
62 targets: Vec<Box<dyn ArrayProtocol>>,
64
65 inputshape: Vec<usize>,
67
68 outputshape: Vec<usize>,
70}
71
72impl InMemoryDataset {
73 pub fn new(
75 inputs: Vec<Box<dyn ArrayProtocol>>,
76 targets: Vec<Box<dyn ArrayProtocol>>,
77 inputshape: Vec<usize>,
78 outputshape: Vec<usize>,
79 ) -> Self {
80 assert_eq!(
81 inputs.len(),
82 targets.len(),
83 "Inputs and targets must have the same length"
84 );
85
86 Self {
87 inputs,
88 targets,
89 inputshape,
90 outputshape,
91 }
92 }
93
94 pub fn from_arrays<T, D1, D2>(inputs: Array<T, D1>, targets: Array<T, D2>) -> Self
96 where
97 T: Clone + Send + Sync + 'static,
98 D1: Dimension + Send + Sync,
99 D2: Dimension + Send + Sync,
100 {
101 let inputshape = inputs.shape().to_vec();
102 let outputshape = targets.shape().to_vec();
103
104 let num_samples = inputshape[0];
106 assert_eq!(
107 num_samples, outputshape[0],
108 "Inputs and targets must have the same number of samples"
109 );
110
111 let mut input_samples = Vec::with_capacity(num_samples);
112 let mut target_samples = Vec::with_capacity(num_samples);
113
114 let to_dyn_inputs = inputs.into_dyn();
116 let to_dyn_targets = targets.into_dyn();
117
118 for i in 0..num_samples {
119 let input_view = to_dyn_inputs.index_axis(ndarray::Axis(0), i);
121 let inputarray = input_view.to_owned();
122 input_samples.push(Box::new(NdarrayWrapper::new(inputarray)) as Box<dyn ArrayProtocol>);
123
124 let target_view = to_dyn_targets.index_axis(ndarray::Axis(0), i);
125 let target_array = target_view.to_owned();
126 target_samples
127 .push(Box::new(NdarrayWrapper::new(target_array)) as Box<dyn ArrayProtocol>);
128 }
129
130 Self {
131 inputs: input_samples,
132 targets: target_samples,
133 inputshape: inputshape[1..].to_vec(),
134 outputshape: outputshape[1..].to_vec(),
135 }
136 }
137}
138
139impl Dataset for InMemoryDataset {
140 fn len(&self) -> usize {
141 self.inputs.len()
142 }
143
144 fn get(&self, index: usize) -> Option<(Box<dyn ArrayProtocol>, Box<dyn ArrayProtocol>)> {
145 if index >= self.len() {
146 return None;
147 }
148
149 Some((self.inputs[index].clone(), self.targets[index].clone()))
150 }
151
152 fn inputshape(&self) -> Vec<usize> {
153 self.inputshape.clone()
154 }
155
156 fn outputshape(&self) -> Vec<usize> {
157 self.outputshape.clone()
158 }
159}
160
161pub struct DataLoader {
163 dataset: Box<dyn Dataset>,
165
166 batch_size: usize,
168
169 shuffle: bool,
171
172 seed: Option<u64>,
174
175 indices: Vec<usize>,
177
178 position: usize,
180}
181
182impl DataLoader {
183 pub fn new(
185 dataset: Box<dyn Dataset>,
186 batch_size: usize,
187 shuffle: bool,
188 seed: Option<u64>,
189 ) -> Self {
190 let indices = (0..dataset.len()).collect();
191
192 Self {
193 dataset,
194 batch_size,
195 shuffle,
196 seed,
197 indices,
198 position: 0,
199 }
200 }
201
202 pub fn reset(&mut self) {
204 self.position = 0;
205
206 if self.shuffle {
207 let mut rng = match self.seed {
208 Some(s) => rand::rngs::StdRng::seed_from_u64(s),
209 None => {
210 let mut rng = rand::rng();
211 let random_seed: u64 = rng.random();
213 rand::rngs::StdRng::seed_from_u64(random_seed)
214 }
215 };
216
217 self.indices.shuffle(&mut rng);
218 }
219 }
220
221 pub fn next_batch(&mut self) -> Option<BatchData> {
223 if self.position >= self.dataset.len() {
224 return None;
225 }
226
227 let remaining = self.dataset.len() - self.position;
229 let batch_size = std::cmp::min(self.batch_size, remaining);
230
231 let mut inputs = Vec::with_capacity(batch_size);
233 let mut targets = Vec::with_capacity(batch_size);
234
235 for i in 0..batch_size {
236 let index = self.indices[self.position + i];
237 if let Some((input, target)) = self.dataset.get(index) {
238 inputs.push(input);
239 targets.push(target);
240 }
241 }
242
243 self.position += batch_size;
245
246 Some((inputs, targets))
247 }
248
249 pub fn numbatches(&self) -> usize {
251 self.dataset.len().div_ceil(self.batch_size)
252 }
253
254 pub fn dataset(&self) -> &dyn Dataset {
256 self.dataset.as_ref()
257 }
258}
259
260impl Iterator for DataLoader {
262 type Item = BatchData;
263
264 fn next(&mut self) -> Option<Self::Item> {
265 self.next_batch()
266 }
267}
268
269pub trait Loss {
271 fn forward(
273 &self,
274 predictions: &dyn ArrayProtocol,
275 targets: &dyn ArrayProtocol,
276 ) -> CoreResult<Box<dyn ArrayProtocol>>;
277
278 fn backward(
280 &self,
281 predictions: &dyn ArrayProtocol,
282 targets: &dyn ArrayProtocol,
283 ) -> CoreResult<Box<dyn ArrayProtocol>>;
284
285 fn name(&self) -> &str;
287}
288
289pub struct MSELoss {
291 name: String,
293
294 reduction: String,
296}
297
298impl MSELoss {
299 pub fn new(reduction: Option<&str>) -> Self {
301 Self {
302 name: "MSELoss".to_string(),
303 reduction: reduction.unwrap_or("mean").to_string(),
304 }
305 }
306}
307
308impl Loss for MSELoss {
309 fn forward(
310 &self,
311 predictions: &dyn ArrayProtocol,
312 targets: &dyn ArrayProtocol,
313 ) -> CoreResult<Box<dyn ArrayProtocol>> {
314 let diff = subtract(predictions, targets)?;
316 let squared = multiply(diff.as_ref(), diff.as_ref())?;
317
318 match self.reduction.as_str() {
320 "none" => Ok(squared),
321 "mean" => {
322 if let Some(array) = squared
324 .as_any()
325 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
326 {
327 let mean = array.as_array().mean().unwrap();
328 let result = Array0::<f64>::from_elem((), mean);
329 Ok(Box::new(NdarrayWrapper::new(result)))
330 } else {
331 Err(CoreError::NotImplementedError(ErrorContext::new(
332 "Mean reduction not implemented for this array type".to_string(),
333 )))
334 }
335 }
336 "sum" => {
337 if let Some(array) = squared
339 .as_any()
340 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
341 {
342 let sum = array.as_array().sum();
343 let result = Array0::<f64>::from_elem((), sum);
344 Ok(Box::new(NdarrayWrapper::new(result)))
345 } else {
346 Err(CoreError::NotImplementedError(ErrorContext::new(
347 "Sum reduction not implemented for this array type".to_string(),
348 )))
349 }
350 }
351 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
352 "Unknown reduction: {reduction}",
353 reduction = self.reduction
354 )))),
355 }
356 }
357
358 fn backward(
359 &self,
360 predictions: &dyn ArrayProtocol,
361 targets: &dyn ArrayProtocol,
362 ) -> CoreResult<Box<dyn ArrayProtocol>> {
363 let diff = subtract(predictions, targets)?;
365 let factor = Box::new(NdarrayWrapper::new(ndarray::Array0::<f64>::from_elem(
366 (),
367 2.0,
368 )));
369 let grad = multiply(factor.as_ref(), diff.as_ref())?;
370
371 match self.reduction.as_str() {
373 "none" => Ok(grad),
374 "mean" => {
375 if let Some(array) = grad
377 .as_any()
378 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
379 {
380 let n = array.as_array().len() as f64;
381 let scale_factor = Box::new(NdarrayWrapper::new(
382 ndarray::Array0::<f64>::from_elem((), 1.0 / n),
383 ));
384 Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
385 } else {
386 Ok(grad)
387 }
388 }
389 "sum" => Ok(grad),
390 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
391 "Unknown reduction: {reduction}",
392 reduction = self.reduction
393 )))),
394 }
395 }
396
397 fn name(&self) -> &str {
398 &self.name
399 }
400}
401
402pub struct CrossEntropyLoss {
404 name: String,
406
407 reduction: String,
409}
410
411impl CrossEntropyLoss {
412 pub fn new(reduction: Option<&str>) -> Self {
414 Self {
415 name: "CrossEntropyLoss".to_string(),
416 reduction: reduction.unwrap_or("mean").to_string(),
417 }
418 }
419}
420
421impl Loss for CrossEntropyLoss {
422 fn forward(
423 &self,
424 predictions: &dyn ArrayProtocol,
425 targets: &dyn ArrayProtocol,
426 ) -> CoreResult<Box<dyn ArrayProtocol>> {
427 let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
429
430 if let (Some(preds_array), Some(targets_array)) = (
432 softmax_preds
433 .as_any()
434 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>(),
435 targets
436 .as_any()
437 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>(),
438 ) {
439 let preds = preds_array.as_array();
440 let targets = targets_array.as_array();
441
442 let log_preds = preds.mapv(|x| x.max(1e-10).ln());
444
445 let mut losses = targets.clone();
447 losses.zip_mut_with(&log_preds, |t, l| *t = -(*t * *l));
448
449 match self.reduction.as_str() {
451 "none" => Ok(Box::new(NdarrayWrapper::new(losses))),
452 "mean" => {
453 let mean = losses.mean().unwrap();
454 let result = Array0::<f64>::from_elem((), mean);
455 Ok(Box::new(NdarrayWrapper::new(result)))
456 }
457 "sum" => {
458 let sum = losses.sum();
459 let result = Array0::<f64>::from_elem((), sum);
460 Ok(Box::new(NdarrayWrapper::new(result)))
461 }
462 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
463 "Unknown reduction: {reduction}",
464 reduction = self.reduction
465 )))),
466 }
467 } else {
468 Err(CoreError::NotImplementedError(ErrorContext::new(
469 "CrossEntropy not implemented for these array types".to_string(),
470 )))
471 }
472 }
473
474 fn backward(
475 &self,
476 predictions: &dyn ArrayProtocol,
477 targets: &dyn ArrayProtocol,
478 ) -> CoreResult<Box<dyn ArrayProtocol>> {
479 let softmax_preds = activation(predictions, ActivationFunc::Softmax)?;
481 let grad = subtract(softmax_preds.as_ref(), targets)?;
482
483 match self.reduction.as_str() {
485 "none" => Ok(grad),
486 "mean" => {
487 if let Some(array) = grad
489 .as_any()
490 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
491 {
492 let n = array.as_array().len() as f64;
493 let scale_factor = Box::new(NdarrayWrapper::new(
494 ndarray::Array0::<f64>::from_elem((), 1.0 / n),
495 ));
496 Ok(multiply(scale_factor.as_ref(), grad.as_ref())?)
497 } else {
498 Ok(grad)
499 }
500 }
501 "sum" => Ok(grad),
502 _ => Err(CoreError::InvalidArgument(ErrorContext::new(format!(
503 "Unknown reduction: {reduction}",
504 reduction = self.reduction
505 )))),
506 }
507 }
508
509 fn name(&self) -> &str {
510 &self.name
511 }
512}
513
514pub struct Metrics {
516 losses: Vec<f64>,
518
519 accuracies: Option<Vec<f64>>,
521
522 name: String,
524}
525
526impl Metrics {
527 pub fn new(name: &str) -> Self {
529 Self {
530 losses: Vec::new(),
531 accuracies: None,
532 name: name.to_string(),
533 }
534 }
535
536 pub fn add_loss(&mut self, loss: f64) {
538 self.losses.push(loss);
539 }
540
541 pub fn add_accuracy(&mut self, accuracy: f64) {
543 if self.accuracies.is_none() {
544 self.accuracies = Some(Vec::new());
545 }
546
547 if let Some(accuracies) = &mut self.accuracies {
548 accuracies.push(accuracy);
549 }
550 }
551
552 pub fn mean_loss(&self) -> Option<f64> {
554 if self.losses.is_empty() {
555 return None;
556 }
557
558 let sum: f64 = self.losses.iter().sum();
559 Some(sum / self.losses.len() as f64)
560 }
561
562 pub fn mean_accuracy(&self) -> Option<f64> {
564 if let Some(accuracies) = &self.accuracies {
565 if accuracies.is_empty() {
566 return None;
567 }
568
569 let sum: f64 = accuracies.iter().sum();
570 Some(sum / accuracies.len() as f64)
571 } else {
572 None
573 }
574 }
575
576 pub fn reset(&mut self) {
578 self.losses.clear();
579 if let Some(accuracies) = &mut self.accuracies {
580 accuracies.clear();
581 }
582 }
583
584 pub fn name(&self) -> &str {
586 &self.name
587 }
588}
589
590impl fmt::Display for Metrics {
591 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
592 write!(
593 f,
594 "{}: loss = {:.4}",
595 self.name,
596 self.mean_loss().unwrap_or(0.0)
597 )?;
598
599 if let Some(acc) = self.mean_accuracy() {
600 write!(f, ", accuracy = {acc:.4}")?;
601 }
602
603 Ok(())
604 }
605}
606
607pub trait TrainingCallback {
609 fn on_epoch_start(&mut self, epoch: usize, numepochs: usize);
611
612 fn on_epoch_end(&mut self, epoch: usize, numepochs: usize, metrics: &Metrics);
614
615 fn on_batch_start(&mut self, batch: usize, numbatches: usize);
617
618 fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64);
620
621 fn on_train_start(&mut self, numepochs: usize);
623
624 fn on_train_end(&mut self, metrics: &Metrics);
626}
627
628pub struct ProgressCallback {
630 verbose: bool,
632
633 epoch_start: Option<Instant>,
635
636 train_start: Option<Instant>,
638}
639
640impl ProgressCallback {
641 pub fn new(verbose: bool) -> Self {
643 Self {
644 verbose,
645 epoch_start: None,
646 train_start: None,
647 }
648 }
649}
650
651impl TrainingCallback for ProgressCallback {
652 fn on_epoch_start(&mut self, epoch: usize, numepochs: usize) {
653 if self.verbose {
654 println!("Epoch {}/{}", epoch + 1, numepochs);
655 }
656
657 self.epoch_start = Some(Instant::now());
658 }
659
660 fn on_epoch_end(&mut self, _epoch: usize, numepochs: usize, metrics: &Metrics) {
661 if self.verbose {
662 if let Some(start) = self.epoch_start {
663 let duration = start.elapsed();
664 println!("{} - {}ms", metrics, duration.as_millis());
665 } else {
666 println!("{metrics}");
667 }
668 }
669 }
670
671 fn on_batch_start(&mut self, _batch: usize, _numbatches: usize) {
672 }
674
675 fn on_batch_end(&mut self, batch: usize, numbatches: usize, loss: f64) {
676 if self.verbose && (batch + 1).is_multiple_of((numbatches / 10).max(1)) {
677 print!("\rBatch {}/{} - loss: {:.4}", batch + 1, numbatches, loss);
678 if batch + 1 == numbatches {
679 println!();
680 }
681 }
682 }
683
684 fn on_train_start(&mut self, numepochs: usize) {
685 if self.verbose {
686 println!("Starting training for {numepochs} epochs");
687 }
688
689 self.train_start = Some(Instant::now());
690 }
691
692 fn on_train_end(&mut self, metrics: &Metrics) {
693 if self.verbose {
694 if let Some(start) = self.train_start {
695 let duration = start.elapsed();
696 println!("Training completed in {}s", duration.as_secs());
697 } else {
698 println!("Training completed");
699 }
700
701 if let Some(acc) = metrics.mean_accuracy() {
702 println!("Final accuracy: {acc:.4}");
703 }
704 }
705 }
706}
707
708pub struct Trainer {
710 model: Sequential,
712
713 optimizer: Box<dyn Optimizer>,
715
716 lossfn: Box<dyn Loss>,
718
719 callbacks: Vec<Box<dyn TrainingCallback>>,
721
722 train_metrics: Metrics,
724
725 val_metrics: Option<Metrics>,
727}
728
729impl Trainer {
730 pub fn new(model: Sequential, optimizer: Box<dyn Optimizer>, lossfn: Box<dyn Loss>) -> Self {
732 Self {
733 model,
734 optimizer,
735 lossfn,
736 callbacks: Vec::new(),
737 train_metrics: Metrics::new("train"),
738 val_metrics: None,
739 }
740 }
741
742 pub fn add_callback(&mut self, callback: Box<dyn TrainingCallback>) {
744 self.callbacks.push(callback);
745 }
746
747 pub fn train(
749 &mut self,
750 train_loader: &mut DataLoader,
751 numepochs: usize,
752 mut val_loader: Option<&mut DataLoader>,
753 ) -> CoreResult<()> {
754 for callback in &mut self.callbacks {
756 callback.on_train_start(numepochs);
757 }
758
759 if val_loader.is_some() && self.val_metrics.is_none() {
761 self.val_metrics = Some(Metrics::new("val"));
762 }
763
764 for epoch in 0..numepochs {
766 self.train_metrics.reset();
768 if let Some(metrics) = &mut self.val_metrics {
769 metrics.reset();
770 }
771
772 for callback in &mut self.callbacks {
774 callback.on_epoch_start(epoch, numepochs);
775 }
776
777 self.train_epoch(train_loader)?;
779
780 if let Some(ref mut val_loader) = val_loader {
782 self.validate(val_loader)?;
783 }
784
785 for callback in &mut self.callbacks {
787 callback.on_epoch_end(
788 epoch,
789 numepochs,
790 if let Some(val_metrics) = &self.val_metrics {
791 val_metrics
792 } else {
793 &self.train_metrics
794 },
795 );
796 }
797 }
798
799 for callback in &mut self.callbacks {
801 callback.on_train_end(if let Some(val_metrics) = &self.val_metrics {
802 val_metrics
803 } else {
804 &self.train_metrics
805 });
806 }
807
808 Ok(())
809 }
810
811 fn train_epoch(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
813 self.model.train();
815
816 dataloader.reset();
818
819 let numbatches = dataloader.numbatches();
820
821 for batch_idx in 0..numbatches {
823 let (inputs, targets) = dataloader.next_batch().unwrap();
824 for callback in &mut self.callbacks {
826 callback.on_batch_start(batch_idx, numbatches);
827 }
828
829 let batch_loss = self.train_batch(&inputs, &targets)?;
831
832 self.train_metrics.add_loss(batch_loss);
834
835 for callback in &mut self.callbacks {
837 callback.on_batch_end(batch_idx, numbatches, batch_loss);
838 }
839 }
840
841 Ok(())
842 }
843
844 fn train_batch(
846 &mut self,
847 inputs: &[Box<dyn ArrayProtocol>],
848 targets: &[Box<dyn ArrayProtocol>],
849 ) -> CoreResult<f64> {
850 self.optimizer.zero_grad();
852
853 let mut batch_loss = 0.0;
855
856 for (input, target) in inputs.iter().zip(targets.iter()) {
857 let output = self.model.forward(input.as_ref())?;
859
860 let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
862
863 if let Some(loss_array) = loss
865 .as_any()
866 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
867 {
868 let loss_value = loss_array.as_array().sum();
869 batch_loss += loss_value;
870 }
871
872 let learningrate = 0.001; let current_output = self.model.forward(input.as_ref())?;
883 let current_loss = self
884 .lossfn
885 .forward(current_output.as_ref(), target.as_ref())?;
886 let _current_loss_value = if let Some(loss_array) = current_loss
887 .as_any()
888 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
889 {
890 loss_array.as_array().sum()
891 } else {
892 0.0
893 };
894
895 let gradients = self.compute_gradients(
897 input.as_ref(),
898 target.as_ref(),
899 current_output.as_ref(),
900 current_loss.as_ref(),
901 )?;
902
903 self.apply_gradients(&gradients, learningrate)?;
905
906 self.optimizer.accumulate_gradients(&gradients)?;
908 }
909
910 let batch_loss = batch_loss / inputs.len() as f64;
912
913 self.optimizer.step()?;
915
916 Ok(batch_loss)
917 }
918
919 fn compute_gradients(
921 &self,
922 input: &dyn ArrayProtocol,
923 target: &dyn ArrayProtocol,
924 output: &dyn ArrayProtocol,
925 _loss: &dyn ArrayProtocol,
926 ) -> CoreResult<GradientDict> {
927 let mut gradients = GradientDict::new();
929
930 let loss_grad = self.lossfn.backward(output, target)?;
932
933 let model_gradients = self.model.backward(input, loss_grad.as_ref())?;
935
936 gradients.merge(model_gradients);
938
939 Ok(gradients)
940 }
941
942 fn apply_gradients(&mut self, gradients: &GradientDict, learningrate: f64) -> CoreResult<()> {
944 for (param_name, gradient) in gradients.iter() {
946 self.model
947 .update_parameter(param_name, gradient.as_ref(), learningrate)?;
948 }
949
950 Ok(())
951 }
952
953 fn validate(&mut self, dataloader: &mut DataLoader) -> CoreResult<()> {
955 self.model.eval();
957
958 if let Some(metrics) = &mut self.val_metrics {
960 metrics.reset();
961 } else {
962 return Ok(());
963 }
964
965 dataloader.reset();
967
968 let numbatches = dataloader.numbatches();
969
970 for _ in 0..numbatches {
972 let (inputs, targets) = dataloader.next_batch().unwrap();
973 let mut batch_loss = 0.0;
975 let mut batch_correct = 0;
976 let mut batch_total = 0;
977
978 for (input, target) in inputs.iter().zip(targets.iter()) {
979 let output = self.model.forward(input.as_ref())?;
981
982 let loss = self.lossfn.forward(output.as_ref(), target.as_ref())?;
984
985 if let Some(loss_array) = loss
987 .as_any()
988 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
989 {
990 let loss_value = loss_array.as_array().sum();
991 batch_loss += loss_value;
992 }
993
994 if let (Some(output_array), Some(target_array)) = (
996 output
997 .as_any()
998 .downcast_ref::<NdarrayWrapper<f64, ndarray::Ix2>>(),
999 target
1000 .as_any()
1001 .downcast_ref::<NdarrayWrapper<f64, ndarray::Ix2>>(),
1002 ) {
1003 let output_vec = output_array.as_array();
1005 let target_vec = target_array.as_array();
1006
1007 if output_vec.ndim() == 2 && target_vec.ndim() == 2 {
1009 for (out_row, target_row) in
1010 output_vec.outer_iter().zip(target_vec.outer_iter())
1011 {
1012 let mut max_idx = 0;
1014 let mut max_val = out_row[0];
1015
1016 for (i, &val) in out_row.iter().enumerate().skip(1) {
1017 if val > max_val {
1018 max_idx = i;
1019 max_val = val;
1020 }
1021 }
1022
1023 if let Some(target_idx) = target_row.iter().position(|&x| x == 1.0) {
1025 if max_idx == target_idx {
1026 batch_correct += 1;
1027 }
1028 }
1029
1030 batch_total += 1;
1031 }
1032 }
1033 }
1034 }
1035
1036 let batch_loss = batch_loss / inputs.len() as f64;
1038 let batch_accuracy = if batch_total > 0 {
1039 batch_correct as f64 / batch_total as f64
1040 } else {
1041 0.0
1042 };
1043
1044 if let Some(metrics) = &mut self.val_metrics {
1046 metrics.add_loss(batch_loss);
1047 metrics.add_accuracy(batch_accuracy);
1048 }
1049 }
1050
1051 Ok(())
1052 }
1053
1054 pub const fn train_metrics(&self) -> &Metrics {
1056 &self.train_metrics
1057 }
1058
1059 pub fn val_metrics(&self) -> Option<&Metrics> {
1061 self.val_metrics.as_ref()
1062 }
1063}
1064
1065#[cfg(test)]
1068mod tests {
1069 use super::*;
1070 use crate::array_protocol::{self, NdarrayWrapper};
1071 use ndarray::Array2;
1072
1073 #[test]
1074 fn test_in_memory_dataset() {
1075 let inputs = Array2::<f64>::ones((10, 5));
1077 let targets = Array2::<f64>::zeros((10, 2));
1078
1079 let dataset = InMemoryDataset::from_arrays(inputs, targets);
1081
1082 assert_eq!(dataset.len(), 10);
1084 assert_eq!(dataset.inputshape(), vec![5]);
1085 assert_eq!(dataset.outputshape(), vec![2]);
1086
1087 let (input, target) = dataset.get(0).unwrap();
1089 assert!(input
1090 .as_any()
1091 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
1092 .is_some());
1093 assert!(target
1094 .as_any()
1095 .downcast_ref::<NdarrayWrapper<f64, ndarray::IxDyn>>()
1096 .is_some());
1097 }
1098
1099 #[test]
1100 fn test_dataloader() {
1101 let inputs = Array2::<f64>::ones((10, 5));
1103 let targets = Array2::<f64>::zeros((10, 2));
1104
1105 let dataset = Box::new(InMemoryDataset::from_arrays(inputs, targets));
1107 let mut loader = DataLoader::new(dataset, 4, true, Some(42));
1108
1109 assert_eq!(loader.numbatches(), 3);
1111
1112 let (batch1_inputs, batch1_targets) = loader.next_batch().unwrap();
1114 assert_eq!(batch1_inputs.len(), 4);
1115 assert_eq!(batch1_targets.len(), 4);
1116
1117 let (batch2_inputs, batch2_targets) = loader.next_batch().unwrap();
1118 assert_eq!(batch2_inputs.len(), 4);
1119 assert_eq!(batch2_targets.len(), 4);
1120
1121 let (batch3_inputs, batch3_targets) = loader.next_batch().unwrap();
1122 assert_eq!(batch3_inputs.len(), 2);
1123 assert_eq!(batch3_targets.len(), 2);
1124
1125 loader.reset();
1127 let (batch1_inputs, batch1_targets) = loader.next_batch().unwrap();
1128 assert_eq!(batch1_inputs.len(), 4);
1129 assert_eq!(batch1_targets.len(), 4);
1130 }
1131
1132 #[test]
1133 fn test_mse_loss() {
1134 array_protocol::init();
1136
1137 let predictions = Array2::<f64>::ones((2, 3));
1139 let targets = Array2::<f64>::zeros((2, 3));
1140
1141 let predictions_wrapped = NdarrayWrapper::new(predictions);
1142 let targets_wrapped = NdarrayWrapper::new(targets);
1143
1144 let mse = MSELoss::new(Some("mean"));
1146
1147 match mse.forward(&predictions_wrapped, &targets_wrapped) {
1149 Ok(loss) => {
1150 if let Some(loss_array) = loss
1151 .as_any()
1152 .downcast_ref::<NdarrayWrapper<f64, ndarray::Ix0>>()
1153 {
1154 assert_eq!(loss_array.as_array()[()], 1.0);
1156 } else {
1157 println!("Loss not of expected type NdarrayWrapper<f64, Ix0>");
1158 }
1159 }
1160 Err(e) => {
1161 println!("MSE Loss forward not fully implemented: {e}");
1162 }
1163 }
1164 }
1165
1166 #[test]
1167 fn test_metrics() {
1168 let mut metrics = Metrics::new("test");
1170
1171 metrics.add_loss(1.0);
1173 metrics.add_loss(2.0);
1174 metrics.add_loss(3.0);
1175
1176 metrics.add_accuracy(0.5);
1178 metrics.add_accuracy(0.6);
1179 metrics.add_accuracy(0.7);
1180
1181 assert_eq!(metrics.mean_loss().unwrap(), 2.0);
1183 assert_eq!(metrics.mean_accuracy().unwrap(), 0.6);
1184
1185 metrics.reset();
1187 assert!(metrics.mean_loss().is_none());
1188 assert!(metrics.mean_accuracy().is_none());
1189 }
1190}