1use scirs2_core::ndarray::{ArrayD, IxDyn, Zip};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
17pub enum TensorLossError {
18 ShapeMismatch {
20 expected: Vec<usize>,
21 got: Vec<usize>,
22 },
23 InvalidTarget(String),
25 DivisionByZero,
27 EmptyInput,
29 InvalidConfig(String),
31}
32
33impl std::fmt::Display for TensorLossError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::ShapeMismatch { expected, got } => {
37 write!(f, "shape mismatch: expected {:?}, got {:?}", expected, got)
38 }
39 Self::InvalidTarget(msg) => write!(f, "invalid target: {}", msg),
40 Self::DivisionByZero => write!(f, "division by zero encountered"),
41 Self::EmptyInput => write!(f, "input tensor is empty"),
42 Self::InvalidConfig(msg) => write!(f, "invalid configuration: {}", msg),
43 }
44 }
45}
46
47impl std::error::Error for TensorLossError {}
48
49#[derive(Debug, Clone, PartialEq)]
55pub enum LossReduction {
56 Mean,
58 Sum,
60 None,
62}
63
64#[derive(Debug, Clone)]
70pub struct TensorLossOutput {
71 pub loss: f64,
73 pub loss_tensor: Option<ArrayD<f64>>,
75 pub grad: Option<ArrayD<f64>>,
77}
78
79pub trait TensorLoss: std::fmt::Debug {
85 fn compute(
87 &self,
88 pred: &ArrayD<f64>,
89 target: &ArrayD<f64>,
90 ) -> Result<TensorLossOutput, TensorLossError>;
91
92 fn name(&self) -> &'static str;
94}
95
96#[derive(Debug, Clone)]
102pub struct TensorLossConfig {
103 pub reduction: LossReduction,
105 pub compute_grad: bool,
107 pub epsilon: f64,
109}
110
111impl Default for TensorLossConfig {
112 fn default() -> Self {
113 Self {
114 reduction: LossReduction::Mean,
115 compute_grad: true,
116 epsilon: 1e-8,
117 }
118 }
119}
120
121fn validate_shapes(pred: &ArrayD<f64>, target: &ArrayD<f64>) -> Result<usize, TensorLossError> {
127 let n = pred.len();
128 if n == 0 {
129 return Err(TensorLossError::EmptyInput);
130 }
131 if pred.shape() != target.shape() {
132 return Err(TensorLossError::ShapeMismatch {
133 expected: pred.shape().to_vec(),
134 got: target.shape().to_vec(),
135 });
136 }
137 Ok(n)
138}
139
140fn apply_reduction(
142 loss_elem: ArrayD<f64>,
143 grad_elem: Option<ArrayD<f64>>,
144 reduction: &LossReduction,
145 n: usize,
146) -> TensorLossOutput {
147 match reduction {
148 LossReduction::None => TensorLossOutput {
149 loss: 0.0,
150 loss_tensor: Some(loss_elem),
151 grad: grad_elem,
152 },
153 LossReduction::Sum => {
154 let loss = loss_elem.sum();
155 TensorLossOutput {
156 loss,
157 loss_tensor: None,
158 grad: grad_elem,
159 }
160 }
161 LossReduction::Mean => {
162 let loss = loss_elem.sum() / n as f64;
163 TensorLossOutput {
164 loss,
165 loss_tensor: None,
166 grad: grad_elem,
167 }
168 }
169 }
170}
171
172#[derive(Debug, Clone)]
180pub struct TensorMseLoss {
181 pub config: TensorLossConfig,
182}
183
184impl TensorMseLoss {
185 pub fn new() -> Self {
187 Self {
188 config: TensorLossConfig::default(),
189 }
190 }
191
192 pub fn with_config(config: TensorLossConfig) -> Self {
194 Self { config }
195 }
196}
197
198impl Default for TensorMseLoss {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl TensorLoss for TensorMseLoss {
205 fn name(&self) -> &'static str {
206 "mse"
207 }
208
209 fn compute(
210 &self,
211 pred: &ArrayD<f64>,
212 target: &ArrayD<f64>,
213 ) -> Result<TensorLossOutput, TensorLossError> {
214 let n = validate_shapes(pred, target)?;
215
216 let diff = pred - target;
217 let loss_elem = diff.mapv(|x| x * x);
218
219 let grad = if self.config.compute_grad {
220 let scale = match self.config.reduction {
221 LossReduction::Mean => 2.0 / n as f64,
222 LossReduction::Sum | LossReduction::None => 2.0,
223 };
224 Some(diff.mapv(|x| x * scale))
225 } else {
226 None
227 };
228
229 Ok(apply_reduction(loss_elem, grad, &self.config.reduction, n))
230 }
231}
232
233#[derive(Debug, Clone)]
242pub struct TensorBCELoss {
243 pub config: TensorLossConfig,
244}
245
246impl TensorBCELoss {
247 pub fn new() -> Self {
249 Self {
250 config: TensorLossConfig::default(),
251 }
252 }
253}
254
255impl Default for TensorBCELoss {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261impl TensorLoss for TensorBCELoss {
262 fn name(&self) -> &'static str {
263 "bce"
264 }
265
266 fn compute(
267 &self,
268 pred: &ArrayD<f64>,
269 target: &ArrayD<f64>,
270 ) -> Result<TensorLossOutput, TensorLossError> {
271 let n = validate_shapes(pred, target)?;
272 let eps = self.config.epsilon;
273
274 let p = pred.mapv(|x| x.clamp(eps, 1.0 - eps));
276
277 let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
278 let mut grad_elem = if self.config.compute_grad {
279 Some(ArrayD::zeros(IxDyn(pred.shape())))
280 } else {
281 None
282 };
283
284 Zip::from(&mut loss_elem)
285 .and(&p)
286 .and(target)
287 .for_each(|l, &pi, &ti| {
288 *l = -(ti * pi.ln() + (1.0 - ti) * (1.0 - pi).ln());
289 });
290
291 if let Some(ref mut g) = grad_elem {
292 Zip::from(g).and(&p).and(target).for_each(|gi, &pi, &ti| {
293 *gi = -(ti / pi - (1.0 - ti) / (1.0 - pi));
294 });
295 }
296
297 Ok(apply_reduction(
298 loss_elem,
299 grad_elem,
300 &self.config.reduction,
301 n,
302 ))
303 }
304}
305
306#[derive(Debug, Clone)]
315pub struct TensorCrossEntropyLoss {
316 pub config: TensorLossConfig,
317 pub label_smoothing: f64,
319 pub apply_softmax: bool,
321}
322
323impl TensorCrossEntropyLoss {
324 pub fn new() -> Self {
326 Self {
327 config: TensorLossConfig::default(),
328 label_smoothing: 0.0,
329 apply_softmax: false,
330 }
331 }
332}
333
334impl Default for TensorCrossEntropyLoss {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340fn softmax_flat(logits: &ArrayD<f64>) -> ArrayD<f64> {
342 let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
343 let shifted = logits.mapv(|x| (x - max_val).exp());
344 let sum = shifted.sum();
345 if sum == 0.0 {
346 shifted
347 } else {
348 shifted.mapv(|x| x / sum)
349 }
350}
351
352impl TensorLoss for TensorCrossEntropyLoss {
353 fn name(&self) -> &'static str {
354 "cross_entropy"
355 }
356
357 fn compute(
358 &self,
359 pred: &ArrayD<f64>,
360 target: &ArrayD<f64>,
361 ) -> Result<TensorLossOutput, TensorLossError> {
362 let n = validate_shapes(pred, target)?;
363 let eps = self.config.epsilon;
364 let k = n as f64;
365
366 let p = if self.apply_softmax {
368 softmax_flat(pred)
369 } else {
370 pred.clone()
371 };
372
373 let t_smooth = if self.label_smoothing > 0.0 {
375 let ls = self.label_smoothing;
376 target.mapv(|ti| ti * (1.0 - ls) + ls / k)
377 } else {
378 target.clone()
379 };
380
381 let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
382 Zip::from(&mut loss_elem)
383 .and(&p)
384 .and(&t_smooth)
385 .for_each(|l, &pi, &ti| {
386 *l = -(ti * (pi + eps).ln());
387 });
388
389 let grad = if self.config.compute_grad {
390 let mut g = ArrayD::zeros(IxDyn(pred.shape()));
392 Zip::from(&mut g)
393 .and(&p)
394 .and(&t_smooth)
395 .for_each(|gi, &pi, &ti| {
396 *gi = -ti / (pi + eps);
397 });
398 if self.apply_softmax {
400 Some((&p) - &t_smooth)
401 } else {
402 Some(g)
403 }
404 } else {
405 None
406 };
407
408 Ok(apply_reduction(loss_elem, grad, &self.config.reduction, n))
409 }
410}
411
412#[derive(Debug, Clone)]
420pub struct TensorFocalLoss {
421 pub config: TensorLossConfig,
422 pub gamma: f64,
424 pub alpha: Option<f64>,
426}
427
428impl TensorFocalLoss {
429 pub fn new() -> Self {
431 Self {
432 config: TensorLossConfig::default(),
433 gamma: 2.0,
434 alpha: None,
435 }
436 }
437
438 pub fn with_gamma(gamma: f64) -> Self {
440 Self {
441 config: TensorLossConfig::default(),
442 gamma,
443 alpha: None,
444 }
445 }
446}
447
448impl Default for TensorFocalLoss {
449 fn default() -> Self {
450 Self::new()
451 }
452}
453
454impl TensorLoss for TensorFocalLoss {
455 fn name(&self) -> &'static str {
456 "focal"
457 }
458
459 fn compute(
460 &self,
461 pred: &ArrayD<f64>,
462 target: &ArrayD<f64>,
463 ) -> Result<TensorLossOutput, TensorLossError> {
464 let n = validate_shapes(pred, target)?;
465 let eps = self.config.epsilon;
466 let gamma = self.gamma;
467
468 let p = pred.mapv(|x| x.clamp(eps, 1.0 - eps));
470
471 let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
472 let mut grad_elem = if self.config.compute_grad {
473 Some(ArrayD::zeros(IxDyn(pred.shape())))
474 } else {
475 None
476 };
477
478 Zip::from(&mut loss_elem)
479 .and(&p)
480 .and(target)
481 .for_each(|l, &pi, &ti| {
482 let p_t = if ti > 0.5 { pi } else { 1.0 - pi };
484 let modulator = (1.0 - p_t).powf(gamma);
485 let weight = match self.alpha {
486 Some(a) => {
487 if ti > 0.5 {
488 a
489 } else {
490 1.0 - a
491 }
492 }
493 None => 1.0,
494 };
495 *l = -weight * modulator * (p_t + eps).ln();
496 });
497
498 if let Some(ref mut g) = grad_elem {
499 Zip::from(g).and(&p).and(target).for_each(|gi, &pi, &ti| {
500 let p_t = if ti > 0.5 { pi } else { 1.0 - pi };
501 let sign = if ti > 0.5 { 1.0_f64 } else { -1.0_f64 };
502 let modulator = (1.0 - p_t).powf(gamma);
503 let weight = match self.alpha {
504 Some(a) => {
505 if ti > 0.5 {
506 a
507 } else {
508 1.0 - a
509 }
510 }
511 None => 1.0,
512 };
513 let term1 = if gamma > 0.0 {
516 gamma * (1.0 - p_t).powf(gamma - 1.0) * (p_t + eps).ln()
517 } else {
518 0.0
519 };
520 let term2 = modulator / (p_t + eps);
521 *gi = -weight * (term1 - term2) * sign;
523 });
524 }
525
526 Ok(apply_reduction(
527 loss_elem,
528 grad_elem,
529 &self.config.reduction,
530 n,
531 ))
532 }
533}
534
535#[derive(Debug, Clone)]
547pub struct TensorHuberLoss {
548 pub config: TensorLossConfig,
549 pub delta: f64,
551}
552
553impl TensorHuberLoss {
554 pub fn new() -> Self {
556 Self {
557 config: TensorLossConfig::default(),
558 delta: 1.0,
559 }
560 }
561
562 pub fn with_delta(delta: f64) -> Self {
564 Self {
565 config: TensorLossConfig::default(),
566 delta,
567 }
568 }
569}
570
571impl Default for TensorHuberLoss {
572 fn default() -> Self {
573 Self::new()
574 }
575}
576
577impl TensorLoss for TensorHuberLoss {
578 fn name(&self) -> &'static str {
579 "huber"
580 }
581
582 fn compute(
583 &self,
584 pred: &ArrayD<f64>,
585 target: &ArrayD<f64>,
586 ) -> Result<TensorLossOutput, TensorLossError> {
587 let n = validate_shapes(pred, target)?;
588 let delta = self.delta;
589
590 if delta <= 0.0 {
591 return Err(TensorLossError::InvalidConfig(format!(
592 "delta must be positive, got {}",
593 delta
594 )));
595 }
596
597 let diff = pred - target;
598 let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
599 let mut grad_elem = if self.config.compute_grad {
600 Some(ArrayD::zeros(IxDyn(pred.shape())))
601 } else {
602 None
603 };
604
605 Zip::from(&mut loss_elem).and(&diff).for_each(|l, &d| {
606 let abs_d = d.abs();
607 if abs_d < delta {
608 *l = 0.5 * d * d / delta;
609 } else {
610 *l = abs_d - 0.5 * delta;
611 }
612 });
613
614 if let Some(ref mut g) = grad_elem {
615 Zip::from(g).and(&diff).for_each(|gi, &d| {
616 let abs_d = d.abs();
617 let sign = if d > 0.0 {
618 1.0
619 } else if d < 0.0 {
620 -1.0
621 } else {
622 0.0
623 };
624 *gi = sign * (abs_d / delta).min(1.0);
625 });
626 }
627
628 Ok(apply_reduction(
629 loss_elem,
630 grad_elem,
631 &self.config.reduction,
632 n,
633 ))
634 }
635}
636
637#[derive(Debug, Clone)]
645pub struct TensorKLDivLoss {
646 pub config: TensorLossConfig,
647}
648
649impl TensorKLDivLoss {
650 pub fn new() -> Self {
652 Self {
653 config: TensorLossConfig::default(),
654 }
655 }
656}
657
658impl Default for TensorKLDivLoss {
659 fn default() -> Self {
660 Self::new()
661 }
662}
663
664impl TensorLoss for TensorKLDivLoss {
665 fn name(&self) -> &'static str {
666 "kl_div"
667 }
668
669 fn compute(
670 &self,
671 pred: &ArrayD<f64>,
672 target: &ArrayD<f64>,
673 ) -> Result<TensorLossOutput, TensorLossError> {
674 let n = validate_shapes(pred, target)?;
675 let eps = self.config.epsilon;
676
677 let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
678 let mut grad_elem = if self.config.compute_grad {
679 Some(ArrayD::zeros(IxDyn(pred.shape())))
680 } else {
681 None
682 };
683
684 Zip::from(&mut loss_elem)
685 .and(pred)
686 .and(target)
687 .for_each(|l, &pi, &ti| {
688 if ti > eps {
689 let p_safe = pi.max(eps);
690 *l = ti * (ti.ln() - p_safe.ln());
692 }
693 });
695
696 if let Some(ref mut g) = grad_elem {
697 Zip::from(g).and(pred).and(target).for_each(|gi, &pi, &ti| {
699 if ti > eps {
700 *gi = -ti / (pi + eps);
701 }
702 });
703 }
704
705 Ok(apply_reduction(
706 loss_elem,
707 grad_elem,
708 &self.config.reduction,
709 n,
710 ))
711 }
712}
713
714#[derive(Debug, Clone)]
722pub struct TensorCosineEmbeddingLoss {
723 pub config: TensorLossConfig,
724}
725
726impl TensorCosineEmbeddingLoss {
727 pub fn new() -> Self {
729 Self {
730 config: TensorLossConfig::default(),
731 }
732 }
733}
734
735impl Default for TensorCosineEmbeddingLoss {
736 fn default() -> Self {
737 Self::new()
738 }
739}
740
741impl TensorLoss for TensorCosineEmbeddingLoss {
742 fn name(&self) -> &'static str {
743 "cosine_embedding"
744 }
745
746 fn compute(
747 &self,
748 pred: &ArrayD<f64>,
749 target: &ArrayD<f64>,
750 ) -> Result<TensorLossOutput, TensorLossError> {
751 let n = validate_shapes(pred, target)?;
752 let eps = self.config.epsilon;
753
754 let dot: f64 = pred.iter().zip(target.iter()).map(|(p, t)| p * t).sum();
755 let norm_p: f64 = pred.iter().map(|x| x * x).sum::<f64>().sqrt();
756 let norm_t: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
757 let denom = norm_p * norm_t + eps;
758
759 let similarity = dot / denom;
760 let scalar_loss = 1.0 - similarity;
761
762 let grad = if self.config.compute_grad {
766 let mut g = ArrayD::zeros(IxDyn(pred.shape()));
767 let norm_p_sq = norm_p * norm_p + eps;
768 Zip::from(&mut g)
769 .and(pred)
770 .and(target)
771 .for_each(|gi, &pi, &ti| {
772 let d_sim = ti / denom - dot * pi / (norm_p_sq * denom);
773 *gi = -d_sim;
774 });
775 Some(g)
776 } else {
777 None
778 };
779
780 match self.config.reduction {
782 LossReduction::None => {
783 let loss_tensor = ArrayD::from_elem(IxDyn(pred.shape()), scalar_loss / n as f64);
786 Ok(TensorLossOutput {
787 loss: 0.0,
788 loss_tensor: Some(loss_tensor),
789 grad,
790 })
791 }
792 LossReduction::Mean | LossReduction::Sum => Ok(TensorLossOutput {
793 loss: scalar_loss,
794 loss_tensor: None,
795 grad,
796 }),
797 }
798 }
799}
800
801#[derive(Debug)]
810pub struct TensorLossRegistry {
811 losses: HashMap<String, Box<dyn TensorLoss>>,
812}
813
814impl TensorLossRegistry {
815 pub fn new() -> Self {
817 Self {
818 losses: HashMap::new(),
819 }
820 }
821
822 pub fn with_all_defaults() -> Self {
826 let mut reg = Self::new();
827 reg.register("mse", Box::new(TensorMseLoss::new()));
828 reg.register("bce", Box::new(TensorBCELoss::new()));
829 reg.register("cross_entropy", Box::new(TensorCrossEntropyLoss::new()));
830 reg.register("focal", Box::new(TensorFocalLoss::new()));
831 reg.register("huber", Box::new(TensorHuberLoss::new()));
832 reg.register("kl_div", Box::new(TensorKLDivLoss::new()));
833 reg.register(
834 "cosine_embedding",
835 Box::new(TensorCosineEmbeddingLoss::new()),
836 );
837 reg
838 }
839
840 pub fn register(&mut self, name: impl Into<String>, loss: Box<dyn TensorLoss>) {
842 self.losses.insert(name.into(), loss);
843 }
844
845 pub fn compute(
849 &self,
850 name: &str,
851 pred: &ArrayD<f64>,
852 target: &ArrayD<f64>,
853 ) -> Result<TensorLossOutput, TensorLossError> {
854 let loss = self.losses.get(name).ok_or_else(|| {
855 TensorLossError::InvalidConfig(format!("no loss registered under name '{}'", name))
856 })?;
857 loss.compute(pred, target)
858 }
859
860 pub fn names(&self) -> Vec<&str> {
862 self.losses.keys().map(|s| s.as_str()).collect()
863 }
864
865 pub fn contains(&self, name: &str) -> bool {
867 self.losses.contains_key(name)
868 }
869}
870
871impl Default for TensorLossRegistry {
872 fn default() -> Self {
873 Self::new()
874 }
875}
876
877#[cfg(test)]
882mod tests {
883 use super::*;
884 use scirs2_core::ndarray::arr1;
885
886 fn to_arrayd(v: Vec<f64>) -> ArrayD<f64> {
887 arr1(&v).into_dyn()
888 }
889
890 #[test]
893 fn test_mse_zero_loss_identical_arrays() {
894 let a = to_arrayd(vec![1.0, 2.0, 3.0]);
895 let loss = TensorMseLoss::new().compute(&a, &a).unwrap();
896 assert!(
897 (loss.loss).abs() < 1e-10,
898 "identical arrays should yield zero loss"
899 );
900 }
901
902 #[test]
903 fn test_mse_loss_value_correct() {
904 let pred = to_arrayd(vec![1.0, 2.0]);
906 let target = to_arrayd(vec![0.0, 0.0]);
907 let out = TensorMseLoss::new().compute(&pred, &target).unwrap();
908 assert!((out.loss - 2.5).abs() < 1e-10);
909 }
910
911 #[test]
912 fn test_mse_gradient_shape() {
913 let pred = to_arrayd(vec![1.0, 2.0, 3.0]);
914 let target = to_arrayd(vec![0.0, 0.0, 0.0]);
915 let out = TensorMseLoss::new().compute(&pred, &target).unwrap();
916 let grad = out.grad.unwrap();
917 assert_eq!(grad.shape(), pred.shape());
918 }
919
920 #[test]
921 fn test_mse_gradient_direction() {
922 let pred = to_arrayd(vec![3.0, 2.0]);
924 let target = to_arrayd(vec![1.0, 1.0]);
925 let out = TensorMseLoss::new().compute(&pred, &target).unwrap();
926 let grad = out.grad.unwrap();
927 for &g in grad.iter() {
928 assert!(g > 0.0, "gradient should be positive when pred > target");
929 }
930 }
931
932 #[test]
935 fn test_bce_perfect_prediction_near_zero() {
936 let pred = to_arrayd(vec![0.9999, 0.0001]);
938 let target = to_arrayd(vec![1.0, 0.0]);
939 let out = TensorBCELoss::new().compute(&pred, &target).unwrap();
940 assert!(out.loss < 1e-3, "near-perfect predictions → near-zero loss");
941 }
942
943 #[test]
944 fn test_bce_gradient_shape() {
945 let pred = to_arrayd(vec![0.5, 0.7]);
946 let target = to_arrayd(vec![1.0, 0.0]);
947 let out = TensorBCELoss::new().compute(&pred, &target).unwrap();
948 let grad = out.grad.unwrap();
949 assert_eq!(grad.shape(), pred.shape());
950 }
951
952 #[test]
955 fn test_cross_entropy_uniform_target() {
956 let eps = 1e-8_f64;
959 let p = 1.0_f64 / 3.0;
960 let pred = to_arrayd(vec![p; 3]);
961 let target = to_arrayd(vec![p; 3]);
962 let out = TensorCrossEntropyLoss::new()
963 .compute(&pred, &target)
964 .unwrap();
965 let expected = -(p * (p + eps).ln());
967 assert!(
968 (out.loss - expected).abs() < 1e-6,
969 "expected {}, got {}",
970 expected,
971 out.loss
972 );
973 }
974
975 #[test]
976 fn test_cross_entropy_label_smoothing() {
977 let pred = to_arrayd(vec![0.9, 0.05, 0.05]);
979 let target = to_arrayd(vec![1.0, 0.0, 0.0]);
980
981 let no_smooth = TensorCrossEntropyLoss::new()
982 .compute(&pred, &target)
983 .unwrap();
984
985 let with_smooth = TensorCrossEntropyLoss {
986 label_smoothing: 0.1,
987 ..TensorCrossEntropyLoss::new()
988 }
989 .compute(&pred, &target)
990 .unwrap();
991
992 assert!(
993 (no_smooth.loss - with_smooth.loss).abs() > 1e-6,
994 "label smoothing should change the loss"
995 );
996 }
997
998 #[test]
1001 fn test_focal_gamma_zero_equals_bce() {
1002 let pred = to_arrayd(vec![0.7, 0.3, 0.8]);
1004 let target = to_arrayd(vec![1.0, 0.0, 1.0]);
1005
1006 let focal = TensorFocalLoss::with_gamma(0.0)
1007 .compute(&pred, &target)
1008 .unwrap();
1009 let bce = TensorBCELoss::new().compute(&pred, &target).unwrap();
1010
1011 assert!(
1012 (focal.loss - bce.loss).abs() < 1e-6,
1013 "focal(gamma=0) ≈ BCE, got focal={} bce={}",
1014 focal.loss,
1015 bce.loss
1016 );
1017 }
1018
1019 #[test]
1020 fn test_focal_high_confidence_downweighted() {
1021 let pred_high = to_arrayd(vec![0.99]);
1023 let pred_low = to_arrayd(vec![0.6]);
1024 let target = to_arrayd(vec![1.0]);
1025
1026 let focal = TensorFocalLoss::new(); let out_high = focal.compute(&pred_high, &target).unwrap();
1028 let out_low = focal.compute(&pred_low, &target).unwrap();
1029 assert!(
1030 out_high.loss < out_low.loss,
1031 "high-confidence correct prediction should be downweighted"
1032 );
1033 }
1034
1035 #[test]
1038 fn test_huber_small_error_quadratic() {
1039 let pred = to_arrayd(vec![0.5]);
1041 let target = to_arrayd(vec![0.0]);
1042 let out = TensorHuberLoss::new().compute(&pred, &target).unwrap();
1043 assert!((out.loss - 0.125).abs() < 1e-10);
1044 }
1045
1046 #[test]
1047 fn test_huber_large_error_linear() {
1048 let pred = to_arrayd(vec![2.0]);
1050 let target = to_arrayd(vec![0.0]);
1051 let out = TensorHuberLoss::new().compute(&pred, &target).unwrap();
1052 assert!((out.loss - 1.5).abs() < 1e-10);
1053 }
1054
1055 #[test]
1058 fn test_kl_div_identical_distributions_zero() {
1059 let p = to_arrayd(vec![0.3, 0.5, 0.2]);
1060 let out = TensorKLDivLoss::new().compute(&p, &p).unwrap();
1061 assert!(out.loss.abs() < 1e-6);
1063 }
1064
1065 #[test]
1066 fn test_kl_div_gradient_shape() {
1067 let pred = to_arrayd(vec![0.3, 0.5, 0.2]);
1068 let target = to_arrayd(vec![0.4, 0.4, 0.2]);
1069 let out = TensorKLDivLoss::new().compute(&pred, &target).unwrap();
1070 let grad = out.grad.unwrap();
1071 assert_eq!(grad.shape(), pred.shape());
1072 }
1073
1074 #[test]
1077 fn test_cosine_parallel_loss_zero() {
1078 let pred = to_arrayd(vec![1.0, 0.0, 0.0]);
1080 let target = to_arrayd(vec![2.0, 0.0, 0.0]); let out = TensorCosineEmbeddingLoss::new()
1082 .compute(&pred, &target)
1083 .unwrap();
1084 assert!(out.loss.abs() < 1e-6, "parallel vectors → loss ≈ 0");
1085 }
1086
1087 #[test]
1088 fn test_cosine_orthogonal_loss_one() {
1089 let pred = to_arrayd(vec![1.0, 0.0]);
1091 let target = to_arrayd(vec![0.0, 1.0]);
1092 let out = TensorCosineEmbeddingLoss::new()
1093 .compute(&pred, &target)
1094 .unwrap();
1095 assert!(
1096 (out.loss - 1.0).abs() < 1e-6,
1097 "orthogonal vectors → loss ≈ 1"
1098 );
1099 }
1100
1101 #[test]
1104 fn test_reduction_sum_vs_mean() {
1105 let pred = to_arrayd(vec![1.0, 2.0, 3.0]);
1106 let target = to_arrayd(vec![0.0, 0.0, 0.0]);
1107
1108 let mean_loss = TensorMseLoss::with_config(TensorLossConfig {
1109 reduction: LossReduction::Mean,
1110 ..Default::default()
1111 })
1112 .compute(&pred, &target)
1113 .unwrap();
1114
1115 let sum_loss = TensorMseLoss::with_config(TensorLossConfig {
1116 reduction: LossReduction::Sum,
1117 ..Default::default()
1118 })
1119 .compute(&pred, &target)
1120 .unwrap();
1121
1122 assert!(
1123 (sum_loss.loss - mean_loss.loss).abs() > 1e-6,
1124 "sum != mean for non-unit arrays"
1125 );
1126 }
1127
1128 #[test]
1129 fn test_reduction_none_returns_tensor() {
1130 let pred = to_arrayd(vec![1.0, 2.0]);
1131 let target = to_arrayd(vec![0.0, 0.0]);
1132
1133 let out = TensorMseLoss::with_config(TensorLossConfig {
1134 reduction: LossReduction::None,
1135 ..Default::default()
1136 })
1137 .compute(&pred, &target)
1138 .unwrap();
1139
1140 assert!(
1141 out.loss_tensor.is_some(),
1142 "None reduction should return a loss tensor"
1143 );
1144 let lt = out.loss_tensor.unwrap();
1145 assert_eq!(lt.shape(), pred.shape());
1146 }
1147
1148 #[test]
1151 fn test_registry_with_all_defaults() {
1152 let reg = TensorLossRegistry::with_all_defaults();
1153 assert_eq!(
1154 reg.names().len(),
1155 7,
1156 "registry should contain 7 built-in losses"
1157 );
1158 for name in &[
1159 "mse",
1160 "bce",
1161 "cross_entropy",
1162 "focal",
1163 "huber",
1164 "kl_div",
1165 "cosine_embedding",
1166 ] {
1167 assert!(reg.contains(name), "missing: {}", name);
1168 }
1169 }
1170
1171 #[test]
1172 fn test_registry_compute_by_name() {
1173 let reg = TensorLossRegistry::with_all_defaults();
1174 let pred = to_arrayd(vec![0.5, 0.5]);
1175 let target = to_arrayd(vec![1.0, 0.0]);
1176 let out = reg.compute("bce", &pred, &target).unwrap();
1177 assert!(
1178 out.loss > 0.0,
1179 "BCE of non-perfect prediction should be positive"
1180 );
1181 }
1182}