Skip to main content

tensorlogic_scirs_backend/
tensor_loss.rs

1//! Tensor-level loss functions operating on `ArrayD<f64>` with optional gradient output.
2//!
3//! This module provides production-ready implementations of common loss functions used
4//! in machine learning, each operating on N-dimensional tensors. Unlike scalar-level
5//! losses (see `tensorlogic-train`), these functions accept and return `ArrayD<f64>`
6//! and support configurable reductions and gradient computation.
7
8use scirs2_core::ndarray::{ArrayD, IxDyn, Zip};
9use std::collections::HashMap;
10
11// ───────────────────────────────────────────────────────────────────────────────
12// Error type
13// ───────────────────────────────────────────────────────────────────────────────
14
15/// Errors that can occur during tensor-level loss computation.
16#[derive(Debug, Clone)]
17pub enum TensorLossError {
18    /// The prediction and target tensors have different shapes.
19    ShapeMismatch {
20        expected: Vec<usize>,
21        got: Vec<usize>,
22    },
23    /// The target tensor contains an invalid value (e.g. out of `[0,1]` for BCE).
24    InvalidTarget(String),
25    /// A division-by-zero was encountered (e.g. zero-norm vector in cosine loss).
26    DivisionByZero,
27    /// The input tensor has no elements.
28    EmptyInput,
29    /// The loss was configured with an invalid parameter value.
30    InvalidConfig(String),
31}
32
33impl std::fmt::Display for TensorLossError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            Self::ShapeMismatch { expected, got } => {
37                write!(f, "shape mismatch: expected {:?}, got {:?}", expected, got)
38            }
39            Self::InvalidTarget(msg) => write!(f, "invalid target: {}", msg),
40            Self::DivisionByZero => write!(f, "division by zero encountered"),
41            Self::EmptyInput => write!(f, "input tensor is empty"),
42            Self::InvalidConfig(msg) => write!(f, "invalid configuration: {}", msg),
43        }
44    }
45}
46
47impl std::error::Error for TensorLossError {}
48
49// ───────────────────────────────────────────────────────────────────────────────
50// Reduction modes
51// ───────────────────────────────────────────────────────────────────────────────
52
53/// How to aggregate element-wise losses into a scalar.
54#[derive(Debug, Clone, PartialEq)]
55pub enum LossReduction {
56    /// Divide the summed loss by the number of elements.
57    Mean,
58    /// Sum all element-wise losses.
59    Sum,
60    /// Return the element-wise loss tensor without any aggregation.
61    None,
62}
63
64// ───────────────────────────────────────────────────────────────────────────────
65// Output type
66// ───────────────────────────────────────────────────────────────────────────────
67
68/// The result of computing a tensor-level loss.
69#[derive(Debug, Clone)]
70pub struct TensorLossOutput {
71    /// Scalar loss value. When `reduction == None` this is `0.0`.
72    pub loss: f64,
73    /// Element-wise loss tensor. Present only when `reduction == None`.
74    pub loss_tensor: Option<ArrayD<f64>>,
75    /// Gradient of the loss with respect to `pred`. Present when `compute_grad == true`.
76    pub grad: Option<ArrayD<f64>>,
77}
78
79// ───────────────────────────────────────────────────────────────────────────────
80// Trait
81// ───────────────────────────────────────────────────────────────────────────────
82
83/// Trait implemented by all tensor-level loss functions.
84pub trait TensorLoss: std::fmt::Debug {
85    /// Compute the loss (and optionally the gradient) between `pred` and `target`.
86    fn compute(
87        &self,
88        pred: &ArrayD<f64>,
89        target: &ArrayD<f64>,
90    ) -> Result<TensorLossOutput, TensorLossError>;
91
92    /// Human-readable name used by the registry.
93    fn name(&self) -> &'static str;
94}
95
96// ───────────────────────────────────────────────────────────────────────────────
97// Shared configuration
98// ───────────────────────────────────────────────────────────────────────────────
99
100/// Configuration options shared by all built-in loss functions.
101#[derive(Debug, Clone)]
102pub struct TensorLossConfig {
103    /// How to reduce the element-wise losses to a scalar.
104    pub reduction: LossReduction,
105    /// Whether to compute and return the gradient w.r.t. predictions.
106    pub compute_grad: bool,
107    /// Small constant for numerical stability (default `1e-8`).
108    pub epsilon: f64,
109}
110
111impl Default for TensorLossConfig {
112    fn default() -> Self {
113        Self {
114            reduction: LossReduction::Mean,
115            compute_grad: true,
116            epsilon: 1e-8,
117        }
118    }
119}
120
121// ───────────────────────────────────────────────────────────────────────────────
122// Internal helpers
123// ───────────────────────────────────────────────────────────────────────────────
124
125/// Validate that `pred` and `target` have identical shapes and are non-empty.
126fn validate_shapes(pred: &ArrayD<f64>, target: &ArrayD<f64>) -> Result<usize, TensorLossError> {
127    let n = pred.len();
128    if n == 0 {
129        return Err(TensorLossError::EmptyInput);
130    }
131    if pred.shape() != target.shape() {
132        return Err(TensorLossError::ShapeMismatch {
133            expected: pred.shape().to_vec(),
134            got: target.shape().to_vec(),
135        });
136    }
137    Ok(n)
138}
139
140/// Apply a reduction to an element-wise loss tensor and an element-wise gradient.
141fn apply_reduction(
142    loss_elem: ArrayD<f64>,
143    grad_elem: Option<ArrayD<f64>>,
144    reduction: &LossReduction,
145    n: usize,
146) -> TensorLossOutput {
147    match reduction {
148        LossReduction::None => TensorLossOutput {
149            loss: 0.0,
150            loss_tensor: Some(loss_elem),
151            grad: grad_elem,
152        },
153        LossReduction::Sum => {
154            let loss = loss_elem.sum();
155            TensorLossOutput {
156                loss,
157                loss_tensor: None,
158                grad: grad_elem,
159            }
160        }
161        LossReduction::Mean => {
162            let loss = loss_elem.sum() / n as f64;
163            TensorLossOutput {
164                loss,
165                loss_tensor: None,
166                grad: grad_elem,
167            }
168        }
169    }
170}
171
172// ───────────────────────────────────────────────────────────────────────────────
173// MSE Loss
174// ───────────────────────────────────────────────────────────────────────────────
175
176/// Mean Squared Error loss: `mean((pred - target)^2)`.
177///
178/// Gradient: `2 * (pred - target) / N` (for Mean reduction).
179#[derive(Debug, Clone)]
180pub struct TensorMseLoss {
181    pub config: TensorLossConfig,
182}
183
184impl TensorMseLoss {
185    /// Create with default configuration (Mean reduction, gradient enabled).
186    pub fn new() -> Self {
187        Self {
188            config: TensorLossConfig::default(),
189        }
190    }
191
192    /// Create with a custom configuration.
193    pub fn with_config(config: TensorLossConfig) -> Self {
194        Self { config }
195    }
196}
197
198impl Default for TensorMseLoss {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204impl TensorLoss for TensorMseLoss {
205    fn name(&self) -> &'static str {
206        "mse"
207    }
208
209    fn compute(
210        &self,
211        pred: &ArrayD<f64>,
212        target: &ArrayD<f64>,
213    ) -> Result<TensorLossOutput, TensorLossError> {
214        let n = validate_shapes(pred, target)?;
215
216        let diff = pred - target;
217        let loss_elem = diff.mapv(|x| x * x);
218
219        let grad = if self.config.compute_grad {
220            let scale = match self.config.reduction {
221                LossReduction::Mean => 2.0 / n as f64,
222                LossReduction::Sum | LossReduction::None => 2.0,
223            };
224            Some(diff.mapv(|x| x * scale))
225        } else {
226            None
227        };
228
229        Ok(apply_reduction(loss_elem, grad, &self.config.reduction, n))
230    }
231}
232
233// ───────────────────────────────────────────────────────────────────────────────
234// Binary Cross-Entropy Loss
235// ───────────────────────────────────────────────────────────────────────────────
236
237/// Binary Cross-Entropy loss: `-[t*log(p) + (1-t)*log(1-p)]`.
238///
239/// Predictions are clamped to `[eps, 1-eps]` for numerical stability.
240/// Gradient: `-(t/p - (1-t)/(1-p))`.
241#[derive(Debug, Clone)]
242pub struct TensorBCELoss {
243    pub config: TensorLossConfig,
244}
245
246impl TensorBCELoss {
247    /// Create with default configuration.
248    pub fn new() -> Self {
249        Self {
250            config: TensorLossConfig::default(),
251        }
252    }
253}
254
255impl Default for TensorBCELoss {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261impl TensorLoss for TensorBCELoss {
262    fn name(&self) -> &'static str {
263        "bce"
264    }
265
266    fn compute(
267        &self,
268        pred: &ArrayD<f64>,
269        target: &ArrayD<f64>,
270    ) -> Result<TensorLossOutput, TensorLossError> {
271        let n = validate_shapes(pred, target)?;
272        let eps = self.config.epsilon;
273
274        // Clamp predictions for numerical stability
275        let p = pred.mapv(|x| x.clamp(eps, 1.0 - eps));
276
277        let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
278        let mut grad_elem = if self.config.compute_grad {
279            Some(ArrayD::zeros(IxDyn(pred.shape())))
280        } else {
281            None
282        };
283
284        Zip::from(&mut loss_elem)
285            .and(&p)
286            .and(target)
287            .for_each(|l, &pi, &ti| {
288                *l = -(ti * pi.ln() + (1.0 - ti) * (1.0 - pi).ln());
289            });
290
291        if let Some(ref mut g) = grad_elem {
292            Zip::from(g).and(&p).and(target).for_each(|gi, &pi, &ti| {
293                *gi = -(ti / pi - (1.0 - ti) / (1.0 - pi));
294            });
295        }
296
297        Ok(apply_reduction(
298            loss_elem,
299            grad_elem,
300            &self.config.reduction,
301            n,
302        ))
303    }
304}
305
306// ───────────────────────────────────────────────────────────────────────────────
307// Categorical Cross-Entropy Loss
308// ───────────────────────────────────────────────────────────────────────────────
309
310/// Categorical Cross-Entropy loss: `-sum(t * log(p + eps))`.
311///
312/// Optionally applies softmax to predictions before computing the loss,
313/// and supports label smoothing.
314#[derive(Debug, Clone)]
315pub struct TensorCrossEntropyLoss {
316    pub config: TensorLossConfig,
317    /// Label smoothing coefficient in `[0, 1)`. `0.0` means no smoothing.
318    pub label_smoothing: f64,
319    /// If `true`, apply a numerically stable softmax to predictions first.
320    pub apply_softmax: bool,
321}
322
323impl TensorCrossEntropyLoss {
324    /// Create with default configuration, no label smoothing, no softmax.
325    pub fn new() -> Self {
326        Self {
327            config: TensorLossConfig::default(),
328            label_smoothing: 0.0,
329            apply_softmax: false,
330        }
331    }
332}
333
334impl Default for TensorCrossEntropyLoss {
335    fn default() -> Self {
336        Self::new()
337    }
338}
339
340/// Numerically stable softmax along the last dimension of a flat tensor.
341fn softmax_flat(logits: &ArrayD<f64>) -> ArrayD<f64> {
342    let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
343    let shifted = logits.mapv(|x| (x - max_val).exp());
344    let sum = shifted.sum();
345    if sum == 0.0 {
346        shifted
347    } else {
348        shifted.mapv(|x| x / sum)
349    }
350}
351
352impl TensorLoss for TensorCrossEntropyLoss {
353    fn name(&self) -> &'static str {
354        "cross_entropy"
355    }
356
357    fn compute(
358        &self,
359        pred: &ArrayD<f64>,
360        target: &ArrayD<f64>,
361    ) -> Result<TensorLossOutput, TensorLossError> {
362        let n = validate_shapes(pred, target)?;
363        let eps = self.config.epsilon;
364        let k = n as f64;
365
366        // Optional softmax on predictions
367        let p = if self.apply_softmax {
368            softmax_flat(pred)
369        } else {
370            pred.clone()
371        };
372
373        // Label smoothing
374        let t_smooth = if self.label_smoothing > 0.0 {
375            let ls = self.label_smoothing;
376            target.mapv(|ti| ti * (1.0 - ls) + ls / k)
377        } else {
378            target.clone()
379        };
380
381        let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
382        Zip::from(&mut loss_elem)
383            .and(&p)
384            .and(&t_smooth)
385            .for_each(|l, &pi, &ti| {
386                *l = -(ti * (pi + eps).ln());
387            });
388
389        let grad = if self.config.compute_grad {
390            // Gradient of -sum(t * log(p + eps)) w.r.t. p is -t/(p+eps)
391            let mut g = ArrayD::zeros(IxDyn(pred.shape()));
392            Zip::from(&mut g)
393                .and(&p)
394                .and(&t_smooth)
395                .for_each(|gi, &pi, &ti| {
396                    *gi = -ti / (pi + eps);
397                });
398            // If softmax was applied, chain through softmax Jacobian: g = p - t_smooth
399            if self.apply_softmax {
400                Some((&p) - &t_smooth)
401            } else {
402                Some(g)
403            }
404        } else {
405            None
406        };
407
408        Ok(apply_reduction(loss_elem, grad, &self.config.reduction, n))
409    }
410}
411
412// ───────────────────────────────────────────────────────────────────────────────
413// Focal Loss
414// ───────────────────────────────────────────────────────────────────────────────
415
416/// Focal loss for binary classification: `-(1 - p_t)^gamma * log(p_t + eps)`.
417///
418/// Downweights easy examples so the model focuses on hard ones.
419#[derive(Debug, Clone)]
420pub struct TensorFocalLoss {
421    pub config: TensorLossConfig,
422    /// Focusing parameter (default `2.0`). Higher values increase focus on hard examples.
423    pub gamma: f64,
424    /// Optional class-balance weight applied to the positive class.
425    pub alpha: Option<f64>,
426}
427
428impl TensorFocalLoss {
429    /// Create with default configuration and `gamma = 2.0`.
430    pub fn new() -> Self {
431        Self {
432            config: TensorLossConfig::default(),
433            gamma: 2.0,
434            alpha: None,
435        }
436    }
437
438    /// Create with a custom `gamma` value.
439    pub fn with_gamma(gamma: f64) -> Self {
440        Self {
441            config: TensorLossConfig::default(),
442            gamma,
443            alpha: None,
444        }
445    }
446}
447
448impl Default for TensorFocalLoss {
449    fn default() -> Self {
450        Self::new()
451    }
452}
453
454impl TensorLoss for TensorFocalLoss {
455    fn name(&self) -> &'static str {
456        "focal"
457    }
458
459    fn compute(
460        &self,
461        pred: &ArrayD<f64>,
462        target: &ArrayD<f64>,
463    ) -> Result<TensorLossOutput, TensorLossError> {
464        let n = validate_shapes(pred, target)?;
465        let eps = self.config.epsilon;
466        let gamma = self.gamma;
467
468        // p clamped for safety
469        let p = pred.mapv(|x| x.clamp(eps, 1.0 - eps));
470
471        let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
472        let mut grad_elem = if self.config.compute_grad {
473            Some(ArrayD::zeros(IxDyn(pred.shape())))
474        } else {
475            None
476        };
477
478        Zip::from(&mut loss_elem)
479            .and(&p)
480            .and(target)
481            .for_each(|l, &pi, &ti| {
482                // p_t = p if target == 1, else (1 - p)
483                let p_t = if ti > 0.5 { pi } else { 1.0 - pi };
484                let modulator = (1.0 - p_t).powf(gamma);
485                let weight = match self.alpha {
486                    Some(a) => {
487                        if ti > 0.5 {
488                            a
489                        } else {
490                            1.0 - a
491                        }
492                    }
493                    None => 1.0,
494                };
495                *l = -weight * modulator * (p_t + eps).ln();
496            });
497
498        if let Some(ref mut g) = grad_elem {
499            Zip::from(g).and(&p).and(target).for_each(|gi, &pi, &ti| {
500                let p_t = if ti > 0.5 { pi } else { 1.0 - pi };
501                let sign = if ti > 0.5 { 1.0_f64 } else { -1.0_f64 };
502                let modulator = (1.0 - p_t).powf(gamma);
503                let weight = match self.alpha {
504                    Some(a) => {
505                        if ti > 0.5 {
506                            a
507                        } else {
508                            1.0 - a
509                        }
510                    }
511                    None => 1.0,
512                };
513                // d/dp_t [ -(1-p_t)^g * ln(p_t) ]
514                //   = gamma*(1-p_t)^(g-1)*ln(p_t) - (1-p_t)^g / p_t
515                let term1 = if gamma > 0.0 {
516                    gamma * (1.0 - p_t).powf(gamma - 1.0) * (p_t + eps).ln()
517                } else {
518                    0.0
519                };
520                let term2 = modulator / (p_t + eps);
521                // chain: dp_t/dp = sign
522                *gi = -weight * (term1 - term2) * sign;
523            });
524        }
525
526        Ok(apply_reduction(
527            loss_elem,
528            grad_elem,
529            &self.config.reduction,
530            n,
531        ))
532    }
533}
534
535// ───────────────────────────────────────────────────────────────────────────────
536// Huber Loss
537// ───────────────────────────────────────────────────────────────────────────────
538
539/// Huber (Smooth L1) loss.
540///
541/// For element-wise absolute error `|x|`:
542/// - If `|x| < delta`: `0.5 * x^2 / delta`
543/// - Otherwise: `|x| - 0.5 * delta`
544///
545/// Gradient: `sign(x) * min(|x|/delta, 1)`.
546#[derive(Debug, Clone)]
547pub struct TensorHuberLoss {
548    pub config: TensorLossConfig,
549    /// Threshold between quadratic and linear regime (default `1.0`).
550    pub delta: f64,
551}
552
553impl TensorHuberLoss {
554    /// Create with default configuration and `delta = 1.0`.
555    pub fn new() -> Self {
556        Self {
557            config: TensorLossConfig::default(),
558            delta: 1.0,
559        }
560    }
561
562    /// Create with a custom `delta` value.
563    pub fn with_delta(delta: f64) -> Self {
564        Self {
565            config: TensorLossConfig::default(),
566            delta,
567        }
568    }
569}
570
571impl Default for TensorHuberLoss {
572    fn default() -> Self {
573        Self::new()
574    }
575}
576
577impl TensorLoss for TensorHuberLoss {
578    fn name(&self) -> &'static str {
579        "huber"
580    }
581
582    fn compute(
583        &self,
584        pred: &ArrayD<f64>,
585        target: &ArrayD<f64>,
586    ) -> Result<TensorLossOutput, TensorLossError> {
587        let n = validate_shapes(pred, target)?;
588        let delta = self.delta;
589
590        if delta <= 0.0 {
591            return Err(TensorLossError::InvalidConfig(format!(
592                "delta must be positive, got {}",
593                delta
594            )));
595        }
596
597        let diff = pred - target;
598        let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
599        let mut grad_elem = if self.config.compute_grad {
600            Some(ArrayD::zeros(IxDyn(pred.shape())))
601        } else {
602            None
603        };
604
605        Zip::from(&mut loss_elem).and(&diff).for_each(|l, &d| {
606            let abs_d = d.abs();
607            if abs_d < delta {
608                *l = 0.5 * d * d / delta;
609            } else {
610                *l = abs_d - 0.5 * delta;
611            }
612        });
613
614        if let Some(ref mut g) = grad_elem {
615            Zip::from(g).and(&diff).for_each(|gi, &d| {
616                let abs_d = d.abs();
617                let sign = if d > 0.0 {
618                    1.0
619                } else if d < 0.0 {
620                    -1.0
621                } else {
622                    0.0
623                };
624                *gi = sign * (abs_d / delta).min(1.0);
625            });
626        }
627
628        Ok(apply_reduction(
629            loss_elem,
630            grad_elem,
631            &self.config.reduction,
632            n,
633        ))
634    }
635}
636
637// ───────────────────────────────────────────────────────────────────────────────
638// KL Divergence Loss
639// ───────────────────────────────────────────────────────────────────────────────
640
641/// Kullback-Leibler Divergence: `sum(target * log(target / (pred + eps)))`.
642///
643/// Elements where `target ≈ 0` contribute zero (following the convention `0 * log(0) = 0`).
644#[derive(Debug, Clone)]
645pub struct TensorKLDivLoss {
646    pub config: TensorLossConfig,
647}
648
649impl TensorKLDivLoss {
650    /// Create with default configuration.
651    pub fn new() -> Self {
652        Self {
653            config: TensorLossConfig::default(),
654        }
655    }
656}
657
658impl Default for TensorKLDivLoss {
659    fn default() -> Self {
660        Self::new()
661    }
662}
663
664impl TensorLoss for TensorKLDivLoss {
665    fn name(&self) -> &'static str {
666        "kl_div"
667    }
668
669    fn compute(
670        &self,
671        pred: &ArrayD<f64>,
672        target: &ArrayD<f64>,
673    ) -> Result<TensorLossOutput, TensorLossError> {
674        let n = validate_shapes(pred, target)?;
675        let eps = self.config.epsilon;
676
677        let mut loss_elem = ArrayD::zeros(IxDyn(pred.shape()));
678        let mut grad_elem = if self.config.compute_grad {
679            Some(ArrayD::zeros(IxDyn(pred.shape())))
680        } else {
681            None
682        };
683
684        Zip::from(&mut loss_elem)
685            .and(pred)
686            .and(target)
687            .for_each(|l, &pi, &ti| {
688                if ti > eps {
689                    let p_safe = pi.max(eps);
690                    // KL(T || P) = T * (ln T - ln P)
691                    *l = ti * (ti.ln() - p_safe.ln());
692                }
693                // else: 0 * log(0) = 0, leave as 0
694            });
695
696        if let Some(ref mut g) = grad_elem {
697            // d KL / d p_i = -t_i / (p_i + eps)
698            Zip::from(g).and(pred).and(target).for_each(|gi, &pi, &ti| {
699                if ti > eps {
700                    *gi = -ti / (pi + eps);
701                }
702            });
703        }
704
705        Ok(apply_reduction(
706            loss_elem,
707            grad_elem,
708            &self.config.reduction,
709            n,
710        ))
711    }
712}
713
714// ───────────────────────────────────────────────────────────────────────────────
715// Cosine Embedding Loss
716// ───────────────────────────────────────────────────────────────────────────────
717
718/// Cosine Embedding loss: `1 - cosine_similarity(pred, target)`.
719///
720/// Treats inputs as flat vectors (all dimensions collapsed).
721#[derive(Debug, Clone)]
722pub struct TensorCosineEmbeddingLoss {
723    pub config: TensorLossConfig,
724}
725
726impl TensorCosineEmbeddingLoss {
727    /// Create with default configuration.
728    pub fn new() -> Self {
729        Self {
730            config: TensorLossConfig::default(),
731        }
732    }
733}
734
735impl Default for TensorCosineEmbeddingLoss {
736    fn default() -> Self {
737        Self::new()
738    }
739}
740
741impl TensorLoss for TensorCosineEmbeddingLoss {
742    fn name(&self) -> &'static str {
743        "cosine_embedding"
744    }
745
746    fn compute(
747        &self,
748        pred: &ArrayD<f64>,
749        target: &ArrayD<f64>,
750    ) -> Result<TensorLossOutput, TensorLossError> {
751        let n = validate_shapes(pred, target)?;
752        let eps = self.config.epsilon;
753
754        let dot: f64 = pred.iter().zip(target.iter()).map(|(p, t)| p * t).sum();
755        let norm_p: f64 = pred.iter().map(|x| x * x).sum::<f64>().sqrt();
756        let norm_t: f64 = target.iter().map(|x| x * x).sum::<f64>().sqrt();
757        let denom = norm_p * norm_t + eps;
758
759        let similarity = dot / denom;
760        let scalar_loss = 1.0 - similarity;
761
762        // For the gradient: d(1 - cos) / d(pred_i)
763        //   = -(d cos / d pred_i)
764        //   = -(target_i / denom - dot * pred_i / (norm_p^2 * denom + eps))
765        let grad = if self.config.compute_grad {
766            let mut g = ArrayD::zeros(IxDyn(pred.shape()));
767            let norm_p_sq = norm_p * norm_p + eps;
768            Zip::from(&mut g)
769                .and(pred)
770                .and(target)
771                .for_each(|gi, &pi, &ti| {
772                    let d_sim = ti / denom - dot * pi / (norm_p_sq * denom);
773                    *gi = -d_sim;
774                });
775            Some(g)
776        } else {
777            None
778        };
779
780        // Cosine loss is inherently a single scalar; build a uniform tensor for consistency.
781        match self.config.reduction {
782            LossReduction::None => {
783                // Return an element-wise tensor filled with scalar_loss / n
784                // (so it sums to scalar_loss).
785                let loss_tensor = ArrayD::from_elem(IxDyn(pred.shape()), scalar_loss / n as f64);
786                Ok(TensorLossOutput {
787                    loss: 0.0,
788                    loss_tensor: Some(loss_tensor),
789                    grad,
790                })
791            }
792            LossReduction::Mean | LossReduction::Sum => Ok(TensorLossOutput {
793                loss: scalar_loss,
794                loss_tensor: None,
795                grad,
796            }),
797        }
798    }
799}
800
801// ───────────────────────────────────────────────────────────────────────────────
802// Registry
803// ───────────────────────────────────────────────────────────────────────────────
804
805/// Dynamic registry for named tensor-level loss functions.
806///
807/// Use [`TensorLossRegistry::with_all_defaults`] to get a registry pre-populated
808/// with all seven built-in losses.
809#[derive(Debug)]
810pub struct TensorLossRegistry {
811    losses: HashMap<String, Box<dyn TensorLoss>>,
812}
813
814impl TensorLossRegistry {
815    /// Create an empty registry.
816    pub fn new() -> Self {
817        Self {
818            losses: HashMap::new(),
819        }
820    }
821
822    /// Create a registry pre-populated with all seven built-in losses:
823    /// `"mse"`, `"bce"`, `"cross_entropy"`, `"focal"`, `"huber"`, `"kl_div"`,
824    /// `"cosine_embedding"`.
825    pub fn with_all_defaults() -> Self {
826        let mut reg = Self::new();
827        reg.register("mse", Box::new(TensorMseLoss::new()));
828        reg.register("bce", Box::new(TensorBCELoss::new()));
829        reg.register("cross_entropy", Box::new(TensorCrossEntropyLoss::new()));
830        reg.register("focal", Box::new(TensorFocalLoss::new()));
831        reg.register("huber", Box::new(TensorHuberLoss::new()));
832        reg.register("kl_div", Box::new(TensorKLDivLoss::new()));
833        reg.register(
834            "cosine_embedding",
835            Box::new(TensorCosineEmbeddingLoss::new()),
836        );
837        reg
838    }
839
840    /// Register a loss under a name. Overwrites any previous entry with the same name.
841    pub fn register(&mut self, name: impl Into<String>, loss: Box<dyn TensorLoss>) {
842        self.losses.insert(name.into(), loss);
843    }
844
845    /// Compute a named loss.
846    ///
847    /// Returns [`TensorLossError::InvalidConfig`] if the name is not registered.
848    pub fn compute(
849        &self,
850        name: &str,
851        pred: &ArrayD<f64>,
852        target: &ArrayD<f64>,
853    ) -> Result<TensorLossOutput, TensorLossError> {
854        let loss = self.losses.get(name).ok_or_else(|| {
855            TensorLossError::InvalidConfig(format!("no loss registered under name '{}'", name))
856        })?;
857        loss.compute(pred, target)
858    }
859
860    /// Return all registered loss names (order is not guaranteed).
861    pub fn names(&self) -> Vec<&str> {
862        self.losses.keys().map(|s| s.as_str()).collect()
863    }
864
865    /// Return `true` if a loss is registered under `name`.
866    pub fn contains(&self, name: &str) -> bool {
867        self.losses.contains_key(name)
868    }
869}
870
871impl Default for TensorLossRegistry {
872    fn default() -> Self {
873        Self::new()
874    }
875}
876
877// ───────────────────────────────────────────────────────────────────────────────
878// Tests
879// ───────────────────────────────────────────────────────────────────────────────
880
881#[cfg(test)]
882mod tests {
883    use super::*;
884    use scirs2_core::ndarray::arr1;
885
886    fn to_arrayd(v: Vec<f64>) -> ArrayD<f64> {
887        arr1(&v).into_dyn()
888    }
889
890    // ── MSE ─────────────────────────────────────────────────────────────────────
891
892    #[test]
893    fn test_mse_zero_loss_identical_arrays() {
894        let a = to_arrayd(vec![1.0, 2.0, 3.0]);
895        let loss = TensorMseLoss::new().compute(&a, &a).unwrap();
896        assert!(
897            (loss.loss).abs() < 1e-10,
898            "identical arrays should yield zero loss"
899        );
900    }
901
902    #[test]
903    fn test_mse_loss_value_correct() {
904        // pred = [1, 2], target = [0, 0], mse = mean([1, 4]) = 2.5
905        let pred = to_arrayd(vec![1.0, 2.0]);
906        let target = to_arrayd(vec![0.0, 0.0]);
907        let out = TensorMseLoss::new().compute(&pred, &target).unwrap();
908        assert!((out.loss - 2.5).abs() < 1e-10);
909    }
910
911    #[test]
912    fn test_mse_gradient_shape() {
913        let pred = to_arrayd(vec![1.0, 2.0, 3.0]);
914        let target = to_arrayd(vec![0.0, 0.0, 0.0]);
915        let out = TensorMseLoss::new().compute(&pred, &target).unwrap();
916        let grad = out.grad.unwrap();
917        assert_eq!(grad.shape(), pred.shape());
918    }
919
920    #[test]
921    fn test_mse_gradient_direction() {
922        // When pred > target, gradient should be positive.
923        let pred = to_arrayd(vec![3.0, 2.0]);
924        let target = to_arrayd(vec![1.0, 1.0]);
925        let out = TensorMseLoss::new().compute(&pred, &target).unwrap();
926        let grad = out.grad.unwrap();
927        for &g in grad.iter() {
928            assert!(g > 0.0, "gradient should be positive when pred > target");
929        }
930    }
931
932    // ── BCE ─────────────────────────────────────────────────────────────────────
933
934    #[test]
935    fn test_bce_perfect_prediction_near_zero() {
936        // Perfect binary predictions → very small loss
937        let pred = to_arrayd(vec![0.9999, 0.0001]);
938        let target = to_arrayd(vec![1.0, 0.0]);
939        let out = TensorBCELoss::new().compute(&pred, &target).unwrap();
940        assert!(out.loss < 1e-3, "near-perfect predictions → near-zero loss");
941    }
942
943    #[test]
944    fn test_bce_gradient_shape() {
945        let pred = to_arrayd(vec![0.5, 0.7]);
946        let target = to_arrayd(vec![1.0, 0.0]);
947        let out = TensorBCELoss::new().compute(&pred, &target).unwrap();
948        let grad = out.grad.unwrap();
949        assert_eq!(grad.shape(), pred.shape());
950    }
951
952    // ── Cross-Entropy ────────────────────────────────────────────────────────────
953
954    #[test]
955    fn test_cross_entropy_uniform_target() {
956        // Uniform prediction and target 1/3 each.
957        // element loss = -(1/3) * ln(1/3 + eps), 3 elements, mean reduction.
958        let eps = 1e-8_f64;
959        let p = 1.0_f64 / 3.0;
960        let pred = to_arrayd(vec![p; 3]);
961        let target = to_arrayd(vec![p; 3]);
962        let out = TensorCrossEntropyLoss::new()
963            .compute(&pred, &target)
964            .unwrap();
965        // mean of 3 identical elements: -(p * ln(p + eps))
966        let expected = -(p * (p + eps).ln());
967        assert!(
968            (out.loss - expected).abs() < 1e-6,
969            "expected {}, got {}",
970            expected,
971            out.loss
972        );
973    }
974
975    #[test]
976    fn test_cross_entropy_label_smoothing() {
977        // With label smoothing, loss should differ from no-smoothing version
978        let pred = to_arrayd(vec![0.9, 0.05, 0.05]);
979        let target = to_arrayd(vec![1.0, 0.0, 0.0]);
980
981        let no_smooth = TensorCrossEntropyLoss::new()
982            .compute(&pred, &target)
983            .unwrap();
984
985        let with_smooth = TensorCrossEntropyLoss {
986            label_smoothing: 0.1,
987            ..TensorCrossEntropyLoss::new()
988        }
989        .compute(&pred, &target)
990        .unwrap();
991
992        assert!(
993            (no_smooth.loss - with_smooth.loss).abs() > 1e-6,
994            "label smoothing should change the loss"
995        );
996    }
997
998    // ── Focal ────────────────────────────────────────────────────────────────────
999
1000    #[test]
1001    fn test_focal_gamma_zero_equals_bce() {
1002        // focal loss with gamma=0 should approximate BCE
1003        let pred = to_arrayd(vec![0.7, 0.3, 0.8]);
1004        let target = to_arrayd(vec![1.0, 0.0, 1.0]);
1005
1006        let focal = TensorFocalLoss::with_gamma(0.0)
1007            .compute(&pred, &target)
1008            .unwrap();
1009        let bce = TensorBCELoss::new().compute(&pred, &target).unwrap();
1010
1011        assert!(
1012            (focal.loss - bce.loss).abs() < 1e-6,
1013            "focal(gamma=0) ≈ BCE, got focal={} bce={}",
1014            focal.loss,
1015            bce.loss
1016        );
1017    }
1018
1019    #[test]
1020    fn test_focal_high_confidence_downweighted() {
1021        // High-confidence correct prediction should have less focal loss than BCE contribution
1022        let pred_high = to_arrayd(vec![0.99]);
1023        let pred_low = to_arrayd(vec![0.6]);
1024        let target = to_arrayd(vec![1.0]);
1025
1026        let focal = TensorFocalLoss::new(); // gamma=2
1027        let out_high = focal.compute(&pred_high, &target).unwrap();
1028        let out_low = focal.compute(&pred_low, &target).unwrap();
1029        assert!(
1030            out_high.loss < out_low.loss,
1031            "high-confidence correct prediction should be downweighted"
1032        );
1033    }
1034
1035    // ── Huber ────────────────────────────────────────────────────────────────────
1036
1037    #[test]
1038    fn test_huber_small_error_quadratic() {
1039        // |x| = 0.5 < delta=1 → quadratic: 0.5 * x^2 / delta = 0.5 * 0.25 / 1 = 0.125
1040        let pred = to_arrayd(vec![0.5]);
1041        let target = to_arrayd(vec![0.0]);
1042        let out = TensorHuberLoss::new().compute(&pred, &target).unwrap();
1043        assert!((out.loss - 0.125).abs() < 1e-10);
1044    }
1045
1046    #[test]
1047    fn test_huber_large_error_linear() {
1048        // |x| = 2.0 > delta=1 → linear: |x| - 0.5*delta = 2 - 0.5 = 1.5
1049        let pred = to_arrayd(vec![2.0]);
1050        let target = to_arrayd(vec![0.0]);
1051        let out = TensorHuberLoss::new().compute(&pred, &target).unwrap();
1052        assert!((out.loss - 1.5).abs() < 1e-10);
1053    }
1054
1055    // ── KL Divergence ────────────────────────────────────────────────────────────
1056
1057    #[test]
1058    fn test_kl_div_identical_distributions_zero() {
1059        let p = to_arrayd(vec![0.3, 0.5, 0.2]);
1060        let out = TensorKLDivLoss::new().compute(&p, &p).unwrap();
1061        // KL(P||P) should be ≈ 0 (small due to eps)
1062        assert!(out.loss.abs() < 1e-6);
1063    }
1064
1065    #[test]
1066    fn test_kl_div_gradient_shape() {
1067        let pred = to_arrayd(vec![0.3, 0.5, 0.2]);
1068        let target = to_arrayd(vec![0.4, 0.4, 0.2]);
1069        let out = TensorKLDivLoss::new().compute(&pred, &target).unwrap();
1070        let grad = out.grad.unwrap();
1071        assert_eq!(grad.shape(), pred.shape());
1072    }
1073
1074    // ── Cosine Embedding ─────────────────────────────────────────────────────────
1075
1076    #[test]
1077    fn test_cosine_parallel_loss_zero() {
1078        // Same direction → cosine similarity = 1 → loss = 0
1079        let pred = to_arrayd(vec![1.0, 0.0, 0.0]);
1080        let target = to_arrayd(vec![2.0, 0.0, 0.0]); // parallel, different magnitude
1081        let out = TensorCosineEmbeddingLoss::new()
1082            .compute(&pred, &target)
1083            .unwrap();
1084        assert!(out.loss.abs() < 1e-6, "parallel vectors → loss ≈ 0");
1085    }
1086
1087    #[test]
1088    fn test_cosine_orthogonal_loss_one() {
1089        // Orthogonal vectors → cosine similarity = 0 → loss = 1
1090        let pred = to_arrayd(vec![1.0, 0.0]);
1091        let target = to_arrayd(vec![0.0, 1.0]);
1092        let out = TensorCosineEmbeddingLoss::new()
1093            .compute(&pred, &target)
1094            .unwrap();
1095        assert!(
1096            (out.loss - 1.0).abs() < 1e-6,
1097            "orthogonal vectors → loss ≈ 1"
1098        );
1099    }
1100
1101    // ── Reduction ────────────────────────────────────────────────────────────────
1102
1103    #[test]
1104    fn test_reduction_sum_vs_mean() {
1105        let pred = to_arrayd(vec![1.0, 2.0, 3.0]);
1106        let target = to_arrayd(vec![0.0, 0.0, 0.0]);
1107
1108        let mean_loss = TensorMseLoss::with_config(TensorLossConfig {
1109            reduction: LossReduction::Mean,
1110            ..Default::default()
1111        })
1112        .compute(&pred, &target)
1113        .unwrap();
1114
1115        let sum_loss = TensorMseLoss::with_config(TensorLossConfig {
1116            reduction: LossReduction::Sum,
1117            ..Default::default()
1118        })
1119        .compute(&pred, &target)
1120        .unwrap();
1121
1122        assert!(
1123            (sum_loss.loss - mean_loss.loss).abs() > 1e-6,
1124            "sum != mean for non-unit arrays"
1125        );
1126    }
1127
1128    #[test]
1129    fn test_reduction_none_returns_tensor() {
1130        let pred = to_arrayd(vec![1.0, 2.0]);
1131        let target = to_arrayd(vec![0.0, 0.0]);
1132
1133        let out = TensorMseLoss::with_config(TensorLossConfig {
1134            reduction: LossReduction::None,
1135            ..Default::default()
1136        })
1137        .compute(&pred, &target)
1138        .unwrap();
1139
1140        assert!(
1141            out.loss_tensor.is_some(),
1142            "None reduction should return a loss tensor"
1143        );
1144        let lt = out.loss_tensor.unwrap();
1145        assert_eq!(lt.shape(), pred.shape());
1146    }
1147
1148    // ── Registry ─────────────────────────────────────────────────────────────────
1149
1150    #[test]
1151    fn test_registry_with_all_defaults() {
1152        let reg = TensorLossRegistry::with_all_defaults();
1153        assert_eq!(
1154            reg.names().len(),
1155            7,
1156            "registry should contain 7 built-in losses"
1157        );
1158        for name in &[
1159            "mse",
1160            "bce",
1161            "cross_entropy",
1162            "focal",
1163            "huber",
1164            "kl_div",
1165            "cosine_embedding",
1166        ] {
1167            assert!(reg.contains(name), "missing: {}", name);
1168        }
1169    }
1170
1171    #[test]
1172    fn test_registry_compute_by_name() {
1173        let reg = TensorLossRegistry::with_all_defaults();
1174        let pred = to_arrayd(vec![0.5, 0.5]);
1175        let target = to_arrayd(vec![1.0, 0.0]);
1176        let out = reg.compute("bce", &pred, &target).unwrap();
1177        assert!(
1178            out.loss > 0.0,
1179            "BCE of non-perfect prediction should be positive"
1180        );
1181    }
1182}