Skip to main content

tensorlogic_train/
loss.rs

1//! Loss functions for training.
2//!
3//! Provides both standard ML loss functions and logical constraint-based losses.
4
5use crate::{TrainError, TrainResult};
6use scirs2_core::ndarray::{Array, ArrayView, Ix2};
7use std::fmt::Debug;
8
9/// Configuration for loss functions.
10#[derive(Debug, Clone)]
11pub struct LossConfig {
12    /// Weight for supervised loss component.
13    pub supervised_weight: f64,
14    /// Weight for constraint violation loss component.
15    pub constraint_weight: f64,
16    /// Weight for rule satisfaction loss component.
17    pub rule_weight: f64,
18    /// Temperature for soft constraint penalties.
19    pub temperature: f64,
20}
21
22impl Default for LossConfig {
23    fn default() -> Self {
24        Self {
25            supervised_weight: 1.0,
26            constraint_weight: 1.0,
27            rule_weight: 1.0,
28            temperature: 1.0,
29        }
30    }
31}
32
33/// Trait for loss functions.
34pub trait Loss: Debug {
35    /// Compute loss value.
36    fn compute(
37        &self,
38        predictions: &ArrayView<f64, Ix2>,
39        targets: &ArrayView<f64, Ix2>,
40    ) -> TrainResult<f64>;
41
42    /// Compute loss gradient with respect to predictions.
43    fn gradient(
44        &self,
45        predictions: &ArrayView<f64, Ix2>,
46        targets: &ArrayView<f64, Ix2>,
47    ) -> TrainResult<Array<f64, Ix2>>;
48
49    /// Get the name of the loss function.
50    fn name(&self) -> &str {
51        "unknown"
52    }
53}
54
55/// Cross-entropy loss for classification.
56#[derive(Debug, Clone)]
57pub struct CrossEntropyLoss {
58    /// Epsilon for numerical stability.
59    pub epsilon: f64,
60}
61
62impl Default for CrossEntropyLoss {
63    fn default() -> Self {
64        Self { epsilon: 1e-10 }
65    }
66}
67
68impl Loss for CrossEntropyLoss {
69    fn compute(
70        &self,
71        predictions: &ArrayView<f64, Ix2>,
72        targets: &ArrayView<f64, Ix2>,
73    ) -> TrainResult<f64> {
74        if predictions.shape() != targets.shape() {
75            return Err(TrainError::LossError(format!(
76                "Shape mismatch: predictions {:?} vs targets {:?}",
77                predictions.shape(),
78                targets.shape()
79            )));
80        }
81
82        let n = predictions.nrows() as f64;
83        let mut total_loss = 0.0;
84
85        for i in 0..predictions.nrows() {
86            for j in 0..predictions.ncols() {
87                let pred = predictions[[i, j]]
88                    .max(self.epsilon)
89                    .min(1.0 - self.epsilon);
90                let target = targets[[i, j]];
91                total_loss -= target * pred.ln();
92            }
93        }
94
95        Ok(total_loss / n)
96    }
97
98    fn gradient(
99        &self,
100        predictions: &ArrayView<f64, Ix2>,
101        targets: &ArrayView<f64, Ix2>,
102    ) -> TrainResult<Array<f64, Ix2>> {
103        if predictions.shape() != targets.shape() {
104            return Err(TrainError::LossError(format!(
105                "Shape mismatch: predictions {:?} vs targets {:?}",
106                predictions.shape(),
107                targets.shape()
108            )));
109        }
110
111        let n = predictions.nrows() as f64;
112        let mut grad = Array::zeros(predictions.raw_dim());
113
114        for i in 0..predictions.nrows() {
115            for j in 0..predictions.ncols() {
116                let pred = predictions[[i, j]]
117                    .max(self.epsilon)
118                    .min(1.0 - self.epsilon);
119                let target = targets[[i, j]];
120                grad[[i, j]] = -(target / pred) / n;
121            }
122        }
123
124        Ok(grad)
125    }
126}
127
128/// Mean squared error loss for regression.
129#[derive(Debug, Clone, Default)]
130pub struct MseLoss;
131
132impl Loss for MseLoss {
133    fn compute(
134        &self,
135        predictions: &ArrayView<f64, Ix2>,
136        targets: &ArrayView<f64, Ix2>,
137    ) -> TrainResult<f64> {
138        if predictions.shape() != targets.shape() {
139            return Err(TrainError::LossError(format!(
140                "Shape mismatch: predictions {:?} vs targets {:?}",
141                predictions.shape(),
142                targets.shape()
143            )));
144        }
145
146        let n = predictions.len() as f64;
147        let mut total_loss = 0.0;
148
149        for i in 0..predictions.nrows() {
150            for j in 0..predictions.ncols() {
151                let diff = predictions[[i, j]] - targets[[i, j]];
152                total_loss += diff * diff;
153            }
154        }
155
156        Ok(total_loss / n)
157    }
158
159    fn gradient(
160        &self,
161        predictions: &ArrayView<f64, Ix2>,
162        targets: &ArrayView<f64, Ix2>,
163    ) -> TrainResult<Array<f64, Ix2>> {
164        if predictions.shape() != targets.shape() {
165            return Err(TrainError::LossError(format!(
166                "Shape mismatch: predictions {:?} vs targets {:?}",
167                predictions.shape(),
168                targets.shape()
169            )));
170        }
171
172        let n = predictions.len() as f64;
173        let mut grad = Array::zeros(predictions.raw_dim());
174
175        for i in 0..predictions.nrows() {
176            for j in 0..predictions.ncols() {
177                grad[[i, j]] = 2.0 * (predictions[[i, j]] - targets[[i, j]]) / n;
178            }
179        }
180
181        Ok(grad)
182    }
183}
184
185/// Logical loss combining multiple objectives.
186#[derive(Debug)]
187pub struct LogicalLoss {
188    /// Configuration.
189    pub config: LossConfig,
190    /// Supervised loss component.
191    pub supervised_loss: Box<dyn Loss>,
192    /// Rule satisfaction components.
193    pub rule_losses: Vec<Box<dyn Loss>>,
194    /// Constraint violation components.
195    pub constraint_losses: Vec<Box<dyn Loss>>,
196}
197
198impl LogicalLoss {
199    /// Create a new logical loss.
200    pub fn new(
201        config: LossConfig,
202        supervised_loss: Box<dyn Loss>,
203        rule_losses: Vec<Box<dyn Loss>>,
204        constraint_losses: Vec<Box<dyn Loss>>,
205    ) -> Self {
206        Self {
207            config,
208            supervised_loss,
209            rule_losses,
210            constraint_losses,
211        }
212    }
213
214    /// Compute total loss with all components.
215    pub fn compute_total(
216        &self,
217        predictions: &ArrayView<f64, Ix2>,
218        targets: &ArrayView<f64, Ix2>,
219        rule_values: &[ArrayView<f64, Ix2>],
220        constraint_values: &[ArrayView<f64, Ix2>],
221    ) -> TrainResult<f64> {
222        let mut total = 0.0;
223
224        // Supervised loss
225        let supervised = self.supervised_loss.compute(predictions, targets)?;
226        total += self.config.supervised_weight * supervised;
227
228        // Rule satisfaction losses
229        if !rule_values.is_empty() && !self.rule_losses.is_empty() {
230            let expected_true = Array::ones((rule_values[0].nrows(), rule_values[0].ncols()));
231            let expected_true_view = expected_true.view();
232
233            for (rule_val, rule_loss) in rule_values.iter().zip(self.rule_losses.iter()) {
234                let rule_loss_val = rule_loss.compute(rule_val, &expected_true_view)?;
235                total += self.config.rule_weight * rule_loss_val;
236            }
237        }
238
239        // Constraint violation losses
240        if !constraint_values.is_empty() && !self.constraint_losses.is_empty() {
241            let expected_zero =
242                Array::zeros((constraint_values[0].nrows(), constraint_values[0].ncols()));
243            let expected_zero_view = expected_zero.view();
244
245            for (constraint_val, constraint_loss) in
246                constraint_values.iter().zip(self.constraint_losses.iter())
247            {
248                let constraint_loss_val =
249                    constraint_loss.compute(constraint_val, &expected_zero_view)?;
250                total += self.config.constraint_weight * constraint_loss_val;
251            }
252        }
253
254        Ok(total)
255    }
256}
257
258/// Rule satisfaction loss - measures how well rules are satisfied.
259#[derive(Debug, Clone)]
260pub struct RuleSatisfactionLoss {
261    /// Temperature for soft satisfaction.
262    pub temperature: f64,
263}
264
265impl Default for RuleSatisfactionLoss {
266    fn default() -> Self {
267        Self { temperature: 1.0 }
268    }
269}
270
271impl Loss for RuleSatisfactionLoss {
272    fn compute(
273        &self,
274        rule_values: &ArrayView<f64, Ix2>,
275        targets: &ArrayView<f64, Ix2>,
276    ) -> TrainResult<f64> {
277        if rule_values.shape() != targets.shape() {
278            return Err(TrainError::LossError(format!(
279                "Shape mismatch: rule_values {:?} vs targets {:?}",
280                rule_values.shape(),
281                targets.shape()
282            )));
283        }
284
285        let n = rule_values.len() as f64;
286        let mut total_loss = 0.0;
287
288        // Penalize deviations from expected rule satisfaction (typically 1.0)
289        for i in 0..rule_values.nrows() {
290            for j in 0..rule_values.ncols() {
291                let diff = targets[[i, j]] - rule_values[[i, j]];
292                // Soft penalty with temperature
293                total_loss += (diff / self.temperature).powi(2);
294            }
295        }
296
297        Ok(total_loss / n)
298    }
299
300    fn gradient(
301        &self,
302        rule_values: &ArrayView<f64, Ix2>,
303        targets: &ArrayView<f64, Ix2>,
304    ) -> TrainResult<Array<f64, Ix2>> {
305        if rule_values.shape() != targets.shape() {
306            return Err(TrainError::LossError(format!(
307                "Shape mismatch: rule_values {:?} vs targets {:?}",
308                rule_values.shape(),
309                targets.shape()
310            )));
311        }
312
313        let n = rule_values.len() as f64;
314        let mut grad = Array::zeros(rule_values.raw_dim());
315
316        for i in 0..rule_values.nrows() {
317            for j in 0..rule_values.ncols() {
318                let diff = targets[[i, j]] - rule_values[[i, j]];
319                grad[[i, j]] = -2.0 * diff / (self.temperature * self.temperature * n);
320            }
321        }
322
323        Ok(grad)
324    }
325}
326
327/// Constraint violation loss - penalizes constraint violations.
328#[derive(Debug, Clone)]
329pub struct ConstraintViolationLoss {
330    /// Penalty weight for violations.
331    pub penalty_weight: f64,
332}
333
334impl Default for ConstraintViolationLoss {
335    fn default() -> Self {
336        Self {
337            penalty_weight: 10.0,
338        }
339    }
340}
341
342impl Loss for ConstraintViolationLoss {
343    fn compute(
344        &self,
345        constraint_values: &ArrayView<f64, Ix2>,
346        targets: &ArrayView<f64, Ix2>,
347    ) -> TrainResult<f64> {
348        if constraint_values.shape() != targets.shape() {
349            return Err(TrainError::LossError(format!(
350                "Shape mismatch: constraint_values {:?} vs targets {:?}",
351                constraint_values.shape(),
352                targets.shape()
353            )));
354        }
355
356        let n = constraint_values.len() as f64;
357        let mut total_loss = 0.0;
358
359        // Penalize any positive violation (constraint_values should be <= 0)
360        for i in 0..constraint_values.nrows() {
361            for j in 0..constraint_values.ncols() {
362                let violation = (constraint_values[[i, j]] - targets[[i, j]]).max(0.0);
363                total_loss += self.penalty_weight * violation * violation;
364            }
365        }
366
367        Ok(total_loss / n)
368    }
369
370    fn gradient(
371        &self,
372        constraint_values: &ArrayView<f64, Ix2>,
373        targets: &ArrayView<f64, Ix2>,
374    ) -> TrainResult<Array<f64, Ix2>> {
375        if constraint_values.shape() != targets.shape() {
376            return Err(TrainError::LossError(format!(
377                "Shape mismatch: constraint_values {:?} vs targets {:?}",
378                constraint_values.shape(),
379                targets.shape()
380            )));
381        }
382
383        let n = constraint_values.len() as f64;
384        let mut grad = Array::zeros(constraint_values.raw_dim());
385
386        for i in 0..constraint_values.nrows() {
387            for j in 0..constraint_values.ncols() {
388                let violation = constraint_values[[i, j]] - targets[[i, j]];
389                if violation > 0.0 {
390                    grad[[i, j]] = 2.0 * self.penalty_weight * violation / n;
391                }
392            }
393        }
394
395        Ok(grad)
396    }
397}
398
399/// Focal loss for addressing class imbalance.
400/// Reference: Lin et al., "Focal Loss for Dense Object Detection"
401#[derive(Debug, Clone)]
402pub struct FocalLoss {
403    /// Alpha weighting factor for positive class (range: [0, 1]).
404    pub alpha: f64,
405    /// Gamma focusing parameter (typically 2.0).
406    pub gamma: f64,
407    /// Epsilon for numerical stability.
408    pub epsilon: f64,
409}
410
411impl Default for FocalLoss {
412    fn default() -> Self {
413        Self {
414            alpha: 0.25,
415            gamma: 2.0,
416            epsilon: 1e-10,
417        }
418    }
419}
420
421impl Loss for FocalLoss {
422    fn compute(
423        &self,
424        predictions: &ArrayView<f64, Ix2>,
425        targets: &ArrayView<f64, Ix2>,
426    ) -> TrainResult<f64> {
427        if predictions.shape() != targets.shape() {
428            return Err(TrainError::LossError(format!(
429                "Shape mismatch: predictions {:?} vs targets {:?}",
430                predictions.shape(),
431                targets.shape()
432            )));
433        }
434
435        let n = predictions.nrows() as f64;
436        let mut total_loss = 0.0;
437
438        for i in 0..predictions.nrows() {
439            for j in 0..predictions.ncols() {
440                let pred = predictions[[i, j]]
441                    .max(self.epsilon)
442                    .min(1.0 - self.epsilon);
443                let target = targets[[i, j]];
444
445                // Focal loss: -alpha * (1 - p)^gamma * log(p) for positive class
446                //             -(1 - alpha) * p^gamma * log(1 - p) for negative class
447                if target > 0.5 {
448                    // Positive class
449                    let focal_weight = (1.0 - pred).powf(self.gamma);
450                    total_loss -= self.alpha * focal_weight * pred.ln();
451                } else {
452                    // Negative class
453                    let focal_weight = pred.powf(self.gamma);
454                    total_loss -= (1.0 - self.alpha) * focal_weight * (1.0 - pred).ln();
455                }
456            }
457        }
458
459        Ok(total_loss / n)
460    }
461
462    fn gradient(
463        &self,
464        predictions: &ArrayView<f64, Ix2>,
465        targets: &ArrayView<f64, Ix2>,
466    ) -> TrainResult<Array<f64, Ix2>> {
467        if predictions.shape() != targets.shape() {
468            return Err(TrainError::LossError(format!(
469                "Shape mismatch: predictions {:?} vs targets {:?}",
470                predictions.shape(),
471                targets.shape()
472            )));
473        }
474
475        let n = predictions.nrows() as f64;
476        let mut grad = Array::zeros(predictions.raw_dim());
477
478        for i in 0..predictions.nrows() {
479            for j in 0..predictions.ncols() {
480                let pred = predictions[[i, j]]
481                    .max(self.epsilon)
482                    .min(1.0 - self.epsilon);
483                let target = targets[[i, j]];
484
485                if target > 0.5 {
486                    // Positive class gradient
487                    let focal_weight = (1.0 - pred).powf(self.gamma);
488                    let d_focal = self.gamma * (1.0 - pred).powf(self.gamma - 1.0);
489                    grad[[i, j]] = -self.alpha * (focal_weight / pred - d_focal * pred.ln()) / n;
490                } else {
491                    // Negative class gradient
492                    let focal_weight = pred.powf(self.gamma);
493                    let d_focal = self.gamma * pred.powf(self.gamma - 1.0);
494                    grad[[i, j]] = -(1.0 - self.alpha)
495                        * (d_focal * (1.0 - pred).ln() - focal_weight / (1.0 - pred))
496                        / n;
497                }
498            }
499        }
500
501        Ok(grad)
502    }
503}
504
505/// Huber loss for robust regression.
506#[derive(Debug, Clone)]
507pub struct HuberLoss {
508    /// Delta threshold for switching between L1 and L2.
509    pub delta: f64,
510}
511
512impl Default for HuberLoss {
513    fn default() -> Self {
514        Self { delta: 1.0 }
515    }
516}
517
518impl Loss for HuberLoss {
519    fn compute(
520        &self,
521        predictions: &ArrayView<f64, Ix2>,
522        targets: &ArrayView<f64, Ix2>,
523    ) -> TrainResult<f64> {
524        if predictions.shape() != targets.shape() {
525            return Err(TrainError::LossError(format!(
526                "Shape mismatch: predictions {:?} vs targets {:?}",
527                predictions.shape(),
528                targets.shape()
529            )));
530        }
531
532        let n = predictions.len() as f64;
533        let mut total_loss = 0.0;
534
535        for i in 0..predictions.nrows() {
536            for j in 0..predictions.ncols() {
537                let diff = (predictions[[i, j]] - targets[[i, j]]).abs();
538                if diff <= self.delta {
539                    // Quadratic for small errors
540                    total_loss += 0.5 * diff * diff;
541                } else {
542                    // Linear for large errors
543                    total_loss += self.delta * (diff - 0.5 * self.delta);
544                }
545            }
546        }
547
548        Ok(total_loss / n)
549    }
550
551    fn gradient(
552        &self,
553        predictions: &ArrayView<f64, Ix2>,
554        targets: &ArrayView<f64, Ix2>,
555    ) -> TrainResult<Array<f64, Ix2>> {
556        if predictions.shape() != targets.shape() {
557            return Err(TrainError::LossError(format!(
558                "Shape mismatch: predictions {:?} vs targets {:?}",
559                predictions.shape(),
560                targets.shape()
561            )));
562        }
563
564        let n = predictions.len() as f64;
565        let mut grad = Array::zeros(predictions.raw_dim());
566
567        for i in 0..predictions.nrows() {
568            for j in 0..predictions.ncols() {
569                let diff = predictions[[i, j]] - targets[[i, j]];
570                let abs_diff = diff.abs();
571
572                if abs_diff <= self.delta {
573                    grad[[i, j]] = diff / n;
574                } else {
575                    grad[[i, j]] = self.delta * diff.signum() / n;
576                }
577            }
578        }
579
580        Ok(grad)
581    }
582}
583
584/// Binary cross-entropy with logits loss (numerically stable).
585#[derive(Debug, Clone, Default)]
586pub struct BCEWithLogitsLoss;
587
588impl Loss for BCEWithLogitsLoss {
589    fn compute(
590        &self,
591        logits: &ArrayView<f64, Ix2>,
592        targets: &ArrayView<f64, Ix2>,
593    ) -> TrainResult<f64> {
594        if logits.shape() != targets.shape() {
595            return Err(TrainError::LossError(format!(
596                "Shape mismatch: logits {:?} vs targets {:?}",
597                logits.shape(),
598                targets.shape()
599            )));
600        }
601
602        let n = logits.len() as f64;
603        let mut total_loss = 0.0;
604
605        for i in 0..logits.nrows() {
606            for j in 0..logits.ncols() {
607                let logit = logits[[i, j]];
608                let target = targets[[i, j]];
609
610                // Numerically stable BCE: max(x, 0) - x * z + log(1 + exp(-|x|))
611                // where x = logit, z = target
612                let max_val = logit.max(0.0);
613                total_loss += max_val - logit * target + (1.0 + (-logit.abs()).exp()).ln();
614            }
615        }
616
617        Ok(total_loss / n)
618    }
619
620    fn gradient(
621        &self,
622        logits: &ArrayView<f64, Ix2>,
623        targets: &ArrayView<f64, Ix2>,
624    ) -> TrainResult<Array<f64, Ix2>> {
625        if logits.shape() != targets.shape() {
626            return Err(TrainError::LossError(format!(
627                "Shape mismatch: logits {:?} vs targets {:?}",
628                logits.shape(),
629                targets.shape()
630            )));
631        }
632
633        let n = logits.len() as f64;
634        let mut grad = Array::zeros(logits.raw_dim());
635
636        for i in 0..logits.nrows() {
637            for j in 0..logits.ncols() {
638                let logit = logits[[i, j]];
639                let target = targets[[i, j]];
640
641                // Gradient: sigmoid(logit) - target
642                let sigmoid = 1.0 / (1.0 + (-logit).exp());
643                grad[[i, j]] = (sigmoid - target) / n;
644            }
645        }
646
647        Ok(grad)
648    }
649}
650
651/// Dice loss for segmentation tasks.
652#[derive(Debug, Clone)]
653pub struct DiceLoss {
654    /// Smoothing factor to avoid division by zero.
655    pub smooth: f64,
656}
657
658impl Default for DiceLoss {
659    fn default() -> Self {
660        Self { smooth: 1.0 }
661    }
662}
663
664impl Loss for DiceLoss {
665    fn compute(
666        &self,
667        predictions: &ArrayView<f64, Ix2>,
668        targets: &ArrayView<f64, Ix2>,
669    ) -> TrainResult<f64> {
670        if predictions.shape() != targets.shape() {
671            return Err(TrainError::LossError(format!(
672                "Shape mismatch: predictions {:?} vs targets {:?}",
673                predictions.shape(),
674                targets.shape()
675            )));
676        }
677
678        let mut intersection = 0.0;
679        let mut pred_sum = 0.0;
680        let mut target_sum = 0.0;
681
682        for i in 0..predictions.nrows() {
683            for j in 0..predictions.ncols() {
684                let pred = predictions[[i, j]];
685                let target = targets[[i, j]];
686
687                intersection += pred * target;
688                pred_sum += pred;
689                target_sum += target;
690            }
691        }
692
693        // Dice coefficient: 2 * |X ∩ Y| / (|X| + |Y|)
694        // Dice loss: 1 - Dice coefficient
695        let dice_coef = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth);
696        Ok(1.0 - dice_coef)
697    }
698
699    fn gradient(
700        &self,
701        predictions: &ArrayView<f64, Ix2>,
702        targets: &ArrayView<f64, Ix2>,
703    ) -> TrainResult<Array<f64, Ix2>> {
704        if predictions.shape() != targets.shape() {
705            return Err(TrainError::LossError(format!(
706                "Shape mismatch: predictions {:?} vs targets {:?}",
707                predictions.shape(),
708                targets.shape()
709            )));
710        }
711
712        let mut intersection = 0.0;
713        let mut pred_sum = 0.0;
714        let mut target_sum = 0.0;
715
716        for i in 0..predictions.nrows() {
717            for j in 0..predictions.ncols() {
718                intersection += predictions[[i, j]] * targets[[i, j]];
719                pred_sum += predictions[[i, j]];
720                target_sum += targets[[i, j]];
721            }
722        }
723
724        let denominator = pred_sum + target_sum + self.smooth;
725        let numerator = 2.0 * intersection + self.smooth;
726
727        let mut grad = Array::zeros(predictions.raw_dim());
728
729        for i in 0..predictions.nrows() {
730            for j in 0..predictions.ncols() {
731                let target = targets[[i, j]];
732                // Gradient of Dice loss w.r.t. predictions
733                grad[[i, j]] =
734                    -2.0 * (target * denominator - numerator) / (denominator * denominator);
735            }
736        }
737
738        Ok(grad)
739    }
740}
741
742/// Tversky loss (generalization of Dice loss).
743/// Useful for handling class imbalance in segmentation.
744#[derive(Debug, Clone)]
745pub struct TverskyLoss {
746    /// Alpha parameter (weight for false positives).
747    pub alpha: f64,
748    /// Beta parameter (weight for false negatives).
749    pub beta: f64,
750    /// Smoothing factor.
751    pub smooth: f64,
752}
753
754impl Default for TverskyLoss {
755    fn default() -> Self {
756        Self {
757            alpha: 0.5,
758            beta: 0.5,
759            smooth: 1.0,
760        }
761    }
762}
763
764impl Loss for TverskyLoss {
765    fn compute(
766        &self,
767        predictions: &ArrayView<f64, Ix2>,
768        targets: &ArrayView<f64, Ix2>,
769    ) -> TrainResult<f64> {
770        if predictions.shape() != targets.shape() {
771            return Err(TrainError::LossError(format!(
772                "Shape mismatch: predictions {:?} vs targets {:?}",
773                predictions.shape(),
774                targets.shape()
775            )));
776        }
777
778        let mut true_pos = 0.0;
779        let mut false_pos = 0.0;
780        let mut false_neg = 0.0;
781
782        for i in 0..predictions.nrows() {
783            for j in 0..predictions.ncols() {
784                let pred = predictions[[i, j]];
785                let target = targets[[i, j]];
786
787                true_pos += pred * target;
788                false_pos += pred * (1.0 - target);
789                false_neg += (1.0 - pred) * target;
790            }
791        }
792
793        // Tversky index: TP / (TP + alpha * FP + beta * FN)
794        let tversky_index = (true_pos + self.smooth)
795            / (true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth);
796
797        Ok(1.0 - tversky_index)
798    }
799
800    fn gradient(
801        &self,
802        predictions: &ArrayView<f64, Ix2>,
803        targets: &ArrayView<f64, Ix2>,
804    ) -> TrainResult<Array<f64, Ix2>> {
805        if predictions.shape() != targets.shape() {
806            return Err(TrainError::LossError(format!(
807                "Shape mismatch: predictions {:?} vs targets {:?}",
808                predictions.shape(),
809                targets.shape()
810            )));
811        }
812
813        let mut true_pos = 0.0;
814        let mut false_pos = 0.0;
815        let mut false_neg = 0.0;
816
817        for i in 0..predictions.nrows() {
818            for j in 0..predictions.ncols() {
819                let pred = predictions[[i, j]];
820                let target = targets[[i, j]];
821
822                true_pos += pred * target;
823                false_pos += pred * (1.0 - target);
824                false_neg += (1.0 - pred) * target;
825            }
826        }
827
828        let denominator = true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth;
829        let numerator = true_pos + self.smooth;
830
831        let mut grad = Array::zeros(predictions.raw_dim());
832
833        for i in 0..predictions.nrows() {
834            for j in 0..predictions.ncols() {
835                let target = targets[[i, j]];
836
837                // Gradient of Tversky loss
838                let d_tp = target;
839                let d_fp = self.alpha * (1.0 - target);
840                let d_fn = -self.beta * target;
841
842                grad[[i, j]] = -(d_tp * denominator - numerator * (d_tp + d_fp + d_fn))
843                    / (denominator * denominator);
844            }
845        }
846
847        Ok(grad)
848    }
849}
850
851/// Contrastive loss for metric learning.
852/// Used to learn embeddings where similar pairs are close and dissimilar pairs are far apart.
853#[derive(Debug, Clone)]
854pub struct ContrastiveLoss {
855    /// Margin for dissimilar pairs.
856    pub margin: f64,
857}
858
859impl Default for ContrastiveLoss {
860    fn default() -> Self {
861        Self { margin: 1.0 }
862    }
863}
864
865impl Loss for ContrastiveLoss {
866    fn compute(
867        &self,
868        predictions: &ArrayView<f64, Ix2>,
869        targets: &ArrayView<f64, Ix2>,
870    ) -> TrainResult<f64> {
871        if predictions.ncols() != 2 || targets.ncols() != 1 {
872            return Err(TrainError::LossError(format!(
873                "ContrastiveLoss expects predictions shape [N, 2] (distances) and targets shape [N, 1] (labels), got {:?} and {:?}",
874                predictions.shape(),
875                targets.shape()
876            )));
877        }
878
879        let mut total_loss = 0.0;
880        let n = predictions.nrows() as f64;
881
882        for i in 0..predictions.nrows() {
883            let distance = predictions[[i, 0]];
884            let label = targets[[i, 0]];
885
886            if label > 0.5 {
887                // Similar pair: minimize distance
888                total_loss += distance * distance;
889            } else {
890                // Dissimilar pair: maximize distance up to margin
891                total_loss += (self.margin - distance).max(0.0).powi(2);
892            }
893        }
894
895        Ok(total_loss / n)
896    }
897
898    fn gradient(
899        &self,
900        predictions: &ArrayView<f64, Ix2>,
901        targets: &ArrayView<f64, Ix2>,
902    ) -> TrainResult<Array<f64, Ix2>> {
903        let mut grad = Array::zeros(predictions.raw_dim());
904        let n = predictions.nrows() as f64;
905
906        for i in 0..predictions.nrows() {
907            let distance = predictions[[i, 0]];
908            let label = targets[[i, 0]];
909
910            if label > 0.5 {
911                // Similar pair gradient
912                grad[[i, 0]] = 2.0 * distance / n;
913            } else {
914                // Dissimilar pair gradient
915                if distance < self.margin {
916                    grad[[i, 0]] = -2.0 * (self.margin - distance) / n;
917                }
918            }
919        }
920
921        Ok(grad)
922    }
923}
924
925/// Triplet loss for metric learning.
926/// Learns embeddings where anchor-positive distance < anchor-negative distance + margin.
927#[derive(Debug, Clone)]
928pub struct TripletLoss {
929    /// Margin between positive and negative distances.
930    pub margin: f64,
931}
932
933impl Default for TripletLoss {
934    fn default() -> Self {
935        Self { margin: 1.0 }
936    }
937}
938
939impl Loss for TripletLoss {
940    fn compute(
941        &self,
942        predictions: &ArrayView<f64, Ix2>,
943        _targets: &ArrayView<f64, Ix2>,
944    ) -> TrainResult<f64> {
945        if predictions.ncols() != 2 {
946            return Err(TrainError::LossError(format!(
947                "TripletLoss expects predictions shape [N, 2] (pos_dist, neg_dist), got {:?}",
948                predictions.shape()
949            )));
950        }
951
952        let mut total_loss = 0.0;
953        let n = predictions.nrows() as f64;
954
955        for i in 0..predictions.nrows() {
956            let pos_distance = predictions[[i, 0]];
957            let neg_distance = predictions[[i, 1]];
958
959            // Loss = max(0, pos_dist - neg_dist + margin)
960            let loss = (pos_distance - neg_distance + self.margin).max(0.0);
961            total_loss += loss;
962        }
963
964        Ok(total_loss / n)
965    }
966
967    fn gradient(
968        &self,
969        predictions: &ArrayView<f64, Ix2>,
970        _targets: &ArrayView<f64, Ix2>,
971    ) -> TrainResult<Array<f64, Ix2>> {
972        let mut grad = Array::zeros(predictions.raw_dim());
973        let n = predictions.nrows() as f64;
974
975        for i in 0..predictions.nrows() {
976            let pos_distance = predictions[[i, 0]];
977            let neg_distance = predictions[[i, 1]];
978
979            if pos_distance - neg_distance + self.margin > 0.0 {
980                // Gradient w.r.t. positive distance
981                grad[[i, 0]] = 1.0 / n;
982                // Gradient w.r.t. negative distance
983                grad[[i, 1]] = -1.0 / n;
984            }
985        }
986
987        Ok(grad)
988    }
989}
990
991/// Hinge loss for maximum-margin classification (SVM-style).
992#[derive(Debug, Clone)]
993pub struct HingeLoss {
994    /// Margin for classification.
995    pub margin: f64,
996}
997
998impl Default for HingeLoss {
999    fn default() -> Self {
1000        Self { margin: 1.0 }
1001    }
1002}
1003
1004impl Loss for HingeLoss {
1005    fn compute(
1006        &self,
1007        predictions: &ArrayView<f64, Ix2>,
1008        targets: &ArrayView<f64, Ix2>,
1009    ) -> TrainResult<f64> {
1010        if predictions.shape() != targets.shape() {
1011            return Err(TrainError::LossError(format!(
1012                "Shape mismatch: predictions {:?} vs targets {:?}",
1013                predictions.shape(),
1014                targets.shape()
1015            )));
1016        }
1017
1018        let mut total_loss = 0.0;
1019        let n = predictions.nrows() as f64;
1020
1021        for i in 0..predictions.nrows() {
1022            for j in 0..predictions.ncols() {
1023                let pred = predictions[[i, j]];
1024                let target = targets[[i, j]];
1025
1026                // targets should be +1 or -1
1027                let loss = (self.margin - target * pred).max(0.0);
1028                total_loss += loss;
1029            }
1030        }
1031
1032        Ok(total_loss / n)
1033    }
1034
1035    fn gradient(
1036        &self,
1037        predictions: &ArrayView<f64, Ix2>,
1038        targets: &ArrayView<f64, Ix2>,
1039    ) -> TrainResult<Array<f64, Ix2>> {
1040        let mut grad = Array::zeros(predictions.raw_dim());
1041        let n = predictions.nrows() as f64;
1042
1043        for i in 0..predictions.nrows() {
1044            for j in 0..predictions.ncols() {
1045                let pred = predictions[[i, j]];
1046                let target = targets[[i, j]];
1047
1048                if self.margin - target * pred > 0.0 {
1049                    grad[[i, j]] = -target / n;
1050                }
1051            }
1052        }
1053
1054        Ok(grad)
1055    }
1056}
1057
1058/// Kullback-Leibler Divergence loss.
1059/// Measures how one probability distribution diverges from a reference distribution.
1060#[derive(Debug, Clone)]
1061pub struct KLDivergenceLoss {
1062    /// Epsilon for numerical stability.
1063    pub epsilon: f64,
1064}
1065
1066impl Default for KLDivergenceLoss {
1067    fn default() -> Self {
1068        Self { epsilon: 1e-10 }
1069    }
1070}
1071
1072impl Loss for KLDivergenceLoss {
1073    fn compute(
1074        &self,
1075        predictions: &ArrayView<f64, Ix2>,
1076        targets: &ArrayView<f64, Ix2>,
1077    ) -> TrainResult<f64> {
1078        if predictions.shape() != targets.shape() {
1079            return Err(TrainError::LossError(format!(
1080                "Shape mismatch: predictions {:?} vs targets {:?}",
1081                predictions.shape(),
1082                targets.shape()
1083            )));
1084        }
1085
1086        let mut total_loss = 0.0;
1087
1088        for i in 0..predictions.nrows() {
1089            for j in 0..predictions.ncols() {
1090                let pred = predictions[[i, j]].max(self.epsilon);
1091                let target = targets[[i, j]].max(self.epsilon);
1092
1093                // KL(target || pred) = sum(target * log(target / pred))
1094                total_loss += target * (target / pred).ln();
1095            }
1096        }
1097
1098        Ok(total_loss)
1099    }
1100
1101    fn gradient(
1102        &self,
1103        predictions: &ArrayView<f64, Ix2>,
1104        targets: &ArrayView<f64, Ix2>,
1105    ) -> TrainResult<Array<f64, Ix2>> {
1106        let mut grad = Array::zeros(predictions.raw_dim());
1107
1108        for i in 0..predictions.nrows() {
1109            for j in 0..predictions.ncols() {
1110                let pred = predictions[[i, j]].max(self.epsilon);
1111                let target = targets[[i, j]].max(self.epsilon);
1112
1113                // Gradient of KL divergence w.r.t. predictions
1114                grad[[i, j]] = -target / pred;
1115            }
1116        }
1117
1118        Ok(grad)
1119    }
1120}
1121
1122/// Poly Loss - Polynomial Expansion of Cross-Entropy Loss.
1123///
1124/// Paper: "PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions" (Leng et al., 2022)
1125/// <https://arxiv.org/abs/2204.12511>
1126///
1127/// PolyLoss adds polynomial terms to cross-entropy to provide better gradient flow
1128/// for well-classified examples. It helps with:
1129/// - Label noise robustness
1130/// - Improved generalization
1131/// - Better handling of class imbalance
1132///
1133/// The loss is defined as:
1134/// L_poly = CE + ε₁(1 - p_t) + ε₂(1 - p_t)² + ... + εⱼ(1 - p_t)^j
1135///
1136/// where p_t is the predicted probability of the target class, and εⱼ are polynomial coefficients.
1137/// In practice, Poly-1 (j=1) is most commonly used.
1138#[derive(Debug, Clone)]
1139pub struct PolyLoss {
1140    /// Epsilon for numerical stability
1141    pub epsilon: f64,
1142    /// Polynomial coefficient (typically between 0.5 and 2.0)
1143    pub poly_coeff: f64,
1144}
1145
1146impl Default for PolyLoss {
1147    fn default() -> Self {
1148        Self {
1149            epsilon: 1e-10,
1150            poly_coeff: 1.0, // Poly-1 Loss
1151        }
1152    }
1153}
1154
1155impl PolyLoss {
1156    /// Create a new Poly Loss with custom coefficient.
1157    pub fn new(poly_coeff: f64) -> Self {
1158        Self {
1159            epsilon: 1e-10,
1160            poly_coeff,
1161        }
1162    }
1163}
1164
1165impl Loss for PolyLoss {
1166    fn compute(
1167        &self,
1168        predictions: &ArrayView<f64, Ix2>,
1169        targets: &ArrayView<f64, Ix2>,
1170    ) -> TrainResult<f64> {
1171        if predictions.shape() != targets.shape() {
1172            return Err(TrainError::LossError(format!(
1173                "Shape mismatch: predictions {:?} vs targets {:?}",
1174                predictions.shape(),
1175                targets.shape()
1176            )));
1177        }
1178
1179        let n = predictions.nrows() as f64;
1180        let mut total_loss = 0.0;
1181
1182        for i in 0..predictions.nrows() {
1183            for j in 0..predictions.ncols() {
1184                let pred = predictions[[i, j]]
1185                    .max(self.epsilon)
1186                    .min(1.0 - self.epsilon);
1187                let target = targets[[i, j]];
1188
1189                // Cross-entropy term
1190                let ce = -target * pred.ln();
1191
1192                // Poly term: ε * (1 - p_t) where p_t is probability of target class
1193                // For multi-class, we use the predicted probability at the target position
1194                let poly_term = if target > 0.5 {
1195                    // Target is 1, so p_t = pred
1196                    self.poly_coeff * (1.0 - pred)
1197                } else {
1198                    // Target is 0, so p_t = 1 - pred
1199                    self.poly_coeff * pred
1200                };
1201
1202                total_loss += ce + poly_term;
1203            }
1204        }
1205
1206        Ok(total_loss / n)
1207    }
1208
1209    fn gradient(
1210        &self,
1211        predictions: &ArrayView<f64, Ix2>,
1212        targets: &ArrayView<f64, Ix2>,
1213    ) -> TrainResult<Array<f64, Ix2>> {
1214        let n = predictions.nrows() as f64;
1215        let mut grad = Array::zeros(predictions.raw_dim());
1216
1217        for i in 0..predictions.nrows() {
1218            for j in 0..predictions.ncols() {
1219                let pred = predictions[[i, j]]
1220                    .max(self.epsilon)
1221                    .min(1.0 - self.epsilon);
1222                let target = targets[[i, j]];
1223
1224                // Gradient of cross-entropy: -target / pred
1225                let ce_grad = -target / pred;
1226
1227                // Gradient of poly term
1228                let poly_grad = if target > 0.5 {
1229                    // d/dp [ε * (1 - p)] = -ε
1230                    -self.poly_coeff
1231                } else {
1232                    // d/dp [ε * p] = ε
1233                    self.poly_coeff
1234                };
1235
1236                grad[[i, j]] = (ce_grad + poly_grad) / n;
1237            }
1238        }
1239
1240        Ok(grad)
1241    }
1242
1243    fn name(&self) -> &str {
1244        "poly_loss"
1245    }
1246}
1247
1248#[cfg(test)]
1249mod tests {
1250    use super::*;
1251    use scirs2_core::ndarray::array;
1252
1253    #[test]
1254    fn test_cross_entropy_loss() {
1255        let loss = CrossEntropyLoss::default();
1256        let predictions = array![[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]];
1257        let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
1258
1259        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1260        assert!(loss_val > 0.0);
1261
1262        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1263        assert_eq!(grad.shape(), predictions.shape());
1264    }
1265
1266    #[test]
1267    fn test_mse_loss() {
1268        let loss = MseLoss;
1269        let predictions = array![[1.0, 2.0], [3.0, 4.0]];
1270        let targets = array![[1.5, 2.5], [3.5, 4.5]];
1271
1272        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1273        assert!((loss_val - 0.25).abs() < 1e-6);
1274
1275        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1276        assert_eq!(grad.shape(), predictions.shape());
1277    }
1278
1279    #[test]
1280    fn test_rule_satisfaction_loss() {
1281        let loss = RuleSatisfactionLoss::default();
1282        let rule_values = array![[0.9, 0.8], [0.95, 0.85]];
1283        let targets = array![[1.0, 1.0], [1.0, 1.0]];
1284
1285        let loss_val = loss.compute(&rule_values.view(), &targets.view()).unwrap();
1286        assert!(loss_val > 0.0);
1287
1288        let grad = loss.gradient(&rule_values.view(), &targets.view()).unwrap();
1289        assert_eq!(grad.shape(), rule_values.shape());
1290    }
1291
1292    #[test]
1293    fn test_constraint_violation_loss() {
1294        let loss = ConstraintViolationLoss::default();
1295        let constraint_values = array![[0.1, -0.1], [0.2, -0.2]];
1296        let targets = array![[0.0, 0.0], [0.0, 0.0]];
1297
1298        let loss_val = loss
1299            .compute(&constraint_values.view(), &targets.view())
1300            .unwrap();
1301        assert!(loss_val > 0.0);
1302
1303        let grad = loss
1304            .gradient(&constraint_values.view(), &targets.view())
1305            .unwrap();
1306        assert_eq!(grad.shape(), constraint_values.shape());
1307    }
1308
1309    #[test]
1310    fn test_focal_loss() {
1311        let loss = FocalLoss::default();
1312        let predictions = array![[0.9, 0.1], [0.2, 0.8]];
1313        let targets = array![[1.0, 0.0], [0.0, 1.0]];
1314
1315        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1316        assert!(loss_val >= 0.0);
1317
1318        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1319        assert_eq!(grad.shape(), predictions.shape());
1320    }
1321
1322    #[test]
1323    fn test_huber_loss() {
1324        let loss = HuberLoss::default();
1325        let predictions = array![[1.0, 3.0], [2.0, 5.0]];
1326        let targets = array![[1.5, 2.0], [2.5, 4.0]];
1327
1328        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1329        assert!(loss_val > 0.0);
1330
1331        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1332        assert_eq!(grad.shape(), predictions.shape());
1333    }
1334
1335    #[test]
1336    fn test_bce_with_logits_loss() {
1337        let loss = BCEWithLogitsLoss;
1338        let logits = array![[0.5, -0.5], [1.0, -1.0]];
1339        let targets = array![[1.0, 0.0], [1.0, 0.0]];
1340
1341        let loss_val = loss.compute(&logits.view(), &targets.view()).unwrap();
1342        assert!(loss_val >= 0.0);
1343
1344        let grad = loss.gradient(&logits.view(), &targets.view()).unwrap();
1345        assert_eq!(grad.shape(), logits.shape());
1346    }
1347
1348    #[test]
1349    fn test_dice_loss() {
1350        let loss = DiceLoss::default();
1351        let predictions = array![[0.9, 0.1], [0.8, 0.2]];
1352        let targets = array![[1.0, 0.0], [1.0, 0.0]];
1353
1354        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1355        assert!(loss_val >= 0.0);
1356        assert!(loss_val <= 1.0);
1357
1358        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1359        assert_eq!(grad.shape(), predictions.shape());
1360    }
1361
1362    #[test]
1363    fn test_tversky_loss() {
1364        let loss = TverskyLoss::default();
1365        let predictions = array![[0.9, 0.1], [0.8, 0.2]];
1366        let targets = array![[1.0, 0.0], [1.0, 0.0]];
1367
1368        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1369        assert!(loss_val >= 0.0);
1370        assert!(loss_val <= 1.0);
1371
1372        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1373        assert_eq!(grad.shape(), predictions.shape());
1374    }
1375
1376    #[test]
1377    fn test_contrastive_loss() {
1378        let loss = ContrastiveLoss::default();
1379        // Predictions: [N, 2] where first column is distance, second is unused
1380        // Targets: [N, 1] where 1.0 = similar pair, 0.0 = dissimilar pair
1381        let predictions = array![[0.5, 0.0], [1.5, 0.0], [0.2, 0.0]];
1382        let targets = array![[1.0], [0.0], [1.0]];
1383
1384        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1385        assert!(loss_val >= 0.0);
1386
1387        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1388        assert_eq!(grad.shape(), predictions.shape());
1389
1390        // For similar pair (label=1), gradient should push distance down
1391        assert!(grad[[0, 0]] > 0.0);
1392        // For dissimilar pair beyond margin, gradient should be 0
1393        assert_eq!(grad[[1, 0]], 0.0);
1394    }
1395
1396    #[test]
1397    fn test_triplet_loss() {
1398        let loss = TripletLoss::default();
1399        // Predictions: [N, 2] where columns are (positive_distance, negative_distance)
1400        let predictions = array![[0.5, 2.0], [1.0, 0.5], [0.3, 1.5]];
1401        let targets = array![[0.0], [0.0], [0.0]]; // Not used but required for interface
1402
1403        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1404        assert!(loss_val >= 0.0);
1405
1406        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1407        assert_eq!(grad.shape(), predictions.shape());
1408
1409        // First triplet: pos_dist < neg_dist - margin, so no gradient
1410        assert_eq!(grad[[0, 0]], 0.0);
1411        assert_eq!(grad[[0, 1]], 0.0);
1412
1413        // Second triplet: pos_dist > neg_dist, so should have gradient
1414        assert!(grad[[1, 0]] > 0.0);
1415        assert!(grad[[1, 1]] < 0.0);
1416    }
1417
1418    #[test]
1419    fn test_hinge_loss() {
1420        let loss = HingeLoss::default();
1421        // Predictions are raw scores, targets should be +1 or -1
1422        let predictions = array![[0.5, -0.5], [2.0, -2.0]];
1423        let targets = array![[1.0, -1.0], [1.0, -1.0]];
1424
1425        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1426        assert!(loss_val >= 0.0);
1427
1428        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1429        assert_eq!(grad.shape(), predictions.shape());
1430
1431        // For correct predictions with large margin, gradient should be 0
1432        assert_eq!(grad[[1, 0]], 0.0);
1433        assert_eq!(grad[[1, 1]], 0.0);
1434    }
1435
1436    #[test]
1437    fn test_kl_divergence_loss() {
1438        let loss = KLDivergenceLoss::default();
1439        // Both predictions and targets should be probability distributions
1440        let predictions = array![[0.6, 0.4], [0.7, 0.3]];
1441        let targets = array![[0.5, 0.5], [0.8, 0.2]];
1442
1443        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1444        assert!(loss_val >= 0.0);
1445
1446        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1447        assert_eq!(grad.shape(), predictions.shape());
1448
1449        // KL divergence should be 0 when distributions are identical
1450        let identical_preds = array![[0.5, 0.5]];
1451        let identical_targets = array![[0.5, 0.5]];
1452        let identical_loss = loss
1453            .compute(&identical_preds.view(), &identical_targets.view())
1454            .unwrap();
1455        assert!(identical_loss.abs() < 1e-6);
1456    }
1457
1458    #[test]
1459    fn test_poly_loss() {
1460        let loss = PolyLoss::default();
1461        let predictions = array![[0.9, 0.1], [0.2, 0.8]];
1462        let targets = array![[1.0, 0.0], [0.0, 1.0]];
1463
1464        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1465        assert!(loss_val > 0.0);
1466
1467        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1468        assert_eq!(grad.shape(), predictions.shape());
1469
1470        // Poly loss should be greater than standard cross-entropy for well-classified examples
1471        let ce_loss = CrossEntropyLoss::default();
1472        let ce_val = ce_loss
1473            .compute(&predictions.view(), &targets.view())
1474            .unwrap();
1475
1476        // With default poly_coeff = 1.0, poly loss includes additional penalty term
1477        assert!(loss_val >= ce_val);
1478    }
1479
1480    #[test]
1481    fn test_poly_loss_custom_coefficient() {
1482        let loss = PolyLoss::new(2.0);
1483        let predictions = array![[0.8, 0.2]];
1484        let targets = array![[1.0, 0.0]];
1485
1486        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1487        assert!(loss_val > 0.0);
1488
1489        // Higher coefficient should result in larger poly term
1490        let loss_low_coeff = PolyLoss::new(0.5);
1491        let loss_val_low = loss_low_coeff
1492            .compute(&predictions.view(), &targets.view())
1493            .unwrap();
1494
1495        assert!(loss_val > loss_val_low);
1496    }
1497}