1use crate::{TrainError, TrainResult};
6use scirs2_core::ndarray::{Array, ArrayView, Ix2};
7use std::fmt::Debug;
8
9#[derive(Debug, Clone)]
11pub struct LossConfig {
12 pub supervised_weight: f64,
14 pub constraint_weight: f64,
16 pub rule_weight: f64,
18 pub temperature: f64,
20}
21
22impl Default for LossConfig {
23 fn default() -> Self {
24 Self {
25 supervised_weight: 1.0,
26 constraint_weight: 1.0,
27 rule_weight: 1.0,
28 temperature: 1.0,
29 }
30 }
31}
32
33pub trait Loss: Debug {
35 fn compute(
37 &self,
38 predictions: &ArrayView<f64, Ix2>,
39 targets: &ArrayView<f64, Ix2>,
40 ) -> TrainResult<f64>;
41
42 fn gradient(
44 &self,
45 predictions: &ArrayView<f64, Ix2>,
46 targets: &ArrayView<f64, Ix2>,
47 ) -> TrainResult<Array<f64, Ix2>>;
48}
49
50#[derive(Debug, Clone)]
52pub struct CrossEntropyLoss {
53 pub epsilon: f64,
55}
56
57impl Default for CrossEntropyLoss {
58 fn default() -> Self {
59 Self { epsilon: 1e-10 }
60 }
61}
62
63impl Loss for CrossEntropyLoss {
64 fn compute(
65 &self,
66 predictions: &ArrayView<f64, Ix2>,
67 targets: &ArrayView<f64, Ix2>,
68 ) -> TrainResult<f64> {
69 if predictions.shape() != targets.shape() {
70 return Err(TrainError::LossError(format!(
71 "Shape mismatch: predictions {:?} vs targets {:?}",
72 predictions.shape(),
73 targets.shape()
74 )));
75 }
76
77 let n = predictions.nrows() as f64;
78 let mut total_loss = 0.0;
79
80 for i in 0..predictions.nrows() {
81 for j in 0..predictions.ncols() {
82 let pred = predictions[[i, j]]
83 .max(self.epsilon)
84 .min(1.0 - self.epsilon);
85 let target = targets[[i, j]];
86 total_loss -= target * pred.ln();
87 }
88 }
89
90 Ok(total_loss / n)
91 }
92
93 fn gradient(
94 &self,
95 predictions: &ArrayView<f64, Ix2>,
96 targets: &ArrayView<f64, Ix2>,
97 ) -> TrainResult<Array<f64, Ix2>> {
98 if predictions.shape() != targets.shape() {
99 return Err(TrainError::LossError(format!(
100 "Shape mismatch: predictions {:?} vs targets {:?}",
101 predictions.shape(),
102 targets.shape()
103 )));
104 }
105
106 let n = predictions.nrows() as f64;
107 let mut grad = Array::zeros(predictions.raw_dim());
108
109 for i in 0..predictions.nrows() {
110 for j in 0..predictions.ncols() {
111 let pred = predictions[[i, j]]
112 .max(self.epsilon)
113 .min(1.0 - self.epsilon);
114 let target = targets[[i, j]];
115 grad[[i, j]] = -(target / pred) / n;
116 }
117 }
118
119 Ok(grad)
120 }
121}
122
123#[derive(Debug, Clone, Default)]
125pub struct MseLoss;
126
127impl Loss for MseLoss {
128 fn compute(
129 &self,
130 predictions: &ArrayView<f64, Ix2>,
131 targets: &ArrayView<f64, Ix2>,
132 ) -> TrainResult<f64> {
133 if predictions.shape() != targets.shape() {
134 return Err(TrainError::LossError(format!(
135 "Shape mismatch: predictions {:?} vs targets {:?}",
136 predictions.shape(),
137 targets.shape()
138 )));
139 }
140
141 let n = predictions.len() as f64;
142 let mut total_loss = 0.0;
143
144 for i in 0..predictions.nrows() {
145 for j in 0..predictions.ncols() {
146 let diff = predictions[[i, j]] - targets[[i, j]];
147 total_loss += diff * diff;
148 }
149 }
150
151 Ok(total_loss / n)
152 }
153
154 fn gradient(
155 &self,
156 predictions: &ArrayView<f64, Ix2>,
157 targets: &ArrayView<f64, Ix2>,
158 ) -> TrainResult<Array<f64, Ix2>> {
159 if predictions.shape() != targets.shape() {
160 return Err(TrainError::LossError(format!(
161 "Shape mismatch: predictions {:?} vs targets {:?}",
162 predictions.shape(),
163 targets.shape()
164 )));
165 }
166
167 let n = predictions.len() as f64;
168 let mut grad = Array::zeros(predictions.raw_dim());
169
170 for i in 0..predictions.nrows() {
171 for j in 0..predictions.ncols() {
172 grad[[i, j]] = 2.0 * (predictions[[i, j]] - targets[[i, j]]) / n;
173 }
174 }
175
176 Ok(grad)
177 }
178}
179
180#[derive(Debug)]
182pub struct LogicalLoss {
183 pub config: LossConfig,
185 pub supervised_loss: Box<dyn Loss>,
187 pub rule_losses: Vec<Box<dyn Loss>>,
189 pub constraint_losses: Vec<Box<dyn Loss>>,
191}
192
193impl LogicalLoss {
194 pub fn new(
196 config: LossConfig,
197 supervised_loss: Box<dyn Loss>,
198 rule_losses: Vec<Box<dyn Loss>>,
199 constraint_losses: Vec<Box<dyn Loss>>,
200 ) -> Self {
201 Self {
202 config,
203 supervised_loss,
204 rule_losses,
205 constraint_losses,
206 }
207 }
208
209 pub fn compute_total(
211 &self,
212 predictions: &ArrayView<f64, Ix2>,
213 targets: &ArrayView<f64, Ix2>,
214 rule_values: &[ArrayView<f64, Ix2>],
215 constraint_values: &[ArrayView<f64, Ix2>],
216 ) -> TrainResult<f64> {
217 let mut total = 0.0;
218
219 let supervised = self.supervised_loss.compute(predictions, targets)?;
221 total += self.config.supervised_weight * supervised;
222
223 if !rule_values.is_empty() && !self.rule_losses.is_empty() {
225 let expected_true = Array::ones((rule_values[0].nrows(), rule_values[0].ncols()));
226 let expected_true_view = expected_true.view();
227
228 for (rule_val, rule_loss) in rule_values.iter().zip(self.rule_losses.iter()) {
229 let rule_loss_val = rule_loss.compute(rule_val, &expected_true_view)?;
230 total += self.config.rule_weight * rule_loss_val;
231 }
232 }
233
234 if !constraint_values.is_empty() && !self.constraint_losses.is_empty() {
236 let expected_zero =
237 Array::zeros((constraint_values[0].nrows(), constraint_values[0].ncols()));
238 let expected_zero_view = expected_zero.view();
239
240 for (constraint_val, constraint_loss) in
241 constraint_values.iter().zip(self.constraint_losses.iter())
242 {
243 let constraint_loss_val =
244 constraint_loss.compute(constraint_val, &expected_zero_view)?;
245 total += self.config.constraint_weight * constraint_loss_val;
246 }
247 }
248
249 Ok(total)
250 }
251}
252
253#[derive(Debug, Clone)]
255pub struct RuleSatisfactionLoss {
256 pub temperature: f64,
258}
259
260impl Default for RuleSatisfactionLoss {
261 fn default() -> Self {
262 Self { temperature: 1.0 }
263 }
264}
265
266impl Loss for RuleSatisfactionLoss {
267 fn compute(
268 &self,
269 rule_values: &ArrayView<f64, Ix2>,
270 targets: &ArrayView<f64, Ix2>,
271 ) -> TrainResult<f64> {
272 if rule_values.shape() != targets.shape() {
273 return Err(TrainError::LossError(format!(
274 "Shape mismatch: rule_values {:?} vs targets {:?}",
275 rule_values.shape(),
276 targets.shape()
277 )));
278 }
279
280 let n = rule_values.len() as f64;
281 let mut total_loss = 0.0;
282
283 for i in 0..rule_values.nrows() {
285 for j in 0..rule_values.ncols() {
286 let diff = targets[[i, j]] - rule_values[[i, j]];
287 total_loss += (diff / self.temperature).powi(2);
289 }
290 }
291
292 Ok(total_loss / n)
293 }
294
295 fn gradient(
296 &self,
297 rule_values: &ArrayView<f64, Ix2>,
298 targets: &ArrayView<f64, Ix2>,
299 ) -> TrainResult<Array<f64, Ix2>> {
300 if rule_values.shape() != targets.shape() {
301 return Err(TrainError::LossError(format!(
302 "Shape mismatch: rule_values {:?} vs targets {:?}",
303 rule_values.shape(),
304 targets.shape()
305 )));
306 }
307
308 let n = rule_values.len() as f64;
309 let mut grad = Array::zeros(rule_values.raw_dim());
310
311 for i in 0..rule_values.nrows() {
312 for j in 0..rule_values.ncols() {
313 let diff = targets[[i, j]] - rule_values[[i, j]];
314 grad[[i, j]] = -2.0 * diff / (self.temperature * self.temperature * n);
315 }
316 }
317
318 Ok(grad)
319 }
320}
321
322#[derive(Debug, Clone)]
324pub struct ConstraintViolationLoss {
325 pub penalty_weight: f64,
327}
328
329impl Default for ConstraintViolationLoss {
330 fn default() -> Self {
331 Self {
332 penalty_weight: 10.0,
333 }
334 }
335}
336
337impl Loss for ConstraintViolationLoss {
338 fn compute(
339 &self,
340 constraint_values: &ArrayView<f64, Ix2>,
341 targets: &ArrayView<f64, Ix2>,
342 ) -> TrainResult<f64> {
343 if constraint_values.shape() != targets.shape() {
344 return Err(TrainError::LossError(format!(
345 "Shape mismatch: constraint_values {:?} vs targets {:?}",
346 constraint_values.shape(),
347 targets.shape()
348 )));
349 }
350
351 let n = constraint_values.len() as f64;
352 let mut total_loss = 0.0;
353
354 for i in 0..constraint_values.nrows() {
356 for j in 0..constraint_values.ncols() {
357 let violation = (constraint_values[[i, j]] - targets[[i, j]]).max(0.0);
358 total_loss += self.penalty_weight * violation * violation;
359 }
360 }
361
362 Ok(total_loss / n)
363 }
364
365 fn gradient(
366 &self,
367 constraint_values: &ArrayView<f64, Ix2>,
368 targets: &ArrayView<f64, Ix2>,
369 ) -> TrainResult<Array<f64, Ix2>> {
370 if constraint_values.shape() != targets.shape() {
371 return Err(TrainError::LossError(format!(
372 "Shape mismatch: constraint_values {:?} vs targets {:?}",
373 constraint_values.shape(),
374 targets.shape()
375 )));
376 }
377
378 let n = constraint_values.len() as f64;
379 let mut grad = Array::zeros(constraint_values.raw_dim());
380
381 for i in 0..constraint_values.nrows() {
382 for j in 0..constraint_values.ncols() {
383 let violation = constraint_values[[i, j]] - targets[[i, j]];
384 if violation > 0.0 {
385 grad[[i, j]] = 2.0 * self.penalty_weight * violation / n;
386 }
387 }
388 }
389
390 Ok(grad)
391 }
392}
393
394#[derive(Debug, Clone)]
397pub struct FocalLoss {
398 pub alpha: f64,
400 pub gamma: f64,
402 pub epsilon: f64,
404}
405
406impl Default for FocalLoss {
407 fn default() -> Self {
408 Self {
409 alpha: 0.25,
410 gamma: 2.0,
411 epsilon: 1e-10,
412 }
413 }
414}
415
416impl Loss for FocalLoss {
417 fn compute(
418 &self,
419 predictions: &ArrayView<f64, Ix2>,
420 targets: &ArrayView<f64, Ix2>,
421 ) -> TrainResult<f64> {
422 if predictions.shape() != targets.shape() {
423 return Err(TrainError::LossError(format!(
424 "Shape mismatch: predictions {:?} vs targets {:?}",
425 predictions.shape(),
426 targets.shape()
427 )));
428 }
429
430 let n = predictions.nrows() as f64;
431 let mut total_loss = 0.0;
432
433 for i in 0..predictions.nrows() {
434 for j in 0..predictions.ncols() {
435 let pred = predictions[[i, j]]
436 .max(self.epsilon)
437 .min(1.0 - self.epsilon);
438 let target = targets[[i, j]];
439
440 if target > 0.5 {
443 let focal_weight = (1.0 - pred).powf(self.gamma);
445 total_loss -= self.alpha * focal_weight * pred.ln();
446 } else {
447 let focal_weight = pred.powf(self.gamma);
449 total_loss -= (1.0 - self.alpha) * focal_weight * (1.0 - pred).ln();
450 }
451 }
452 }
453
454 Ok(total_loss / n)
455 }
456
457 fn gradient(
458 &self,
459 predictions: &ArrayView<f64, Ix2>,
460 targets: &ArrayView<f64, Ix2>,
461 ) -> TrainResult<Array<f64, Ix2>> {
462 if predictions.shape() != targets.shape() {
463 return Err(TrainError::LossError(format!(
464 "Shape mismatch: predictions {:?} vs targets {:?}",
465 predictions.shape(),
466 targets.shape()
467 )));
468 }
469
470 let n = predictions.nrows() as f64;
471 let mut grad = Array::zeros(predictions.raw_dim());
472
473 for i in 0..predictions.nrows() {
474 for j in 0..predictions.ncols() {
475 let pred = predictions[[i, j]]
476 .max(self.epsilon)
477 .min(1.0 - self.epsilon);
478 let target = targets[[i, j]];
479
480 if target > 0.5 {
481 let focal_weight = (1.0 - pred).powf(self.gamma);
483 let d_focal = self.gamma * (1.0 - pred).powf(self.gamma - 1.0);
484 grad[[i, j]] = -self.alpha * (focal_weight / pred - d_focal * pred.ln()) / n;
485 } else {
486 let focal_weight = pred.powf(self.gamma);
488 let d_focal = self.gamma * pred.powf(self.gamma - 1.0);
489 grad[[i, j]] = -(1.0 - self.alpha)
490 * (d_focal * (1.0 - pred).ln() - focal_weight / (1.0 - pred))
491 / n;
492 }
493 }
494 }
495
496 Ok(grad)
497 }
498}
499
500#[derive(Debug, Clone)]
502pub struct HuberLoss {
503 pub delta: f64,
505}
506
507impl Default for HuberLoss {
508 fn default() -> Self {
509 Self { delta: 1.0 }
510 }
511}
512
513impl Loss for HuberLoss {
514 fn compute(
515 &self,
516 predictions: &ArrayView<f64, Ix2>,
517 targets: &ArrayView<f64, Ix2>,
518 ) -> TrainResult<f64> {
519 if predictions.shape() != targets.shape() {
520 return Err(TrainError::LossError(format!(
521 "Shape mismatch: predictions {:?} vs targets {:?}",
522 predictions.shape(),
523 targets.shape()
524 )));
525 }
526
527 let n = predictions.len() as f64;
528 let mut total_loss = 0.0;
529
530 for i in 0..predictions.nrows() {
531 for j in 0..predictions.ncols() {
532 let diff = (predictions[[i, j]] - targets[[i, j]]).abs();
533 if diff <= self.delta {
534 total_loss += 0.5 * diff * diff;
536 } else {
537 total_loss += self.delta * (diff - 0.5 * self.delta);
539 }
540 }
541 }
542
543 Ok(total_loss / n)
544 }
545
546 fn gradient(
547 &self,
548 predictions: &ArrayView<f64, Ix2>,
549 targets: &ArrayView<f64, Ix2>,
550 ) -> TrainResult<Array<f64, Ix2>> {
551 if predictions.shape() != targets.shape() {
552 return Err(TrainError::LossError(format!(
553 "Shape mismatch: predictions {:?} vs targets {:?}",
554 predictions.shape(),
555 targets.shape()
556 )));
557 }
558
559 let n = predictions.len() as f64;
560 let mut grad = Array::zeros(predictions.raw_dim());
561
562 for i in 0..predictions.nrows() {
563 for j in 0..predictions.ncols() {
564 let diff = predictions[[i, j]] - targets[[i, j]];
565 let abs_diff = diff.abs();
566
567 if abs_diff <= self.delta {
568 grad[[i, j]] = diff / n;
569 } else {
570 grad[[i, j]] = self.delta * diff.signum() / n;
571 }
572 }
573 }
574
575 Ok(grad)
576 }
577}
578
579#[derive(Debug, Clone, Default)]
581pub struct BCEWithLogitsLoss;
582
583impl Loss for BCEWithLogitsLoss {
584 fn compute(
585 &self,
586 logits: &ArrayView<f64, Ix2>,
587 targets: &ArrayView<f64, Ix2>,
588 ) -> TrainResult<f64> {
589 if logits.shape() != targets.shape() {
590 return Err(TrainError::LossError(format!(
591 "Shape mismatch: logits {:?} vs targets {:?}",
592 logits.shape(),
593 targets.shape()
594 )));
595 }
596
597 let n = logits.len() as f64;
598 let mut total_loss = 0.0;
599
600 for i in 0..logits.nrows() {
601 for j in 0..logits.ncols() {
602 let logit = logits[[i, j]];
603 let target = targets[[i, j]];
604
605 let max_val = logit.max(0.0);
608 total_loss += max_val - logit * target + (1.0 + (-logit.abs()).exp()).ln();
609 }
610 }
611
612 Ok(total_loss / n)
613 }
614
615 fn gradient(
616 &self,
617 logits: &ArrayView<f64, Ix2>,
618 targets: &ArrayView<f64, Ix2>,
619 ) -> TrainResult<Array<f64, Ix2>> {
620 if logits.shape() != targets.shape() {
621 return Err(TrainError::LossError(format!(
622 "Shape mismatch: logits {:?} vs targets {:?}",
623 logits.shape(),
624 targets.shape()
625 )));
626 }
627
628 let n = logits.len() as f64;
629 let mut grad = Array::zeros(logits.raw_dim());
630
631 for i in 0..logits.nrows() {
632 for j in 0..logits.ncols() {
633 let logit = logits[[i, j]];
634 let target = targets[[i, j]];
635
636 let sigmoid = 1.0 / (1.0 + (-logit).exp());
638 grad[[i, j]] = (sigmoid - target) / n;
639 }
640 }
641
642 Ok(grad)
643 }
644}
645
646#[derive(Debug, Clone)]
648pub struct DiceLoss {
649 pub smooth: f64,
651}
652
653impl Default for DiceLoss {
654 fn default() -> Self {
655 Self { smooth: 1.0 }
656 }
657}
658
659impl Loss for DiceLoss {
660 fn compute(
661 &self,
662 predictions: &ArrayView<f64, Ix2>,
663 targets: &ArrayView<f64, Ix2>,
664 ) -> TrainResult<f64> {
665 if predictions.shape() != targets.shape() {
666 return Err(TrainError::LossError(format!(
667 "Shape mismatch: predictions {:?} vs targets {:?}",
668 predictions.shape(),
669 targets.shape()
670 )));
671 }
672
673 let mut intersection = 0.0;
674 let mut pred_sum = 0.0;
675 let mut target_sum = 0.0;
676
677 for i in 0..predictions.nrows() {
678 for j in 0..predictions.ncols() {
679 let pred = predictions[[i, j]];
680 let target = targets[[i, j]];
681
682 intersection += pred * target;
683 pred_sum += pred;
684 target_sum += target;
685 }
686 }
687
688 let dice_coef = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth);
691 Ok(1.0 - dice_coef)
692 }
693
694 fn gradient(
695 &self,
696 predictions: &ArrayView<f64, Ix2>,
697 targets: &ArrayView<f64, Ix2>,
698 ) -> TrainResult<Array<f64, Ix2>> {
699 if predictions.shape() != targets.shape() {
700 return Err(TrainError::LossError(format!(
701 "Shape mismatch: predictions {:?} vs targets {:?}",
702 predictions.shape(),
703 targets.shape()
704 )));
705 }
706
707 let mut intersection = 0.0;
708 let mut pred_sum = 0.0;
709 let mut target_sum = 0.0;
710
711 for i in 0..predictions.nrows() {
712 for j in 0..predictions.ncols() {
713 intersection += predictions[[i, j]] * targets[[i, j]];
714 pred_sum += predictions[[i, j]];
715 target_sum += targets[[i, j]];
716 }
717 }
718
719 let denominator = pred_sum + target_sum + self.smooth;
720 let numerator = 2.0 * intersection + self.smooth;
721
722 let mut grad = Array::zeros(predictions.raw_dim());
723
724 for i in 0..predictions.nrows() {
725 for j in 0..predictions.ncols() {
726 let target = targets[[i, j]];
727 grad[[i, j]] =
729 -2.0 * (target * denominator - numerator) / (denominator * denominator);
730 }
731 }
732
733 Ok(grad)
734 }
735}
736
737#[derive(Debug, Clone)]
740pub struct TverskyLoss {
741 pub alpha: f64,
743 pub beta: f64,
745 pub smooth: f64,
747}
748
749impl Default for TverskyLoss {
750 fn default() -> Self {
751 Self {
752 alpha: 0.5,
753 beta: 0.5,
754 smooth: 1.0,
755 }
756 }
757}
758
759impl Loss for TverskyLoss {
760 fn compute(
761 &self,
762 predictions: &ArrayView<f64, Ix2>,
763 targets: &ArrayView<f64, Ix2>,
764 ) -> TrainResult<f64> {
765 if predictions.shape() != targets.shape() {
766 return Err(TrainError::LossError(format!(
767 "Shape mismatch: predictions {:?} vs targets {:?}",
768 predictions.shape(),
769 targets.shape()
770 )));
771 }
772
773 let mut true_pos = 0.0;
774 let mut false_pos = 0.0;
775 let mut false_neg = 0.0;
776
777 for i in 0..predictions.nrows() {
778 for j in 0..predictions.ncols() {
779 let pred = predictions[[i, j]];
780 let target = targets[[i, j]];
781
782 true_pos += pred * target;
783 false_pos += pred * (1.0 - target);
784 false_neg += (1.0 - pred) * target;
785 }
786 }
787
788 let tversky_index = (true_pos + self.smooth)
790 / (true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth);
791
792 Ok(1.0 - tversky_index)
793 }
794
795 fn gradient(
796 &self,
797 predictions: &ArrayView<f64, Ix2>,
798 targets: &ArrayView<f64, Ix2>,
799 ) -> TrainResult<Array<f64, Ix2>> {
800 if predictions.shape() != targets.shape() {
801 return Err(TrainError::LossError(format!(
802 "Shape mismatch: predictions {:?} vs targets {:?}",
803 predictions.shape(),
804 targets.shape()
805 )));
806 }
807
808 let mut true_pos = 0.0;
809 let mut false_pos = 0.0;
810 let mut false_neg = 0.0;
811
812 for i in 0..predictions.nrows() {
813 for j in 0..predictions.ncols() {
814 let pred = predictions[[i, j]];
815 let target = targets[[i, j]];
816
817 true_pos += pred * target;
818 false_pos += pred * (1.0 - target);
819 false_neg += (1.0 - pred) * target;
820 }
821 }
822
823 let denominator = true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth;
824 let numerator = true_pos + self.smooth;
825
826 let mut grad = Array::zeros(predictions.raw_dim());
827
828 for i in 0..predictions.nrows() {
829 for j in 0..predictions.ncols() {
830 let target = targets[[i, j]];
831
832 let d_tp = target;
834 let d_fp = self.alpha * (1.0 - target);
835 let d_fn = -self.beta * target;
836
837 grad[[i, j]] = -(d_tp * denominator - numerator * (d_tp + d_fp + d_fn))
838 / (denominator * denominator);
839 }
840 }
841
842 Ok(grad)
843 }
844}
845
846#[derive(Debug, Clone)]
849pub struct ContrastiveLoss {
850 pub margin: f64,
852}
853
854impl Default for ContrastiveLoss {
855 fn default() -> Self {
856 Self { margin: 1.0 }
857 }
858}
859
860impl Loss for ContrastiveLoss {
861 fn compute(
862 &self,
863 predictions: &ArrayView<f64, Ix2>,
864 targets: &ArrayView<f64, Ix2>,
865 ) -> TrainResult<f64> {
866 if predictions.ncols() != 2 || targets.ncols() != 1 {
867 return Err(TrainError::LossError(format!(
868 "ContrastiveLoss expects predictions shape [N, 2] (distances) and targets shape [N, 1] (labels), got {:?} and {:?}",
869 predictions.shape(),
870 targets.shape()
871 )));
872 }
873
874 let mut total_loss = 0.0;
875 let n = predictions.nrows() as f64;
876
877 for i in 0..predictions.nrows() {
878 let distance = predictions[[i, 0]];
879 let label = targets[[i, 0]];
880
881 if label > 0.5 {
882 total_loss += distance * distance;
884 } else {
885 total_loss += (self.margin - distance).max(0.0).powi(2);
887 }
888 }
889
890 Ok(total_loss / n)
891 }
892
893 fn gradient(
894 &self,
895 predictions: &ArrayView<f64, Ix2>,
896 targets: &ArrayView<f64, Ix2>,
897 ) -> TrainResult<Array<f64, Ix2>> {
898 let mut grad = Array::zeros(predictions.raw_dim());
899 let n = predictions.nrows() as f64;
900
901 for i in 0..predictions.nrows() {
902 let distance = predictions[[i, 0]];
903 let label = targets[[i, 0]];
904
905 if label > 0.5 {
906 grad[[i, 0]] = 2.0 * distance / n;
908 } else {
909 if distance < self.margin {
911 grad[[i, 0]] = -2.0 * (self.margin - distance) / n;
912 }
913 }
914 }
915
916 Ok(grad)
917 }
918}
919
920#[derive(Debug, Clone)]
923pub struct TripletLoss {
924 pub margin: f64,
926}
927
928impl Default for TripletLoss {
929 fn default() -> Self {
930 Self { margin: 1.0 }
931 }
932}
933
934impl Loss for TripletLoss {
935 fn compute(
936 &self,
937 predictions: &ArrayView<f64, Ix2>,
938 _targets: &ArrayView<f64, Ix2>,
939 ) -> TrainResult<f64> {
940 if predictions.ncols() != 2 {
941 return Err(TrainError::LossError(format!(
942 "TripletLoss expects predictions shape [N, 2] (pos_dist, neg_dist), got {:?}",
943 predictions.shape()
944 )));
945 }
946
947 let mut total_loss = 0.0;
948 let n = predictions.nrows() as f64;
949
950 for i in 0..predictions.nrows() {
951 let pos_distance = predictions[[i, 0]];
952 let neg_distance = predictions[[i, 1]];
953
954 let loss = (pos_distance - neg_distance + self.margin).max(0.0);
956 total_loss += loss;
957 }
958
959 Ok(total_loss / n)
960 }
961
962 fn gradient(
963 &self,
964 predictions: &ArrayView<f64, Ix2>,
965 _targets: &ArrayView<f64, Ix2>,
966 ) -> TrainResult<Array<f64, Ix2>> {
967 let mut grad = Array::zeros(predictions.raw_dim());
968 let n = predictions.nrows() as f64;
969
970 for i in 0..predictions.nrows() {
971 let pos_distance = predictions[[i, 0]];
972 let neg_distance = predictions[[i, 1]];
973
974 if pos_distance - neg_distance + self.margin > 0.0 {
975 grad[[i, 0]] = 1.0 / n;
977 grad[[i, 1]] = -1.0 / n;
979 }
980 }
981
982 Ok(grad)
983 }
984}
985
986#[derive(Debug, Clone)]
988pub struct HingeLoss {
989 pub margin: f64,
991}
992
993impl Default for HingeLoss {
994 fn default() -> Self {
995 Self { margin: 1.0 }
996 }
997}
998
999impl Loss for HingeLoss {
1000 fn compute(
1001 &self,
1002 predictions: &ArrayView<f64, Ix2>,
1003 targets: &ArrayView<f64, Ix2>,
1004 ) -> TrainResult<f64> {
1005 if predictions.shape() != targets.shape() {
1006 return Err(TrainError::LossError(format!(
1007 "Shape mismatch: predictions {:?} vs targets {:?}",
1008 predictions.shape(),
1009 targets.shape()
1010 )));
1011 }
1012
1013 let mut total_loss = 0.0;
1014 let n = predictions.nrows() as f64;
1015
1016 for i in 0..predictions.nrows() {
1017 for j in 0..predictions.ncols() {
1018 let pred = predictions[[i, j]];
1019 let target = targets[[i, j]];
1020
1021 let loss = (self.margin - target * pred).max(0.0);
1023 total_loss += loss;
1024 }
1025 }
1026
1027 Ok(total_loss / n)
1028 }
1029
1030 fn gradient(
1031 &self,
1032 predictions: &ArrayView<f64, Ix2>,
1033 targets: &ArrayView<f64, Ix2>,
1034 ) -> TrainResult<Array<f64, Ix2>> {
1035 let mut grad = Array::zeros(predictions.raw_dim());
1036 let n = predictions.nrows() as f64;
1037
1038 for i in 0..predictions.nrows() {
1039 for j in 0..predictions.ncols() {
1040 let pred = predictions[[i, j]];
1041 let target = targets[[i, j]];
1042
1043 if self.margin - target * pred > 0.0 {
1044 grad[[i, j]] = -target / n;
1045 }
1046 }
1047 }
1048
1049 Ok(grad)
1050 }
1051}
1052
1053#[derive(Debug, Clone)]
1056pub struct KLDivergenceLoss {
1057 pub epsilon: f64,
1059}
1060
1061impl Default for KLDivergenceLoss {
1062 fn default() -> Self {
1063 Self { epsilon: 1e-10 }
1064 }
1065}
1066
1067impl Loss for KLDivergenceLoss {
1068 fn compute(
1069 &self,
1070 predictions: &ArrayView<f64, Ix2>,
1071 targets: &ArrayView<f64, Ix2>,
1072 ) -> TrainResult<f64> {
1073 if predictions.shape() != targets.shape() {
1074 return Err(TrainError::LossError(format!(
1075 "Shape mismatch: predictions {:?} vs targets {:?}",
1076 predictions.shape(),
1077 targets.shape()
1078 )));
1079 }
1080
1081 let mut total_loss = 0.0;
1082
1083 for i in 0..predictions.nrows() {
1084 for j in 0..predictions.ncols() {
1085 let pred = predictions[[i, j]].max(self.epsilon);
1086 let target = targets[[i, j]].max(self.epsilon);
1087
1088 total_loss += target * (target / pred).ln();
1090 }
1091 }
1092
1093 Ok(total_loss)
1094 }
1095
1096 fn gradient(
1097 &self,
1098 predictions: &ArrayView<f64, Ix2>,
1099 targets: &ArrayView<f64, Ix2>,
1100 ) -> TrainResult<Array<f64, Ix2>> {
1101 let mut grad = Array::zeros(predictions.raw_dim());
1102
1103 for i in 0..predictions.nrows() {
1104 for j in 0..predictions.ncols() {
1105 let pred = predictions[[i, j]].max(self.epsilon);
1106 let target = targets[[i, j]].max(self.epsilon);
1107
1108 grad[[i, j]] = -target / pred;
1110 }
1111 }
1112
1113 Ok(grad)
1114 }
1115}
1116
1117#[cfg(test)]
1118mod tests {
1119 use super::*;
1120 use scirs2_core::ndarray::array;
1121
1122 #[test]
1123 fn test_cross_entropy_loss() {
1124 let loss = CrossEntropyLoss::default();
1125 let predictions = array![[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]];
1126 let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
1127
1128 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1129 assert!(loss_val > 0.0);
1130
1131 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1132 assert_eq!(grad.shape(), predictions.shape());
1133 }
1134
1135 #[test]
1136 fn test_mse_loss() {
1137 let loss = MseLoss;
1138 let predictions = array![[1.0, 2.0], [3.0, 4.0]];
1139 let targets = array![[1.5, 2.5], [3.5, 4.5]];
1140
1141 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1142 assert!((loss_val - 0.25).abs() < 1e-6);
1143
1144 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1145 assert_eq!(grad.shape(), predictions.shape());
1146 }
1147
1148 #[test]
1149 fn test_rule_satisfaction_loss() {
1150 let loss = RuleSatisfactionLoss::default();
1151 let rule_values = array![[0.9, 0.8], [0.95, 0.85]];
1152 let targets = array![[1.0, 1.0], [1.0, 1.0]];
1153
1154 let loss_val = loss.compute(&rule_values.view(), &targets.view()).unwrap();
1155 assert!(loss_val > 0.0);
1156
1157 let grad = loss.gradient(&rule_values.view(), &targets.view()).unwrap();
1158 assert_eq!(grad.shape(), rule_values.shape());
1159 }
1160
1161 #[test]
1162 fn test_constraint_violation_loss() {
1163 let loss = ConstraintViolationLoss::default();
1164 let constraint_values = array![[0.1, -0.1], [0.2, -0.2]];
1165 let targets = array![[0.0, 0.0], [0.0, 0.0]];
1166
1167 let loss_val = loss
1168 .compute(&constraint_values.view(), &targets.view())
1169 .unwrap();
1170 assert!(loss_val > 0.0);
1171
1172 let grad = loss
1173 .gradient(&constraint_values.view(), &targets.view())
1174 .unwrap();
1175 assert_eq!(grad.shape(), constraint_values.shape());
1176 }
1177
1178 #[test]
1179 fn test_focal_loss() {
1180 let loss = FocalLoss::default();
1181 let predictions = array![[0.9, 0.1], [0.2, 0.8]];
1182 let targets = array![[1.0, 0.0], [0.0, 1.0]];
1183
1184 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1185 assert!(loss_val >= 0.0);
1186
1187 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1188 assert_eq!(grad.shape(), predictions.shape());
1189 }
1190
1191 #[test]
1192 fn test_huber_loss() {
1193 let loss = HuberLoss::default();
1194 let predictions = array![[1.0, 3.0], [2.0, 5.0]];
1195 let targets = array![[1.5, 2.0], [2.5, 4.0]];
1196
1197 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1198 assert!(loss_val > 0.0);
1199
1200 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1201 assert_eq!(grad.shape(), predictions.shape());
1202 }
1203
1204 #[test]
1205 fn test_bce_with_logits_loss() {
1206 let loss = BCEWithLogitsLoss;
1207 let logits = array![[0.5, -0.5], [1.0, -1.0]];
1208 let targets = array![[1.0, 0.0], [1.0, 0.0]];
1209
1210 let loss_val = loss.compute(&logits.view(), &targets.view()).unwrap();
1211 assert!(loss_val >= 0.0);
1212
1213 let grad = loss.gradient(&logits.view(), &targets.view()).unwrap();
1214 assert_eq!(grad.shape(), logits.shape());
1215 }
1216
1217 #[test]
1218 fn test_dice_loss() {
1219 let loss = DiceLoss::default();
1220 let predictions = array![[0.9, 0.1], [0.8, 0.2]];
1221 let targets = array![[1.0, 0.0], [1.0, 0.0]];
1222
1223 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1224 assert!(loss_val >= 0.0);
1225 assert!(loss_val <= 1.0);
1226
1227 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1228 assert_eq!(grad.shape(), predictions.shape());
1229 }
1230
1231 #[test]
1232 fn test_tversky_loss() {
1233 let loss = TverskyLoss::default();
1234 let predictions = array![[0.9, 0.1], [0.8, 0.2]];
1235 let targets = array![[1.0, 0.0], [1.0, 0.0]];
1236
1237 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1238 assert!(loss_val >= 0.0);
1239 assert!(loss_val <= 1.0);
1240
1241 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1242 assert_eq!(grad.shape(), predictions.shape());
1243 }
1244
1245 #[test]
1246 fn test_contrastive_loss() {
1247 let loss = ContrastiveLoss::default();
1248 let predictions = array![[0.5, 0.0], [1.5, 0.0], [0.2, 0.0]];
1251 let targets = array![[1.0], [0.0], [1.0]];
1252
1253 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1254 assert!(loss_val >= 0.0);
1255
1256 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1257 assert_eq!(grad.shape(), predictions.shape());
1258
1259 assert!(grad[[0, 0]] > 0.0);
1261 assert_eq!(grad[[1, 0]], 0.0);
1263 }
1264
1265 #[test]
1266 fn test_triplet_loss() {
1267 let loss = TripletLoss::default();
1268 let predictions = array![[0.5, 2.0], [1.0, 0.5], [0.3, 1.5]];
1270 let targets = array![[0.0], [0.0], [0.0]]; let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1273 assert!(loss_val >= 0.0);
1274
1275 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1276 assert_eq!(grad.shape(), predictions.shape());
1277
1278 assert_eq!(grad[[0, 0]], 0.0);
1280 assert_eq!(grad[[0, 1]], 0.0);
1281
1282 assert!(grad[[1, 0]] > 0.0);
1284 assert!(grad[[1, 1]] < 0.0);
1285 }
1286
1287 #[test]
1288 fn test_hinge_loss() {
1289 let loss = HingeLoss::default();
1290 let predictions = array![[0.5, -0.5], [2.0, -2.0]];
1292 let targets = array![[1.0, -1.0], [1.0, -1.0]];
1293
1294 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1295 assert!(loss_val >= 0.0);
1296
1297 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1298 assert_eq!(grad.shape(), predictions.shape());
1299
1300 assert_eq!(grad[[1, 0]], 0.0);
1302 assert_eq!(grad[[1, 1]], 0.0);
1303 }
1304
1305 #[test]
1306 fn test_kl_divergence_loss() {
1307 let loss = KLDivergenceLoss::default();
1308 let predictions = array![[0.6, 0.4], [0.7, 0.3]];
1310 let targets = array![[0.5, 0.5], [0.8, 0.2]];
1311
1312 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1313 assert!(loss_val >= 0.0);
1314
1315 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1316 assert_eq!(grad.shape(), predictions.shape());
1317
1318 let identical_preds = array![[0.5, 0.5]];
1320 let identical_targets = array![[0.5, 0.5]];
1321 let identical_loss = loss
1322 .compute(&identical_preds.view(), &identical_targets.view())
1323 .unwrap();
1324 assert!(identical_loss.abs() < 1e-6);
1325 }
1326}