tensorlogic_train/
loss.rs

1//! Loss functions for training.
2//!
3//! Provides both standard ML loss functions and logical constraint-based losses.
4
5use crate::{TrainError, TrainResult};
6use scirs2_core::ndarray::{Array, ArrayView, Ix2};
7use std::fmt::Debug;
8
9/// Configuration for loss functions.
10#[derive(Debug, Clone)]
11pub struct LossConfig {
12    /// Weight for supervised loss component.
13    pub supervised_weight: f64,
14    /// Weight for constraint violation loss component.
15    pub constraint_weight: f64,
16    /// Weight for rule satisfaction loss component.
17    pub rule_weight: f64,
18    /// Temperature for soft constraint penalties.
19    pub temperature: f64,
20}
21
22impl Default for LossConfig {
23    fn default() -> Self {
24        Self {
25            supervised_weight: 1.0,
26            constraint_weight: 1.0,
27            rule_weight: 1.0,
28            temperature: 1.0,
29        }
30    }
31}
32
33/// Trait for loss functions.
34pub trait Loss: Debug {
35    /// Compute loss value.
36    fn compute(
37        &self,
38        predictions: &ArrayView<f64, Ix2>,
39        targets: &ArrayView<f64, Ix2>,
40    ) -> TrainResult<f64>;
41
42    /// Compute loss gradient with respect to predictions.
43    fn gradient(
44        &self,
45        predictions: &ArrayView<f64, Ix2>,
46        targets: &ArrayView<f64, Ix2>,
47    ) -> TrainResult<Array<f64, Ix2>>;
48}
49
50/// Cross-entropy loss for classification.
51#[derive(Debug, Clone)]
52pub struct CrossEntropyLoss {
53    /// Epsilon for numerical stability.
54    pub epsilon: f64,
55}
56
57impl Default for CrossEntropyLoss {
58    fn default() -> Self {
59        Self { epsilon: 1e-10 }
60    }
61}
62
63impl Loss for CrossEntropyLoss {
64    fn compute(
65        &self,
66        predictions: &ArrayView<f64, Ix2>,
67        targets: &ArrayView<f64, Ix2>,
68    ) -> TrainResult<f64> {
69        if predictions.shape() != targets.shape() {
70            return Err(TrainError::LossError(format!(
71                "Shape mismatch: predictions {:?} vs targets {:?}",
72                predictions.shape(),
73                targets.shape()
74            )));
75        }
76
77        let n = predictions.nrows() as f64;
78        let mut total_loss = 0.0;
79
80        for i in 0..predictions.nrows() {
81            for j in 0..predictions.ncols() {
82                let pred = predictions[[i, j]]
83                    .max(self.epsilon)
84                    .min(1.0 - self.epsilon);
85                let target = targets[[i, j]];
86                total_loss -= target * pred.ln();
87            }
88        }
89
90        Ok(total_loss / n)
91    }
92
93    fn gradient(
94        &self,
95        predictions: &ArrayView<f64, Ix2>,
96        targets: &ArrayView<f64, Ix2>,
97    ) -> TrainResult<Array<f64, Ix2>> {
98        if predictions.shape() != targets.shape() {
99            return Err(TrainError::LossError(format!(
100                "Shape mismatch: predictions {:?} vs targets {:?}",
101                predictions.shape(),
102                targets.shape()
103            )));
104        }
105
106        let n = predictions.nrows() as f64;
107        let mut grad = Array::zeros(predictions.raw_dim());
108
109        for i in 0..predictions.nrows() {
110            for j in 0..predictions.ncols() {
111                let pred = predictions[[i, j]]
112                    .max(self.epsilon)
113                    .min(1.0 - self.epsilon);
114                let target = targets[[i, j]];
115                grad[[i, j]] = -(target / pred) / n;
116            }
117        }
118
119        Ok(grad)
120    }
121}
122
123/// Mean squared error loss for regression.
124#[derive(Debug, Clone, Default)]
125pub struct MseLoss;
126
127impl Loss for MseLoss {
128    fn compute(
129        &self,
130        predictions: &ArrayView<f64, Ix2>,
131        targets: &ArrayView<f64, Ix2>,
132    ) -> TrainResult<f64> {
133        if predictions.shape() != targets.shape() {
134            return Err(TrainError::LossError(format!(
135                "Shape mismatch: predictions {:?} vs targets {:?}",
136                predictions.shape(),
137                targets.shape()
138            )));
139        }
140
141        let n = predictions.len() as f64;
142        let mut total_loss = 0.0;
143
144        for i in 0..predictions.nrows() {
145            for j in 0..predictions.ncols() {
146                let diff = predictions[[i, j]] - targets[[i, j]];
147                total_loss += diff * diff;
148            }
149        }
150
151        Ok(total_loss / n)
152    }
153
154    fn gradient(
155        &self,
156        predictions: &ArrayView<f64, Ix2>,
157        targets: &ArrayView<f64, Ix2>,
158    ) -> TrainResult<Array<f64, Ix2>> {
159        if predictions.shape() != targets.shape() {
160            return Err(TrainError::LossError(format!(
161                "Shape mismatch: predictions {:?} vs targets {:?}",
162                predictions.shape(),
163                targets.shape()
164            )));
165        }
166
167        let n = predictions.len() as f64;
168        let mut grad = Array::zeros(predictions.raw_dim());
169
170        for i in 0..predictions.nrows() {
171            for j in 0..predictions.ncols() {
172                grad[[i, j]] = 2.0 * (predictions[[i, j]] - targets[[i, j]]) / n;
173            }
174        }
175
176        Ok(grad)
177    }
178}
179
180/// Logical loss combining multiple objectives.
181#[derive(Debug)]
182pub struct LogicalLoss {
183    /// Configuration.
184    pub config: LossConfig,
185    /// Supervised loss component.
186    pub supervised_loss: Box<dyn Loss>,
187    /// Rule satisfaction components.
188    pub rule_losses: Vec<Box<dyn Loss>>,
189    /// Constraint violation components.
190    pub constraint_losses: Vec<Box<dyn Loss>>,
191}
192
193impl LogicalLoss {
194    /// Create a new logical loss.
195    pub fn new(
196        config: LossConfig,
197        supervised_loss: Box<dyn Loss>,
198        rule_losses: Vec<Box<dyn Loss>>,
199        constraint_losses: Vec<Box<dyn Loss>>,
200    ) -> Self {
201        Self {
202            config,
203            supervised_loss,
204            rule_losses,
205            constraint_losses,
206        }
207    }
208
209    /// Compute total loss with all components.
210    pub fn compute_total(
211        &self,
212        predictions: &ArrayView<f64, Ix2>,
213        targets: &ArrayView<f64, Ix2>,
214        rule_values: &[ArrayView<f64, Ix2>],
215        constraint_values: &[ArrayView<f64, Ix2>],
216    ) -> TrainResult<f64> {
217        let mut total = 0.0;
218
219        // Supervised loss
220        let supervised = self.supervised_loss.compute(predictions, targets)?;
221        total += self.config.supervised_weight * supervised;
222
223        // Rule satisfaction losses
224        if !rule_values.is_empty() && !self.rule_losses.is_empty() {
225            let expected_true = Array::ones((rule_values[0].nrows(), rule_values[0].ncols()));
226            let expected_true_view = expected_true.view();
227
228            for (rule_val, rule_loss) in rule_values.iter().zip(self.rule_losses.iter()) {
229                let rule_loss_val = rule_loss.compute(rule_val, &expected_true_view)?;
230                total += self.config.rule_weight * rule_loss_val;
231            }
232        }
233
234        // Constraint violation losses
235        if !constraint_values.is_empty() && !self.constraint_losses.is_empty() {
236            let expected_zero =
237                Array::zeros((constraint_values[0].nrows(), constraint_values[0].ncols()));
238            let expected_zero_view = expected_zero.view();
239
240            for (constraint_val, constraint_loss) in
241                constraint_values.iter().zip(self.constraint_losses.iter())
242            {
243                let constraint_loss_val =
244                    constraint_loss.compute(constraint_val, &expected_zero_view)?;
245                total += self.config.constraint_weight * constraint_loss_val;
246            }
247        }
248
249        Ok(total)
250    }
251}
252
253/// Rule satisfaction loss - measures how well rules are satisfied.
254#[derive(Debug, Clone)]
255pub struct RuleSatisfactionLoss {
256    /// Temperature for soft satisfaction.
257    pub temperature: f64,
258}
259
260impl Default for RuleSatisfactionLoss {
261    fn default() -> Self {
262        Self { temperature: 1.0 }
263    }
264}
265
266impl Loss for RuleSatisfactionLoss {
267    fn compute(
268        &self,
269        rule_values: &ArrayView<f64, Ix2>,
270        targets: &ArrayView<f64, Ix2>,
271    ) -> TrainResult<f64> {
272        if rule_values.shape() != targets.shape() {
273            return Err(TrainError::LossError(format!(
274                "Shape mismatch: rule_values {:?} vs targets {:?}",
275                rule_values.shape(),
276                targets.shape()
277            )));
278        }
279
280        let n = rule_values.len() as f64;
281        let mut total_loss = 0.0;
282
283        // Penalize deviations from expected rule satisfaction (typically 1.0)
284        for i in 0..rule_values.nrows() {
285            for j in 0..rule_values.ncols() {
286                let diff = targets[[i, j]] - rule_values[[i, j]];
287                // Soft penalty with temperature
288                total_loss += (diff / self.temperature).powi(2);
289            }
290        }
291
292        Ok(total_loss / n)
293    }
294
295    fn gradient(
296        &self,
297        rule_values: &ArrayView<f64, Ix2>,
298        targets: &ArrayView<f64, Ix2>,
299    ) -> TrainResult<Array<f64, Ix2>> {
300        if rule_values.shape() != targets.shape() {
301            return Err(TrainError::LossError(format!(
302                "Shape mismatch: rule_values {:?} vs targets {:?}",
303                rule_values.shape(),
304                targets.shape()
305            )));
306        }
307
308        let n = rule_values.len() as f64;
309        let mut grad = Array::zeros(rule_values.raw_dim());
310
311        for i in 0..rule_values.nrows() {
312            for j in 0..rule_values.ncols() {
313                let diff = targets[[i, j]] - rule_values[[i, j]];
314                grad[[i, j]] = -2.0 * diff / (self.temperature * self.temperature * n);
315            }
316        }
317
318        Ok(grad)
319    }
320}
321
322/// Constraint violation loss - penalizes constraint violations.
323#[derive(Debug, Clone)]
324pub struct ConstraintViolationLoss {
325    /// Penalty weight for violations.
326    pub penalty_weight: f64,
327}
328
329impl Default for ConstraintViolationLoss {
330    fn default() -> Self {
331        Self {
332            penalty_weight: 10.0,
333        }
334    }
335}
336
337impl Loss for ConstraintViolationLoss {
338    fn compute(
339        &self,
340        constraint_values: &ArrayView<f64, Ix2>,
341        targets: &ArrayView<f64, Ix2>,
342    ) -> TrainResult<f64> {
343        if constraint_values.shape() != targets.shape() {
344            return Err(TrainError::LossError(format!(
345                "Shape mismatch: constraint_values {:?} vs targets {:?}",
346                constraint_values.shape(),
347                targets.shape()
348            )));
349        }
350
351        let n = constraint_values.len() as f64;
352        let mut total_loss = 0.0;
353
354        // Penalize any positive violation (constraint_values should be <= 0)
355        for i in 0..constraint_values.nrows() {
356            for j in 0..constraint_values.ncols() {
357                let violation = (constraint_values[[i, j]] - targets[[i, j]]).max(0.0);
358                total_loss += self.penalty_weight * violation * violation;
359            }
360        }
361
362        Ok(total_loss / n)
363    }
364
365    fn gradient(
366        &self,
367        constraint_values: &ArrayView<f64, Ix2>,
368        targets: &ArrayView<f64, Ix2>,
369    ) -> TrainResult<Array<f64, Ix2>> {
370        if constraint_values.shape() != targets.shape() {
371            return Err(TrainError::LossError(format!(
372                "Shape mismatch: constraint_values {:?} vs targets {:?}",
373                constraint_values.shape(),
374                targets.shape()
375            )));
376        }
377
378        let n = constraint_values.len() as f64;
379        let mut grad = Array::zeros(constraint_values.raw_dim());
380
381        for i in 0..constraint_values.nrows() {
382            for j in 0..constraint_values.ncols() {
383                let violation = constraint_values[[i, j]] - targets[[i, j]];
384                if violation > 0.0 {
385                    grad[[i, j]] = 2.0 * self.penalty_weight * violation / n;
386                }
387            }
388        }
389
390        Ok(grad)
391    }
392}
393
394/// Focal loss for addressing class imbalance.
395/// Reference: Lin et al., "Focal Loss for Dense Object Detection"
396#[derive(Debug, Clone)]
397pub struct FocalLoss {
398    /// Alpha weighting factor for positive class (range: [0, 1]).
399    pub alpha: f64,
400    /// Gamma focusing parameter (typically 2.0).
401    pub gamma: f64,
402    /// Epsilon for numerical stability.
403    pub epsilon: f64,
404}
405
406impl Default for FocalLoss {
407    fn default() -> Self {
408        Self {
409            alpha: 0.25,
410            gamma: 2.0,
411            epsilon: 1e-10,
412        }
413    }
414}
415
416impl Loss for FocalLoss {
417    fn compute(
418        &self,
419        predictions: &ArrayView<f64, Ix2>,
420        targets: &ArrayView<f64, Ix2>,
421    ) -> TrainResult<f64> {
422        if predictions.shape() != targets.shape() {
423            return Err(TrainError::LossError(format!(
424                "Shape mismatch: predictions {:?} vs targets {:?}",
425                predictions.shape(),
426                targets.shape()
427            )));
428        }
429
430        let n = predictions.nrows() as f64;
431        let mut total_loss = 0.0;
432
433        for i in 0..predictions.nrows() {
434            for j in 0..predictions.ncols() {
435                let pred = predictions[[i, j]]
436                    .max(self.epsilon)
437                    .min(1.0 - self.epsilon);
438                let target = targets[[i, j]];
439
440                // Focal loss: -alpha * (1 - p)^gamma * log(p) for positive class
441                //             -(1 - alpha) * p^gamma * log(1 - p) for negative class
442                if target > 0.5 {
443                    // Positive class
444                    let focal_weight = (1.0 - pred).powf(self.gamma);
445                    total_loss -= self.alpha * focal_weight * pred.ln();
446                } else {
447                    // Negative class
448                    let focal_weight = pred.powf(self.gamma);
449                    total_loss -= (1.0 - self.alpha) * focal_weight * (1.0 - pred).ln();
450                }
451            }
452        }
453
454        Ok(total_loss / n)
455    }
456
457    fn gradient(
458        &self,
459        predictions: &ArrayView<f64, Ix2>,
460        targets: &ArrayView<f64, Ix2>,
461    ) -> TrainResult<Array<f64, Ix2>> {
462        if predictions.shape() != targets.shape() {
463            return Err(TrainError::LossError(format!(
464                "Shape mismatch: predictions {:?} vs targets {:?}",
465                predictions.shape(),
466                targets.shape()
467            )));
468        }
469
470        let n = predictions.nrows() as f64;
471        let mut grad = Array::zeros(predictions.raw_dim());
472
473        for i in 0..predictions.nrows() {
474            for j in 0..predictions.ncols() {
475                let pred = predictions[[i, j]]
476                    .max(self.epsilon)
477                    .min(1.0 - self.epsilon);
478                let target = targets[[i, j]];
479
480                if target > 0.5 {
481                    // Positive class gradient
482                    let focal_weight = (1.0 - pred).powf(self.gamma);
483                    let d_focal = self.gamma * (1.0 - pred).powf(self.gamma - 1.0);
484                    grad[[i, j]] = -self.alpha * (focal_weight / pred - d_focal * pred.ln()) / n;
485                } else {
486                    // Negative class gradient
487                    let focal_weight = pred.powf(self.gamma);
488                    let d_focal = self.gamma * pred.powf(self.gamma - 1.0);
489                    grad[[i, j]] = -(1.0 - self.alpha)
490                        * (d_focal * (1.0 - pred).ln() - focal_weight / (1.0 - pred))
491                        / n;
492                }
493            }
494        }
495
496        Ok(grad)
497    }
498}
499
500/// Huber loss for robust regression.
501#[derive(Debug, Clone)]
502pub struct HuberLoss {
503    /// Delta threshold for switching between L1 and L2.
504    pub delta: f64,
505}
506
507impl Default for HuberLoss {
508    fn default() -> Self {
509        Self { delta: 1.0 }
510    }
511}
512
513impl Loss for HuberLoss {
514    fn compute(
515        &self,
516        predictions: &ArrayView<f64, Ix2>,
517        targets: &ArrayView<f64, Ix2>,
518    ) -> TrainResult<f64> {
519        if predictions.shape() != targets.shape() {
520            return Err(TrainError::LossError(format!(
521                "Shape mismatch: predictions {:?} vs targets {:?}",
522                predictions.shape(),
523                targets.shape()
524            )));
525        }
526
527        let n = predictions.len() as f64;
528        let mut total_loss = 0.0;
529
530        for i in 0..predictions.nrows() {
531            for j in 0..predictions.ncols() {
532                let diff = (predictions[[i, j]] - targets[[i, j]]).abs();
533                if diff <= self.delta {
534                    // Quadratic for small errors
535                    total_loss += 0.5 * diff * diff;
536                } else {
537                    // Linear for large errors
538                    total_loss += self.delta * (diff - 0.5 * self.delta);
539                }
540            }
541        }
542
543        Ok(total_loss / n)
544    }
545
546    fn gradient(
547        &self,
548        predictions: &ArrayView<f64, Ix2>,
549        targets: &ArrayView<f64, Ix2>,
550    ) -> TrainResult<Array<f64, Ix2>> {
551        if predictions.shape() != targets.shape() {
552            return Err(TrainError::LossError(format!(
553                "Shape mismatch: predictions {:?} vs targets {:?}",
554                predictions.shape(),
555                targets.shape()
556            )));
557        }
558
559        let n = predictions.len() as f64;
560        let mut grad = Array::zeros(predictions.raw_dim());
561
562        for i in 0..predictions.nrows() {
563            for j in 0..predictions.ncols() {
564                let diff = predictions[[i, j]] - targets[[i, j]];
565                let abs_diff = diff.abs();
566
567                if abs_diff <= self.delta {
568                    grad[[i, j]] = diff / n;
569                } else {
570                    grad[[i, j]] = self.delta * diff.signum() / n;
571                }
572            }
573        }
574
575        Ok(grad)
576    }
577}
578
579/// Binary cross-entropy with logits loss (numerically stable).
580#[derive(Debug, Clone, Default)]
581pub struct BCEWithLogitsLoss;
582
583impl Loss for BCEWithLogitsLoss {
584    fn compute(
585        &self,
586        logits: &ArrayView<f64, Ix2>,
587        targets: &ArrayView<f64, Ix2>,
588    ) -> TrainResult<f64> {
589        if logits.shape() != targets.shape() {
590            return Err(TrainError::LossError(format!(
591                "Shape mismatch: logits {:?} vs targets {:?}",
592                logits.shape(),
593                targets.shape()
594            )));
595        }
596
597        let n = logits.len() as f64;
598        let mut total_loss = 0.0;
599
600        for i in 0..logits.nrows() {
601            for j in 0..logits.ncols() {
602                let logit = logits[[i, j]];
603                let target = targets[[i, j]];
604
605                // Numerically stable BCE: max(x, 0) - x * z + log(1 + exp(-|x|))
606                // where x = logit, z = target
607                let max_val = logit.max(0.0);
608                total_loss += max_val - logit * target + (1.0 + (-logit.abs()).exp()).ln();
609            }
610        }
611
612        Ok(total_loss / n)
613    }
614
615    fn gradient(
616        &self,
617        logits: &ArrayView<f64, Ix2>,
618        targets: &ArrayView<f64, Ix2>,
619    ) -> TrainResult<Array<f64, Ix2>> {
620        if logits.shape() != targets.shape() {
621            return Err(TrainError::LossError(format!(
622                "Shape mismatch: logits {:?} vs targets {:?}",
623                logits.shape(),
624                targets.shape()
625            )));
626        }
627
628        let n = logits.len() as f64;
629        let mut grad = Array::zeros(logits.raw_dim());
630
631        for i in 0..logits.nrows() {
632            for j in 0..logits.ncols() {
633                let logit = logits[[i, j]];
634                let target = targets[[i, j]];
635
636                // Gradient: sigmoid(logit) - target
637                let sigmoid = 1.0 / (1.0 + (-logit).exp());
638                grad[[i, j]] = (sigmoid - target) / n;
639            }
640        }
641
642        Ok(grad)
643    }
644}
645
646/// Dice loss for segmentation tasks.
647#[derive(Debug, Clone)]
648pub struct DiceLoss {
649    /// Smoothing factor to avoid division by zero.
650    pub smooth: f64,
651}
652
653impl Default for DiceLoss {
654    fn default() -> Self {
655        Self { smooth: 1.0 }
656    }
657}
658
659impl Loss for DiceLoss {
660    fn compute(
661        &self,
662        predictions: &ArrayView<f64, Ix2>,
663        targets: &ArrayView<f64, Ix2>,
664    ) -> TrainResult<f64> {
665        if predictions.shape() != targets.shape() {
666            return Err(TrainError::LossError(format!(
667                "Shape mismatch: predictions {:?} vs targets {:?}",
668                predictions.shape(),
669                targets.shape()
670            )));
671        }
672
673        let mut intersection = 0.0;
674        let mut pred_sum = 0.0;
675        let mut target_sum = 0.0;
676
677        for i in 0..predictions.nrows() {
678            for j in 0..predictions.ncols() {
679                let pred = predictions[[i, j]];
680                let target = targets[[i, j]];
681
682                intersection += pred * target;
683                pred_sum += pred;
684                target_sum += target;
685            }
686        }
687
688        // Dice coefficient: 2 * |X ∩ Y| / (|X| + |Y|)
689        // Dice loss: 1 - Dice coefficient
690        let dice_coef = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth);
691        Ok(1.0 - dice_coef)
692    }
693
694    fn gradient(
695        &self,
696        predictions: &ArrayView<f64, Ix2>,
697        targets: &ArrayView<f64, Ix2>,
698    ) -> TrainResult<Array<f64, Ix2>> {
699        if predictions.shape() != targets.shape() {
700            return Err(TrainError::LossError(format!(
701                "Shape mismatch: predictions {:?} vs targets {:?}",
702                predictions.shape(),
703                targets.shape()
704            )));
705        }
706
707        let mut intersection = 0.0;
708        let mut pred_sum = 0.0;
709        let mut target_sum = 0.0;
710
711        for i in 0..predictions.nrows() {
712            for j in 0..predictions.ncols() {
713                intersection += predictions[[i, j]] * targets[[i, j]];
714                pred_sum += predictions[[i, j]];
715                target_sum += targets[[i, j]];
716            }
717        }
718
719        let denominator = pred_sum + target_sum + self.smooth;
720        let numerator = 2.0 * intersection + self.smooth;
721
722        let mut grad = Array::zeros(predictions.raw_dim());
723
724        for i in 0..predictions.nrows() {
725            for j in 0..predictions.ncols() {
726                let target = targets[[i, j]];
727                // Gradient of Dice loss w.r.t. predictions
728                grad[[i, j]] =
729                    -2.0 * (target * denominator - numerator) / (denominator * denominator);
730            }
731        }
732
733        Ok(grad)
734    }
735}
736
737/// Tversky loss (generalization of Dice loss).
738/// Useful for handling class imbalance in segmentation.
739#[derive(Debug, Clone)]
740pub struct TverskyLoss {
741    /// Alpha parameter (weight for false positives).
742    pub alpha: f64,
743    /// Beta parameter (weight for false negatives).
744    pub beta: f64,
745    /// Smoothing factor.
746    pub smooth: f64,
747}
748
749impl Default for TverskyLoss {
750    fn default() -> Self {
751        Self {
752            alpha: 0.5,
753            beta: 0.5,
754            smooth: 1.0,
755        }
756    }
757}
758
759impl Loss for TverskyLoss {
760    fn compute(
761        &self,
762        predictions: &ArrayView<f64, Ix2>,
763        targets: &ArrayView<f64, Ix2>,
764    ) -> TrainResult<f64> {
765        if predictions.shape() != targets.shape() {
766            return Err(TrainError::LossError(format!(
767                "Shape mismatch: predictions {:?} vs targets {:?}",
768                predictions.shape(),
769                targets.shape()
770            )));
771        }
772
773        let mut true_pos = 0.0;
774        let mut false_pos = 0.0;
775        let mut false_neg = 0.0;
776
777        for i in 0..predictions.nrows() {
778            for j in 0..predictions.ncols() {
779                let pred = predictions[[i, j]];
780                let target = targets[[i, j]];
781
782                true_pos += pred * target;
783                false_pos += pred * (1.0 - target);
784                false_neg += (1.0 - pred) * target;
785            }
786        }
787
788        // Tversky index: TP / (TP + alpha * FP + beta * FN)
789        let tversky_index = (true_pos + self.smooth)
790            / (true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth);
791
792        Ok(1.0 - tversky_index)
793    }
794
795    fn gradient(
796        &self,
797        predictions: &ArrayView<f64, Ix2>,
798        targets: &ArrayView<f64, Ix2>,
799    ) -> TrainResult<Array<f64, Ix2>> {
800        if predictions.shape() != targets.shape() {
801            return Err(TrainError::LossError(format!(
802                "Shape mismatch: predictions {:?} vs targets {:?}",
803                predictions.shape(),
804                targets.shape()
805            )));
806        }
807
808        let mut true_pos = 0.0;
809        let mut false_pos = 0.0;
810        let mut false_neg = 0.0;
811
812        for i in 0..predictions.nrows() {
813            for j in 0..predictions.ncols() {
814                let pred = predictions[[i, j]];
815                let target = targets[[i, j]];
816
817                true_pos += pred * target;
818                false_pos += pred * (1.0 - target);
819                false_neg += (1.0 - pred) * target;
820            }
821        }
822
823        let denominator = true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth;
824        let numerator = true_pos + self.smooth;
825
826        let mut grad = Array::zeros(predictions.raw_dim());
827
828        for i in 0..predictions.nrows() {
829            for j in 0..predictions.ncols() {
830                let target = targets[[i, j]];
831
832                // Gradient of Tversky loss
833                let d_tp = target;
834                let d_fp = self.alpha * (1.0 - target);
835                let d_fn = -self.beta * target;
836
837                grad[[i, j]] = -(d_tp * denominator - numerator * (d_tp + d_fp + d_fn))
838                    / (denominator * denominator);
839            }
840        }
841
842        Ok(grad)
843    }
844}
845
846/// Contrastive loss for metric learning.
847/// Used to learn embeddings where similar pairs are close and dissimilar pairs are far apart.
848#[derive(Debug, Clone)]
849pub struct ContrastiveLoss {
850    /// Margin for dissimilar pairs.
851    pub margin: f64,
852}
853
854impl Default for ContrastiveLoss {
855    fn default() -> Self {
856        Self { margin: 1.0 }
857    }
858}
859
860impl Loss for ContrastiveLoss {
861    fn compute(
862        &self,
863        predictions: &ArrayView<f64, Ix2>,
864        targets: &ArrayView<f64, Ix2>,
865    ) -> TrainResult<f64> {
866        if predictions.ncols() != 2 || targets.ncols() != 1 {
867            return Err(TrainError::LossError(format!(
868                "ContrastiveLoss expects predictions shape [N, 2] (distances) and targets shape [N, 1] (labels), got {:?} and {:?}",
869                predictions.shape(),
870                targets.shape()
871            )));
872        }
873
874        let mut total_loss = 0.0;
875        let n = predictions.nrows() as f64;
876
877        for i in 0..predictions.nrows() {
878            let distance = predictions[[i, 0]];
879            let label = targets[[i, 0]];
880
881            if label > 0.5 {
882                // Similar pair: minimize distance
883                total_loss += distance * distance;
884            } else {
885                // Dissimilar pair: maximize distance up to margin
886                total_loss += (self.margin - distance).max(0.0).powi(2);
887            }
888        }
889
890        Ok(total_loss / n)
891    }
892
893    fn gradient(
894        &self,
895        predictions: &ArrayView<f64, Ix2>,
896        targets: &ArrayView<f64, Ix2>,
897    ) -> TrainResult<Array<f64, Ix2>> {
898        let mut grad = Array::zeros(predictions.raw_dim());
899        let n = predictions.nrows() as f64;
900
901        for i in 0..predictions.nrows() {
902            let distance = predictions[[i, 0]];
903            let label = targets[[i, 0]];
904
905            if label > 0.5 {
906                // Similar pair gradient
907                grad[[i, 0]] = 2.0 * distance / n;
908            } else {
909                // Dissimilar pair gradient
910                if distance < self.margin {
911                    grad[[i, 0]] = -2.0 * (self.margin - distance) / n;
912                }
913            }
914        }
915
916        Ok(grad)
917    }
918}
919
920/// Triplet loss for metric learning.
921/// Learns embeddings where anchor-positive distance < anchor-negative distance + margin.
922#[derive(Debug, Clone)]
923pub struct TripletLoss {
924    /// Margin between positive and negative distances.
925    pub margin: f64,
926}
927
928impl Default for TripletLoss {
929    fn default() -> Self {
930        Self { margin: 1.0 }
931    }
932}
933
934impl Loss for TripletLoss {
935    fn compute(
936        &self,
937        predictions: &ArrayView<f64, Ix2>,
938        _targets: &ArrayView<f64, Ix2>,
939    ) -> TrainResult<f64> {
940        if predictions.ncols() != 2 {
941            return Err(TrainError::LossError(format!(
942                "TripletLoss expects predictions shape [N, 2] (pos_dist, neg_dist), got {:?}",
943                predictions.shape()
944            )));
945        }
946
947        let mut total_loss = 0.0;
948        let n = predictions.nrows() as f64;
949
950        for i in 0..predictions.nrows() {
951            let pos_distance = predictions[[i, 0]];
952            let neg_distance = predictions[[i, 1]];
953
954            // Loss = max(0, pos_dist - neg_dist + margin)
955            let loss = (pos_distance - neg_distance + self.margin).max(0.0);
956            total_loss += loss;
957        }
958
959        Ok(total_loss / n)
960    }
961
962    fn gradient(
963        &self,
964        predictions: &ArrayView<f64, Ix2>,
965        _targets: &ArrayView<f64, Ix2>,
966    ) -> TrainResult<Array<f64, Ix2>> {
967        let mut grad = Array::zeros(predictions.raw_dim());
968        let n = predictions.nrows() as f64;
969
970        for i in 0..predictions.nrows() {
971            let pos_distance = predictions[[i, 0]];
972            let neg_distance = predictions[[i, 1]];
973
974            if pos_distance - neg_distance + self.margin > 0.0 {
975                // Gradient w.r.t. positive distance
976                grad[[i, 0]] = 1.0 / n;
977                // Gradient w.r.t. negative distance
978                grad[[i, 1]] = -1.0 / n;
979            }
980        }
981
982        Ok(grad)
983    }
984}
985
986/// Hinge loss for maximum-margin classification (SVM-style).
987#[derive(Debug, Clone)]
988pub struct HingeLoss {
989    /// Margin for classification.
990    pub margin: f64,
991}
992
993impl Default for HingeLoss {
994    fn default() -> Self {
995        Self { margin: 1.0 }
996    }
997}
998
999impl Loss for HingeLoss {
1000    fn compute(
1001        &self,
1002        predictions: &ArrayView<f64, Ix2>,
1003        targets: &ArrayView<f64, Ix2>,
1004    ) -> TrainResult<f64> {
1005        if predictions.shape() != targets.shape() {
1006            return Err(TrainError::LossError(format!(
1007                "Shape mismatch: predictions {:?} vs targets {:?}",
1008                predictions.shape(),
1009                targets.shape()
1010            )));
1011        }
1012
1013        let mut total_loss = 0.0;
1014        let n = predictions.nrows() as f64;
1015
1016        for i in 0..predictions.nrows() {
1017            for j in 0..predictions.ncols() {
1018                let pred = predictions[[i, j]];
1019                let target = targets[[i, j]];
1020
1021                // targets should be +1 or -1
1022                let loss = (self.margin - target * pred).max(0.0);
1023                total_loss += loss;
1024            }
1025        }
1026
1027        Ok(total_loss / n)
1028    }
1029
1030    fn gradient(
1031        &self,
1032        predictions: &ArrayView<f64, Ix2>,
1033        targets: &ArrayView<f64, Ix2>,
1034    ) -> TrainResult<Array<f64, Ix2>> {
1035        let mut grad = Array::zeros(predictions.raw_dim());
1036        let n = predictions.nrows() as f64;
1037
1038        for i in 0..predictions.nrows() {
1039            for j in 0..predictions.ncols() {
1040                let pred = predictions[[i, j]];
1041                let target = targets[[i, j]];
1042
1043                if self.margin - target * pred > 0.0 {
1044                    grad[[i, j]] = -target / n;
1045                }
1046            }
1047        }
1048
1049        Ok(grad)
1050    }
1051}
1052
1053/// Kullback-Leibler Divergence loss.
1054/// Measures how one probability distribution diverges from a reference distribution.
1055#[derive(Debug, Clone)]
1056pub struct KLDivergenceLoss {
1057    /// Epsilon for numerical stability.
1058    pub epsilon: f64,
1059}
1060
1061impl Default for KLDivergenceLoss {
1062    fn default() -> Self {
1063        Self { epsilon: 1e-10 }
1064    }
1065}
1066
1067impl Loss for KLDivergenceLoss {
1068    fn compute(
1069        &self,
1070        predictions: &ArrayView<f64, Ix2>,
1071        targets: &ArrayView<f64, Ix2>,
1072    ) -> TrainResult<f64> {
1073        if predictions.shape() != targets.shape() {
1074            return Err(TrainError::LossError(format!(
1075                "Shape mismatch: predictions {:?} vs targets {:?}",
1076                predictions.shape(),
1077                targets.shape()
1078            )));
1079        }
1080
1081        let mut total_loss = 0.0;
1082
1083        for i in 0..predictions.nrows() {
1084            for j in 0..predictions.ncols() {
1085                let pred = predictions[[i, j]].max(self.epsilon);
1086                let target = targets[[i, j]].max(self.epsilon);
1087
1088                // KL(target || pred) = sum(target * log(target / pred))
1089                total_loss += target * (target / pred).ln();
1090            }
1091        }
1092
1093        Ok(total_loss)
1094    }
1095
1096    fn gradient(
1097        &self,
1098        predictions: &ArrayView<f64, Ix2>,
1099        targets: &ArrayView<f64, Ix2>,
1100    ) -> TrainResult<Array<f64, Ix2>> {
1101        let mut grad = Array::zeros(predictions.raw_dim());
1102
1103        for i in 0..predictions.nrows() {
1104            for j in 0..predictions.ncols() {
1105                let pred = predictions[[i, j]].max(self.epsilon);
1106                let target = targets[[i, j]].max(self.epsilon);
1107
1108                // Gradient of KL divergence w.r.t. predictions
1109                grad[[i, j]] = -target / pred;
1110            }
1111        }
1112
1113        Ok(grad)
1114    }
1115}
1116
1117#[cfg(test)]
1118mod tests {
1119    use super::*;
1120    use scirs2_core::ndarray::array;
1121
1122    #[test]
1123    fn test_cross_entropy_loss() {
1124        let loss = CrossEntropyLoss::default();
1125        let predictions = array![[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]];
1126        let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
1127
1128        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1129        assert!(loss_val > 0.0);
1130
1131        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1132        assert_eq!(grad.shape(), predictions.shape());
1133    }
1134
1135    #[test]
1136    fn test_mse_loss() {
1137        let loss = MseLoss;
1138        let predictions = array![[1.0, 2.0], [3.0, 4.0]];
1139        let targets = array![[1.5, 2.5], [3.5, 4.5]];
1140
1141        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1142        assert!((loss_val - 0.25).abs() < 1e-6);
1143
1144        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1145        assert_eq!(grad.shape(), predictions.shape());
1146    }
1147
1148    #[test]
1149    fn test_rule_satisfaction_loss() {
1150        let loss = RuleSatisfactionLoss::default();
1151        let rule_values = array![[0.9, 0.8], [0.95, 0.85]];
1152        let targets = array![[1.0, 1.0], [1.0, 1.0]];
1153
1154        let loss_val = loss.compute(&rule_values.view(), &targets.view()).unwrap();
1155        assert!(loss_val > 0.0);
1156
1157        let grad = loss.gradient(&rule_values.view(), &targets.view()).unwrap();
1158        assert_eq!(grad.shape(), rule_values.shape());
1159    }
1160
1161    #[test]
1162    fn test_constraint_violation_loss() {
1163        let loss = ConstraintViolationLoss::default();
1164        let constraint_values = array![[0.1, -0.1], [0.2, -0.2]];
1165        let targets = array![[0.0, 0.0], [0.0, 0.0]];
1166
1167        let loss_val = loss
1168            .compute(&constraint_values.view(), &targets.view())
1169            .unwrap();
1170        assert!(loss_val > 0.0);
1171
1172        let grad = loss
1173            .gradient(&constraint_values.view(), &targets.view())
1174            .unwrap();
1175        assert_eq!(grad.shape(), constraint_values.shape());
1176    }
1177
1178    #[test]
1179    fn test_focal_loss() {
1180        let loss = FocalLoss::default();
1181        let predictions = array![[0.9, 0.1], [0.2, 0.8]];
1182        let targets = array![[1.0, 0.0], [0.0, 1.0]];
1183
1184        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1185        assert!(loss_val >= 0.0);
1186
1187        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1188        assert_eq!(grad.shape(), predictions.shape());
1189    }
1190
1191    #[test]
1192    fn test_huber_loss() {
1193        let loss = HuberLoss::default();
1194        let predictions = array![[1.0, 3.0], [2.0, 5.0]];
1195        let targets = array![[1.5, 2.0], [2.5, 4.0]];
1196
1197        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1198        assert!(loss_val > 0.0);
1199
1200        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1201        assert_eq!(grad.shape(), predictions.shape());
1202    }
1203
1204    #[test]
1205    fn test_bce_with_logits_loss() {
1206        let loss = BCEWithLogitsLoss;
1207        let logits = array![[0.5, -0.5], [1.0, -1.0]];
1208        let targets = array![[1.0, 0.0], [1.0, 0.0]];
1209
1210        let loss_val = loss.compute(&logits.view(), &targets.view()).unwrap();
1211        assert!(loss_val >= 0.0);
1212
1213        let grad = loss.gradient(&logits.view(), &targets.view()).unwrap();
1214        assert_eq!(grad.shape(), logits.shape());
1215    }
1216
1217    #[test]
1218    fn test_dice_loss() {
1219        let loss = DiceLoss::default();
1220        let predictions = array![[0.9, 0.1], [0.8, 0.2]];
1221        let targets = array![[1.0, 0.0], [1.0, 0.0]];
1222
1223        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1224        assert!(loss_val >= 0.0);
1225        assert!(loss_val <= 1.0);
1226
1227        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1228        assert_eq!(grad.shape(), predictions.shape());
1229    }
1230
1231    #[test]
1232    fn test_tversky_loss() {
1233        let loss = TverskyLoss::default();
1234        let predictions = array![[0.9, 0.1], [0.8, 0.2]];
1235        let targets = array![[1.0, 0.0], [1.0, 0.0]];
1236
1237        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1238        assert!(loss_val >= 0.0);
1239        assert!(loss_val <= 1.0);
1240
1241        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1242        assert_eq!(grad.shape(), predictions.shape());
1243    }
1244
1245    #[test]
1246    fn test_contrastive_loss() {
1247        let loss = ContrastiveLoss::default();
1248        // Predictions: [N, 2] where first column is distance, second is unused
1249        // Targets: [N, 1] where 1.0 = similar pair, 0.0 = dissimilar pair
1250        let predictions = array![[0.5, 0.0], [1.5, 0.0], [0.2, 0.0]];
1251        let targets = array![[1.0], [0.0], [1.0]];
1252
1253        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1254        assert!(loss_val >= 0.0);
1255
1256        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1257        assert_eq!(grad.shape(), predictions.shape());
1258
1259        // For similar pair (label=1), gradient should push distance down
1260        assert!(grad[[0, 0]] > 0.0);
1261        // For dissimilar pair beyond margin, gradient should be 0
1262        assert_eq!(grad[[1, 0]], 0.0);
1263    }
1264
1265    #[test]
1266    fn test_triplet_loss() {
1267        let loss = TripletLoss::default();
1268        // Predictions: [N, 2] where columns are (positive_distance, negative_distance)
1269        let predictions = array![[0.5, 2.0], [1.0, 0.5], [0.3, 1.5]];
1270        let targets = array![[0.0], [0.0], [0.0]]; // Not used but required for interface
1271
1272        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1273        assert!(loss_val >= 0.0);
1274
1275        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1276        assert_eq!(grad.shape(), predictions.shape());
1277
1278        // First triplet: pos_dist < neg_dist - margin, so no gradient
1279        assert_eq!(grad[[0, 0]], 0.0);
1280        assert_eq!(grad[[0, 1]], 0.0);
1281
1282        // Second triplet: pos_dist > neg_dist, so should have gradient
1283        assert!(grad[[1, 0]] > 0.0);
1284        assert!(grad[[1, 1]] < 0.0);
1285    }
1286
1287    #[test]
1288    fn test_hinge_loss() {
1289        let loss = HingeLoss::default();
1290        // Predictions are raw scores, targets should be +1 or -1
1291        let predictions = array![[0.5, -0.5], [2.0, -2.0]];
1292        let targets = array![[1.0, -1.0], [1.0, -1.0]];
1293
1294        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1295        assert!(loss_val >= 0.0);
1296
1297        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1298        assert_eq!(grad.shape(), predictions.shape());
1299
1300        // For correct predictions with large margin, gradient should be 0
1301        assert_eq!(grad[[1, 0]], 0.0);
1302        assert_eq!(grad[[1, 1]], 0.0);
1303    }
1304
1305    #[test]
1306    fn test_kl_divergence_loss() {
1307        let loss = KLDivergenceLoss::default();
1308        // Both predictions and targets should be probability distributions
1309        let predictions = array![[0.6, 0.4], [0.7, 0.3]];
1310        let targets = array![[0.5, 0.5], [0.8, 0.2]];
1311
1312        let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1313        assert!(loss_val >= 0.0);
1314
1315        let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1316        assert_eq!(grad.shape(), predictions.shape());
1317
1318        // KL divergence should be 0 when distributions are identical
1319        let identical_preds = array![[0.5, 0.5]];
1320        let identical_targets = array![[0.5, 0.5]];
1321        let identical_loss = loss
1322            .compute(&identical_preds.view(), &identical_targets.view())
1323            .unwrap();
1324        assert!(identical_loss.abs() < 1e-6);
1325    }
1326}