1use crate::error::{GnnError, Result};
6use crate::search::cosine_similarity;
7use ndarray::Array2;
8
9#[derive(Debug, Clone)]
11pub enum OptimizerType {
12 Sgd {
14 learning_rate: f32,
16 momentum: f32,
18 },
19 Adam {
21 learning_rate: f32,
23 beta1: f32,
25 beta2: f32,
27 epsilon: f32,
29 },
30}
31
32#[derive(Debug)]
34enum OptimizerState {
35 Sgd {
37 velocity: Option<Array2<f32>>,
39 },
40 Adam {
42 m: Option<Array2<f32>>,
44 v: Option<Array2<f32>>,
46 t: usize,
48 },
49}
50
51pub struct Optimizer {
53 optimizer_type: OptimizerType,
54 state: OptimizerState,
55}
56
57impl Optimizer {
58 pub fn new(optimizer_type: OptimizerType) -> Self {
60 let state = match &optimizer_type {
61 OptimizerType::Sgd { .. } => OptimizerState::Sgd { velocity: None },
62 OptimizerType::Adam { .. } => OptimizerState::Adam {
63 m: None,
64 v: None,
65 t: 0,
66 },
67 };
68
69 Self {
70 optimizer_type,
71 state,
72 }
73 }
74
75 pub fn step(&mut self, params: &mut Array2<f32>, grads: &Array2<f32>) -> Result<()> {
87 if params.shape() != grads.shape() {
89 return Err(GnnError::dimension_mismatch(
90 format!("{:?}", params.shape()),
91 format!("{:?}", grads.shape()),
92 ));
93 }
94
95 match (&self.optimizer_type, &mut self.state) {
96 (
97 OptimizerType::Sgd {
98 learning_rate,
99 momentum,
100 },
101 OptimizerState::Sgd { velocity },
102 ) => Self::sgd_step_with_momentum(params, grads, *learning_rate, *momentum, velocity),
103 (
104 OptimizerType::Adam {
105 learning_rate,
106 beta1,
107 beta2,
108 epsilon,
109 },
110 OptimizerState::Adam { m, v, t },
111 ) => Self::adam_step(
112 params,
113 grads,
114 *learning_rate,
115 *beta1,
116 *beta2,
117 *epsilon,
118 m,
119 v,
120 t,
121 ),
122 _ => return Err(GnnError::invalid_input("Optimizer type and state mismatch")),
123 }
124 }
125
126 fn sgd_step_with_momentum(
131 params: &mut Array2<f32>,
132 grads: &Array2<f32>,
133 learning_rate: f32,
134 momentum: f32,
135 velocity: &mut Option<Array2<f32>>,
136 ) -> Result<()> {
137 if momentum == 0.0 {
138 *params -= &(grads * learning_rate);
140 } else {
141 if velocity.is_none() {
143 *velocity = Some(Array2::zeros(params.dim()));
145 }
146
147 if let Some(v) = velocity {
148 let new_velocity = v.mapv(|x| x * momentum) + grads * learning_rate;
150 *v = new_velocity;
151
152 *params -= &*v;
154 }
155 }
156
157 Ok(())
158 }
159
160 #[allow(clippy::too_many_arguments)]
169 fn adam_step(
170 params: &mut Array2<f32>,
171 grads: &Array2<f32>,
172 learning_rate: f32,
173 beta1: f32,
174 beta2: f32,
175 epsilon: f32,
176 m: &mut Option<Array2<f32>>,
177 v: &mut Option<Array2<f32>>,
178 t: &mut usize,
179 ) -> Result<()> {
180 if m.is_none() {
182 *m = Some(Array2::zeros(params.dim()));
183 }
184 if v.is_none() {
185 *v = Some(Array2::zeros(params.dim()));
186 }
187
188 *t += 1;
190 let timestep = *t as f32;
191
192 if let (Some(m_buf), Some(v_buf)) = (m, v) {
193 let new_m = m_buf.mapv(|x| x * beta1) + grads * (1.0 - beta1);
196 *m_buf = new_m;
197
198 let grads_squared = grads.mapv(|x| x * x);
201 let new_v = v_buf.mapv(|x| x * beta2) + grads_squared * (1.0 - beta2);
202 *v_buf = new_v;
203
204 let bias_correction1 = 1.0 - beta1.powi(*t as i32);
207 let m_hat = m_buf.mapv(|x| x / bias_correction1);
208
209 let bias_correction2 = 1.0 - beta2.powi(*t as i32);
212 let v_hat = v_buf.mapv(|x| x / bias_correction2);
213
214 let update = m_hat
217 .iter()
218 .zip(v_hat.iter())
219 .map(|(&m_val, &v_val)| learning_rate * m_val / (v_val.sqrt() + epsilon));
220
221 for (param, upd) in params.iter_mut().zip(update) {
222 *param -= upd;
223 }
224 }
225
226 Ok(())
227 }
228}
229
230#[derive(Debug, Clone, Copy)]
232pub enum LossType {
233 Mse,
235 CrossEntropy,
237 BinaryCrossEntropy,
239}
240
241pub struct Loss;
251
252impl Loss {
253 const EPS: f32 = 1e-7;
255
256 const MAX_GRAD: f32 = 1e6;
258
259 pub fn compute(
281 loss_type: LossType,
282 predictions: &Array2<f32>,
283 targets: &Array2<f32>,
284 ) -> Result<f32> {
285 if predictions.shape() != targets.shape() {
287 return Err(GnnError::dimension_mismatch(
288 format!("{:?}", predictions.shape()),
289 format!("{:?}", targets.shape()),
290 ));
291 }
292
293 if predictions.is_empty() {
294 return Err(GnnError::invalid_input(
295 "Cannot compute loss on empty arrays",
296 ));
297 }
298
299 match loss_type {
300 LossType::Mse => Self::mse_forward(predictions, targets),
301 LossType::CrossEntropy => Self::cross_entropy_forward(predictions, targets),
302 LossType::BinaryCrossEntropy => Self::bce_forward(predictions, targets),
303 }
304 }
305
306 pub fn gradient(
328 loss_type: LossType,
329 predictions: &Array2<f32>,
330 targets: &Array2<f32>,
331 ) -> Result<Array2<f32>> {
332 if predictions.shape() != targets.shape() {
334 return Err(GnnError::dimension_mismatch(
335 format!("{:?}", predictions.shape()),
336 format!("{:?}", targets.shape()),
337 ));
338 }
339
340 if predictions.is_empty() {
341 return Err(GnnError::invalid_input(
342 "Cannot compute gradient on empty arrays",
343 ));
344 }
345
346 match loss_type {
347 LossType::Mse => Self::mse_backward(predictions, targets),
348 LossType::CrossEntropy => Self::cross_entropy_backward(predictions, targets),
349 LossType::BinaryCrossEntropy => Self::bce_backward(predictions, targets),
350 }
351 }
352
353 fn mse_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
355 let diff = predictions - targets;
356 let squared = diff.mapv(|x| x * x);
357 Ok(squared.mean().unwrap_or(0.0))
358 }
359
360 fn mse_backward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<Array2<f32>> {
362 let n = predictions.len() as f32;
363 let diff = predictions - targets;
364 Ok(diff.mapv(|x| 2.0 * x / n))
365 }
366
367 fn cross_entropy_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
372 let log_pred = predictions.mapv(|x| (x.max(Self::EPS)).ln());
373 let elementwise = targets * &log_pred;
374 let loss = -elementwise.sum() / predictions.nrows() as f32;
375 Ok(loss)
376 }
377
378 fn cross_entropy_backward(
382 predictions: &Array2<f32>,
383 targets: &Array2<f32>,
384 ) -> Result<Array2<f32>> {
385 let n = predictions.nrows() as f32;
386 let safe_pred = predictions.mapv(|x| x.max(Self::EPS));
388 let grad = targets / &safe_pred;
389 Ok(grad.mapv(|x| (-x / n).clamp(-Self::MAX_GRAD, Self::MAX_GRAD)))
391 }
392
393 fn bce_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
397 let n = predictions.len() as f32;
398 let loss: f32 = predictions
399 .iter()
400 .zip(targets.iter())
401 .map(|(&p, &t)| {
402 let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS);
404 -(t * p_safe.ln() + (1.0 - t) * (1.0 - p_safe).ln())
405 })
406 .sum();
407 Ok(loss / n)
408 }
409
410 fn bce_backward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<Array2<f32>> {
414 let n = predictions.len() as f32;
415 let grad_vec: Vec<f32> = predictions
416 .iter()
417 .zip(targets.iter())
418 .map(|(&p, &t)| {
419 let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS);
421 let grad = (-t / p_safe + (1.0 - t) / (1.0 - p_safe)) / n;
422 grad.clamp(-Self::MAX_GRAD, Self::MAX_GRAD)
424 })
425 .collect();
426
427 Array2::from_shape_vec(predictions.dim(), grad_vec)
428 .map_err(|e| GnnError::training(format!("Failed to reshape gradient: {}", e)))
429 }
430}
431
432#[derive(Debug, Clone)]
434pub struct TrainingConfig {
435 pub epochs: usize,
437 pub batch_size: usize,
439 pub learning_rate: f32,
441 pub loss_type: LossType,
443 pub optimizer_type: OptimizerType,
445}
446
447impl Default for TrainingConfig {
448 fn default() -> Self {
449 Self {
450 epochs: 100,
451 batch_size: 32,
452 learning_rate: 0.001,
453 loss_type: LossType::Mse,
454 optimizer_type: OptimizerType::Adam {
455 learning_rate: 0.001,
456 beta1: 0.9,
457 beta2: 0.999,
458 epsilon: 1e-8,
459 },
460 }
461 }
462}
463
464#[derive(Debug, Clone)]
466pub struct TrainConfig {
467 pub batch_size: usize,
469 pub n_negatives: usize,
471 pub temperature: f32,
473 pub learning_rate: f32,
475 pub flush_threshold: usize,
477}
478
479impl Default for TrainConfig {
480 fn default() -> Self {
481 Self {
482 batch_size: 256,
483 n_negatives: 64,
484 temperature: 0.07,
485 learning_rate: 0.001,
486 flush_threshold: 1000,
487 }
488 }
489}
490
491#[derive(Debug, Clone)]
493pub struct OnlineConfig {
494 pub local_steps: usize,
496 pub propagate_updates: bool,
498}
499
500impl Default for OnlineConfig {
501 fn default() -> Self {
502 Self {
503 local_steps: 5,
504 propagate_updates: true,
505 }
506 }
507}
508
509pub fn info_nce_loss(
542 anchor: &[f32],
543 positives: &[&[f32]],
544 negatives: &[&[f32]],
545 temperature: f32,
546) -> f32 {
547 if positives.is_empty() {
548 return 0.0;
549 }
550
551 let pos_sims: Vec<f32> = positives
553 .iter()
554 .map(|pos| cosine_similarity(anchor, pos) / temperature)
555 .collect();
556
557 let neg_sims: Vec<f32> = negatives
559 .iter()
560 .map(|neg| cosine_similarity(anchor, neg) / temperature)
561 .collect();
562
563 let mut total_loss = 0.0;
565 for &pos_sim in &pos_sims {
566 let mut all_logits = vec![pos_sim];
573 all_logits.extend(&neg_sims);
574
575 let max_logit = all_logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
577 let log_sum_exp = max_logit
578 + all_logits
579 .iter()
580 .map(|&x| (x - max_logit).exp())
581 .sum::<f32>()
582 .ln();
583
584 total_loss -= pos_sim - log_sum_exp;
586 }
587
588 total_loss / positives.len() as f32
590}
591
592pub fn local_contrastive_loss(
624 node_embedding: &[f32],
625 neighbor_embeddings: &[Vec<f32>],
626 non_neighbor_embeddings: &[Vec<f32>],
627 temperature: f32,
628) -> f32 {
629 if neighbor_embeddings.is_empty() {
630 return 0.0;
631 }
632
633 let positives: Vec<&[f32]> = neighbor_embeddings.iter().map(|v| v.as_slice()).collect();
635 let negatives: Vec<&[f32]> = non_neighbor_embeddings
636 .iter()
637 .map(|v| v.as_slice())
638 .collect();
639
640 info_nce_loss(node_embedding, &positives, &negatives, temperature)
641}
642
643pub fn sgd_step(embedding: &mut [f32], grad: &[f32], learning_rate: f32) {
668 assert_eq!(
669 embedding.len(),
670 grad.len(),
671 "Embedding and gradient must have the same length"
672 );
673
674 for (emb, &g) in embedding.iter_mut().zip(grad.iter()) {
675 *emb -= learning_rate * g;
676 }
677}
678
679#[cfg(test)]
680mod tests {
681 use super::*;
682
683 #[test]
684 fn test_train_config_default() {
685 let config = TrainConfig::default();
686 assert_eq!(config.batch_size, 256);
687 assert_eq!(config.n_negatives, 64);
688 assert_eq!(config.temperature, 0.07);
689 assert_eq!(config.learning_rate, 0.001);
690 assert_eq!(config.flush_threshold, 1000);
691 }
692
693 #[test]
694 fn test_online_config_default() {
695 let config = OnlineConfig::default();
696 assert_eq!(config.local_steps, 5);
697 assert!(config.propagate_updates);
698 }
699
700 #[test]
701 fn test_info_nce_loss_basic() {
702 let anchor = vec![1.0, 0.0, 0.0];
704 let positive = vec![0.9, 0.1, 0.0];
705
706 let negative1 = vec![0.0, 1.0, 0.0];
708 let negative2 = vec![0.0, 0.0, 1.0];
709
710 let loss = info_nce_loss(&anchor, &[&positive], &[&negative1, &negative2], 0.07);
711
712 assert!(loss > 0.0);
714
715 assert!(loss.is_finite());
717 }
718
719 #[test]
720 fn test_info_nce_loss_perfect_match() {
721 let anchor = vec![1.0, 0.0, 0.0];
723 let positive = vec![1.0, 0.0, 0.0];
724
725 let negative1 = vec![0.0, 1.0, 0.0];
727 let negative2 = vec![0.0, 0.0, 1.0];
728
729 let loss = info_nce_loss(&anchor, &[&positive], &[&negative1, &negative2], 0.07);
730
731 assert!(loss < 1.0);
733 assert!(loss.is_finite());
734 }
735
736 #[test]
737 fn test_info_nce_loss_no_positives() {
738 let anchor = vec![1.0, 0.0, 0.0];
739 let negative1 = vec![0.0, 1.0, 0.0];
740
741 let loss = info_nce_loss(&anchor, &[], &[&negative1], 0.07);
742
743 assert_eq!(loss, 0.0);
745 }
746
747 #[test]
748 fn test_info_nce_loss_temperature_effect() {
749 let anchor = vec![1.0, 0.0, 0.0];
750 let positive = vec![0.9, 0.1, 0.0];
751 let negative = vec![0.0, 1.0, 0.0];
752
753 let loss_low_temp = info_nce_loss(&anchor, &[&positive], &[&negative], 0.07);
756 let loss_high_temp = info_nce_loss(&anchor, &[&positive], &[&negative], 1.0);
757
758 assert!(
760 loss_low_temp > 0.0 && loss_low_temp.is_finite(),
761 "Low temp loss should be positive and finite, got: {}",
762 loss_low_temp
763 );
764 assert!(
765 loss_high_temp > 0.0 && loss_high_temp.is_finite(),
766 "High temp loss should be positive and finite, got: {}",
767 loss_high_temp
768 );
769
770 assert!(loss_low_temp < 10.0, "Loss should not be too large");
772 assert!(loss_high_temp < 10.0, "Loss should not be too large");
773 }
774
775 #[test]
776 fn test_local_contrastive_loss_basic() {
777 let node = vec![1.0, 0.0, 0.0];
778 let neighbor = vec![0.9, 0.1, 0.0];
779 let non_neighbor1 = vec![0.0, 1.0, 0.0];
780 let non_neighbor2 = vec![0.0, 0.0, 1.0];
781
782 let loss =
783 local_contrastive_loss(&node, &[neighbor], &[non_neighbor1, non_neighbor2], 0.07);
784
785 assert!(loss > 0.0);
787 assert!(loss.is_finite());
788 }
789
790 #[test]
791 fn test_local_contrastive_loss_multiple_neighbors() {
792 let node = vec![1.0, 0.0, 0.0];
793 let neighbor1 = vec![0.9, 0.1, 0.0];
794 let neighbor2 = vec![0.95, 0.05, 0.0];
795 let non_neighbor = vec![0.0, 1.0, 0.0];
796
797 let loss = local_contrastive_loss(&node, &[neighbor1, neighbor2], &[non_neighbor], 0.07);
798
799 assert!(loss > 0.0);
800 assert!(loss.is_finite());
801 }
802
803 #[test]
804 fn test_local_contrastive_loss_no_neighbors() {
805 let node = vec![1.0, 0.0, 0.0];
806 let non_neighbor = vec![0.0, 1.0, 0.0];
807
808 let loss = local_contrastive_loss(&node, &[], &[non_neighbor], 0.07);
809
810 assert_eq!(loss, 0.0);
812 }
813
814 #[test]
815 fn test_sgd_step_basic() {
816 let mut embedding = vec![1.0, 2.0, 3.0];
817 let gradient = vec![0.1, -0.2, 0.3];
818 let learning_rate = 0.01;
819
820 sgd_step(&mut embedding, &gradient, learning_rate);
821
822 assert!((embedding[0] - 0.999).abs() < 1e-6); assert!((embedding[1] - 2.002).abs() < 1e-6); assert!((embedding[2] - 2.997).abs() < 1e-6); }
827
828 #[test]
829 fn test_sgd_step_zero_gradient() {
830 let mut embedding = vec![1.0, 2.0, 3.0];
831 let original = embedding.clone();
832 let gradient = vec![0.0, 0.0, 0.0];
833 let learning_rate = 0.01;
834
835 sgd_step(&mut embedding, &gradient, learning_rate);
836
837 assert_eq!(embedding, original);
839 }
840
841 #[test]
842 fn test_sgd_step_zero_learning_rate() {
843 let mut embedding = vec![1.0, 2.0, 3.0];
844 let original = embedding.clone();
845 let gradient = vec![0.1, 0.2, 0.3];
846 let learning_rate = 0.0;
847
848 sgd_step(&mut embedding, &gradient, learning_rate);
849
850 assert_eq!(embedding, original);
852 }
853
854 #[test]
855 fn test_sgd_step_large_learning_rate() {
856 let mut embedding = vec![10.0, 20.0, 30.0];
857 let gradient = vec![1.0, 2.0, 3.0];
858 let learning_rate = 5.0;
859
860 sgd_step(&mut embedding, &gradient, learning_rate);
861
862 assert!((embedding[0] - 5.0).abs() < 1e-5); assert!((embedding[1] - 10.0).abs() < 1e-5); assert!((embedding[2] - 15.0).abs() < 1e-5); }
867
868 #[test]
869 #[should_panic(expected = "Embedding and gradient must have the same length")]
870 fn test_sgd_step_mismatched_lengths() {
871 let mut embedding = vec![1.0, 2.0, 3.0];
872 let gradient = vec![0.1, 0.2]; sgd_step(&mut embedding, &gradient, 0.01);
875 }
876
877 #[test]
878 fn test_info_nce_loss_multiple_positives() {
879 let anchor = vec![1.0, 0.0, 0.0];
880 let positive1 = vec![0.9, 0.1, 0.0];
881 let positive2 = vec![0.95, 0.05, 0.0];
882 let negative = vec![0.0, 1.0, 0.0];
883
884 let loss = info_nce_loss(&anchor, &[&positive1, &positive2], &[&negative], 0.07);
885
886 assert!(loss > 0.0);
888 assert!(loss.is_finite());
889 }
890
891 #[test]
892 fn test_contrastive_loss_gradient_property() {
893 let anchor = vec![1.0, 0.0, 0.0];
895 let positive_far = vec![0.5, 0.5, 0.0];
896 let positive_close = vec![0.9, 0.1, 0.0];
897 let negative = vec![0.0, 1.0, 0.0];
898
899 let loss_far = info_nce_loss(&anchor, &[&positive_far], &[&negative], 0.07);
900 let loss_close = info_nce_loss(&anchor, &[&positive_close], &[&negative], 0.07);
901
902 assert!(loss_close < loss_far);
904 }
905
906 #[test]
907 fn test_sgd_optimizer_basic() {
908 let optimizer_type = OptimizerType::Sgd {
909 learning_rate: 0.1,
910 momentum: 0.0,
911 };
912 let mut optimizer = Optimizer::new(optimizer_type);
913
914 let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
915 let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
916
917 let result = optimizer.step(&mut params, &grads);
918 assert!(result.is_ok());
919
920 assert!((params[[0, 0]] - 0.99).abs() < 1e-6); assert!((params[[0, 1]] - 1.98).abs() < 1e-6); assert!((params[[1, 0]] - 2.97).abs() < 1e-6); assert!((params[[1, 1]] - 3.96).abs() < 1e-6); }
926
927 #[test]
928 fn test_sgd_optimizer_with_momentum() {
929 let optimizer_type = OptimizerType::Sgd {
930 learning_rate: 0.1,
931 momentum: 0.9,
932 };
933 let mut optimizer = Optimizer::new(optimizer_type);
934
935 let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
936 let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
937
938 let result = optimizer.step(&mut params, &grads);
940 assert!(result.is_ok());
941
942 assert!((params[[0, 0]] - 0.99).abs() < 1e-6);
944
945 let result = optimizer.step(&mut params, &grads);
947 assert!(result.is_ok());
948
949 assert!(params[[0, 0]] < 0.99);
951 }
952
953 #[test]
954 fn test_adam_optimizer_basic() {
955 let optimizer_type = OptimizerType::Adam {
956 learning_rate: 0.001,
957 beta1: 0.9,
958 beta2: 0.999,
959 epsilon: 1e-8,
960 };
961 let mut optimizer = Optimizer::new(optimizer_type);
962
963 let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
964 let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
965
966 let original_params = params.clone();
967 let result = optimizer.step(&mut params, &grads);
968 assert!(result.is_ok());
969
970 assert!(params[[0, 0]] < original_params[[0, 0]]);
972 assert!(params[[0, 1]] < original_params[[0, 1]]);
973 assert!(params[[1, 0]] < original_params[[1, 0]]);
974 assert!(params[[1, 1]] < original_params[[1, 1]]);
975
976 assert!(params.iter().all(|&x| x.is_finite()));
978 }
979
980 #[test]
981 fn test_adam_optimizer_multiple_steps() {
982 let optimizer_type = OptimizerType::Adam {
983 learning_rate: 0.01,
984 beta1: 0.9,
985 beta2: 0.999,
986 epsilon: 1e-8,
987 };
988 let mut optimizer = Optimizer::new(optimizer_type);
989
990 let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
991 let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
992 let initial_params = params.clone();
993
994 for _ in 0..10 {
996 let result = optimizer.step(&mut params, &grads);
997 assert!(result.is_ok());
998 assert!(params.iter().all(|&x| x.is_finite()));
999 }
1000
1001 assert!(params[[0, 0]] < initial_params[[0, 0]]);
1003 assert!(params[[1, 1]] < initial_params[[1, 1]]);
1004 for i in 0..2 {
1006 for j in 0..2 {
1007 assert!(params[[i, j]] < initial_params[[i, j]]);
1008 }
1009 }
1010 }
1011
1012 #[test]
1013 fn test_adam_bias_correction() {
1014 let optimizer_type = OptimizerType::Adam {
1015 learning_rate: 0.001,
1016 beta1: 0.9,
1017 beta2: 0.999,
1018 epsilon: 1e-8,
1019 };
1020 let mut optimizer = Optimizer::new(optimizer_type.clone());
1021
1022 let mut params = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
1023 let grads = Array2::from_shape_vec((1, 1), vec![0.1]).unwrap();
1024
1025 let result = optimizer.step(&mut params, &grads);
1027 assert!(result.is_ok());
1028 let first_update = 1.0 - params[[0, 0]];
1029
1030 let mut optimizer = Optimizer::new(optimizer_type);
1032 let mut params = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
1033
1034 for _ in 0..100 {
1036 let _ = optimizer.step(&mut params, &grads);
1037 }
1038
1039 assert!(first_update > 0.0);
1041 }
1042
1043 #[test]
1044 fn test_optimizer_shape_mismatch() {
1045 let optimizer_type = OptimizerType::Adam {
1046 learning_rate: 0.001,
1047 beta1: 0.9,
1048 beta2: 0.999,
1049 epsilon: 1e-8,
1050 };
1051 let mut optimizer = Optimizer::new(optimizer_type);
1052
1053 let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1054 let grads = Array2::from_shape_vec((3, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
1055
1056 let result = optimizer.step(&mut params, &grads);
1057 assert!(result.is_err());
1058 if let Err(GnnError::DimensionMismatch { expected, actual }) = result {
1059 assert!(expected.contains("2, 2"));
1060 assert!(actual.contains("3, 2"));
1061 } else {
1062 panic!("Expected DimensionMismatch error");
1063 }
1064 }
1065
1066 #[test]
1067 fn test_adam_convergence() {
1068 let optimizer_type = OptimizerType::Adam {
1070 learning_rate: 0.5,
1071 beta1: 0.9,
1072 beta2: 0.999,
1073 epsilon: 1e-8,
1074 };
1075 let mut optimizer = Optimizer::new(optimizer_type);
1076
1077 let mut params = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
1079
1080 for _ in 0..200 {
1082 let grads =
1083 Array2::from_shape_vec((1, 2), vec![2.0 * params[[0, 0]], 2.0 * params[[0, 1]]])
1084 .unwrap();
1085 let _ = optimizer.step(&mut params, &grads);
1086 }
1087
1088 assert!(params[[0, 0]].abs() < 0.5);
1090 assert!(params[[0, 1]].abs() < 0.5);
1091 }
1092
1093 #[test]
1094 fn test_sgd_momentum_convergence() {
1095 let optimizer_type = OptimizerType::Sgd {
1097 learning_rate: 0.01,
1098 momentum: 0.9,
1099 };
1100 let mut optimizer = Optimizer::new(optimizer_type);
1101
1102 let mut params = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
1104
1105 for _ in 0..200 {
1107 let grads =
1108 Array2::from_shape_vec((1, 2), vec![2.0 * params[[0, 0]], 2.0 * params[[0, 1]]])
1109 .unwrap();
1110 let _ = optimizer.step(&mut params, &grads);
1111 }
1112
1113 assert!(params[[0, 0]].abs() < 0.5);
1115 assert!(params[[0, 1]].abs() < 0.5);
1116 }
1117
1118 #[test]
1121 fn test_mse_loss_zero_when_equal() {
1122 let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1123 let target = pred.clone();
1124 let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
1125 assert!(
1126 (loss - 0.0).abs() < 1e-6,
1127 "MSE should be 0 when pred == target"
1128 );
1129 }
1130
1131 #[test]
1132 fn test_mse_loss_positive() {
1133 let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1134 let target = Array2::from_shape_vec((2, 2), vec![2.0, 3.0, 4.0, 5.0]).unwrap();
1135 let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
1136 assert!((loss - 1.0).abs() < 1e-6, "MSE should be 1.0, got {}", loss);
1138 }
1139
1140 #[test]
1141 fn test_mse_loss_varying_diffs() {
1142 let pred = Array2::from_shape_vec((1, 4), vec![0.0, 0.0, 0.0, 0.0]).unwrap();
1143 let target = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1144 let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
1145 assert!((loss - 7.5).abs() < 1e-6, "MSE should be 7.5, got {}", loss);
1147 }
1148
1149 #[test]
1150 fn test_mse_gradient_shape() {
1151 let pred = Array2::from_shape_vec((2, 3), vec![0.0; 6]).unwrap();
1152 let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1153 let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
1154 assert_eq!(grad.shape(), pred.shape());
1155 }
1156
1157 #[test]
1158 fn test_mse_gradient_direction() {
1159 let pred = Array2::from_shape_vec((1, 2), vec![0.0, 2.0]).unwrap();
1160 let target = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap();
1161 let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
1162 assert!(
1164 grad[[0, 0]] < 0.0,
1165 "Gradient should be negative when pred < target"
1166 );
1167 assert!(
1168 grad[[0, 1]] > 0.0,
1169 "Gradient should be positive when pred > target"
1170 );
1171 }
1172
1173 #[test]
1174 fn test_mse_gradient_zero_when_equal() {
1175 let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1176 let target = pred.clone();
1177 let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
1178 assert!(
1179 grad.iter().all(|&x| x.abs() < 1e-6),
1180 "Gradient should be zero when pred == target"
1181 );
1182 }
1183
1184 #[test]
1185 fn test_bce_loss_perfect_predictions() {
1186 let pred = Array2::from_shape_vec((1, 2), vec![0.999, 0.001]).unwrap();
1187 let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
1188 let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
1189 assert!(
1191 loss < 0.1,
1192 "BCE should be low for good predictions, got {}",
1193 loss
1194 );
1195 }
1196
1197 #[test]
1198 fn test_bce_loss_bad_predictions() {
1199 let pred = Array2::from_shape_vec((1, 2), vec![0.001, 0.999]).unwrap();
1200 let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
1201 let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
1202 assert!(
1204 loss > 1.0,
1205 "BCE should be high for bad predictions, got {}",
1206 loss
1207 );
1208 }
1209
1210 #[test]
1211 fn test_bce_loss_numerical_stability() {
1212 let pred = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap();
1214 let target = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap();
1215 let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
1216 assert!(
1217 loss.is_finite(),
1218 "BCE should be finite even with extreme values"
1219 );
1220 }
1221
1222 #[test]
1223 fn test_bce_gradient_shape() {
1224 let pred = Array2::from_shape_vec((3, 2), vec![0.5; 6]).unwrap();
1225 let target = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]).unwrap();
1226 let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
1227 assert_eq!(grad.shape(), pred.shape());
1228 }
1229
1230 #[test]
1231 fn test_bce_gradient_direction() {
1232 let pred = Array2::from_shape_vec((1, 2), vec![0.3, 0.7]).unwrap();
1233 let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
1234 let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
1235 assert!(
1237 grad[[0, 0]] < 0.0,
1238 "Gradient should be negative to increase pred towards 1"
1239 );
1240 assert!(
1242 grad[[0, 1]] > 0.0,
1243 "Gradient should be positive to decrease pred towards 0"
1244 );
1245 }
1246
1247 #[test]
1248 fn test_cross_entropy_one_hot() {
1249 let pred = Array2::from_shape_vec((2, 3), vec![0.7, 0.2, 0.1, 0.1, 0.8, 0.1]).unwrap();
1251 let target = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap();
1252 let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap();
1253 assert!(
1255 loss > 0.0 && loss < 1.0,
1256 "CE should be reasonable for good predictions, got {}",
1257 loss
1258 );
1259 }
1260
1261 #[test]
1262 fn test_cross_entropy_wrong_class() {
1263 let pred = Array2::from_shape_vec((1, 3), vec![0.1, 0.1, 0.8]).unwrap();
1264 let target = Array2::from_shape_vec((1, 3), vec![1.0, 0.0, 0.0]).unwrap();
1265 let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap();
1266 assert!(
1268 loss > 1.0,
1269 "CE should be high for wrong predictions, got {}",
1270 loss
1271 );
1272 }
1273
1274 #[test]
1275 fn test_cross_entropy_gradient_shape() {
1276 let pred = Array2::from_shape_vec((2, 4), vec![0.25; 8]).unwrap();
1277 let target =
1278 Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
1279 let grad = Loss::gradient(LossType::CrossEntropy, &pred, &target).unwrap();
1280 assert_eq!(grad.shape(), pred.shape());
1281 }
1282
1283 #[test]
1284 fn test_loss_dimension_mismatch_error() {
1285 let pred = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap();
1286 let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1287
1288 let result = Loss::compute(LossType::Mse, &pred, &target);
1289 assert!(result.is_err(), "Should error on dimension mismatch");
1290
1291 let result = Loss::gradient(LossType::Mse, &pred, &target);
1292 assert!(
1293 result.is_err(),
1294 "Gradient should error on dimension mismatch"
1295 );
1296 }
1297
1298 #[test]
1299 fn test_loss_empty_array_error() {
1300 let pred = Array2::from_shape_vec((0, 2), vec![]).unwrap();
1301 let target = Array2::from_shape_vec((0, 2), vec![]).unwrap();
1302
1303 let result = Loss::compute(LossType::Mse, &pred, &target);
1304 assert!(result.is_err(), "Should error on empty arrays");
1305
1306 let result = Loss::gradient(LossType::Mse, &pred, &target);
1307 assert!(result.is_err(), "Gradient should error on empty arrays");
1308 }
1309
1310 #[test]
1311 fn test_loss_gradient_numerical_check() {
1312 let pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.8]).unwrap();
1314 let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
1315
1316 let analytical_grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
1317
1318 let eps = 1e-5;
1320 for i in 0..2 {
1321 let mut pred_plus = pred.clone();
1322 let mut pred_minus = pred.clone();
1323 pred_plus[[0, i]] += eps;
1324 pred_minus[[0, i]] -= eps;
1325
1326 let loss_plus = Loss::compute(LossType::Mse, &pred_plus, &target).unwrap();
1327 let loss_minus = Loss::compute(LossType::Mse, &pred_minus, &target).unwrap();
1328
1329 let numerical_grad = (loss_plus - loss_minus) / (2.0 * eps);
1330 let error = (analytical_grad[[0, i]] - numerical_grad).abs();
1331
1332 assert!(
1333 error < 1e-3,
1334 "Numerical gradient check failed: analytical={}, numerical={}",
1335 analytical_grad[[0, i]],
1336 numerical_grad
1337 );
1338 }
1339 }
1340
1341 #[test]
1342 fn test_training_loop_integration() {
1343 let mut optimizer = Optimizer::new(OptimizerType::Sgd {
1345 learning_rate: 0.1,
1346 momentum: 0.0,
1347 });
1348
1349 let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
1350 let mut pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).unwrap();
1351
1352 let initial_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
1353
1354 for _ in 0..10 {
1356 let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
1357 optimizer.step(&mut pred, &grad).unwrap();
1358 }
1359
1360 let final_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
1361
1362 assert!(
1363 final_loss < initial_loss,
1364 "Loss should decrease during training"
1365 );
1366 }
1367}