1use crate::{TrainError, TrainResult};
6use scirs2_core::ndarray::{Array, ArrayView, Ix2};
7use std::fmt::Debug;
8
9#[derive(Debug, Clone)]
11pub struct LossConfig {
12 pub supervised_weight: f64,
14 pub constraint_weight: f64,
16 pub rule_weight: f64,
18 pub temperature: f64,
20}
21
22impl Default for LossConfig {
23 fn default() -> Self {
24 Self {
25 supervised_weight: 1.0,
26 constraint_weight: 1.0,
27 rule_weight: 1.0,
28 temperature: 1.0,
29 }
30 }
31}
32
33pub trait Loss: Debug {
35 fn compute(
37 &self,
38 predictions: &ArrayView<f64, Ix2>,
39 targets: &ArrayView<f64, Ix2>,
40 ) -> TrainResult<f64>;
41
42 fn gradient(
44 &self,
45 predictions: &ArrayView<f64, Ix2>,
46 targets: &ArrayView<f64, Ix2>,
47 ) -> TrainResult<Array<f64, Ix2>>;
48
49 fn name(&self) -> &str {
51 "unknown"
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct CrossEntropyLoss {
58 pub epsilon: f64,
60}
61
62impl Default for CrossEntropyLoss {
63 fn default() -> Self {
64 Self { epsilon: 1e-10 }
65 }
66}
67
68impl Loss for CrossEntropyLoss {
69 fn compute(
70 &self,
71 predictions: &ArrayView<f64, Ix2>,
72 targets: &ArrayView<f64, Ix2>,
73 ) -> TrainResult<f64> {
74 if predictions.shape() != targets.shape() {
75 return Err(TrainError::LossError(format!(
76 "Shape mismatch: predictions {:?} vs targets {:?}",
77 predictions.shape(),
78 targets.shape()
79 )));
80 }
81
82 let n = predictions.nrows() as f64;
83 let mut total_loss = 0.0;
84
85 for i in 0..predictions.nrows() {
86 for j in 0..predictions.ncols() {
87 let pred = predictions[[i, j]]
88 .max(self.epsilon)
89 .min(1.0 - self.epsilon);
90 let target = targets[[i, j]];
91 total_loss -= target * pred.ln();
92 }
93 }
94
95 Ok(total_loss / n)
96 }
97
98 fn gradient(
99 &self,
100 predictions: &ArrayView<f64, Ix2>,
101 targets: &ArrayView<f64, Ix2>,
102 ) -> TrainResult<Array<f64, Ix2>> {
103 if predictions.shape() != targets.shape() {
104 return Err(TrainError::LossError(format!(
105 "Shape mismatch: predictions {:?} vs targets {:?}",
106 predictions.shape(),
107 targets.shape()
108 )));
109 }
110
111 let n = predictions.nrows() as f64;
112 let mut grad = Array::zeros(predictions.raw_dim());
113
114 for i in 0..predictions.nrows() {
115 for j in 0..predictions.ncols() {
116 let pred = predictions[[i, j]]
117 .max(self.epsilon)
118 .min(1.0 - self.epsilon);
119 let target = targets[[i, j]];
120 grad[[i, j]] = -(target / pred) / n;
121 }
122 }
123
124 Ok(grad)
125 }
126}
127
128#[derive(Debug, Clone, Default)]
130pub struct MseLoss;
131
132impl Loss for MseLoss {
133 fn compute(
134 &self,
135 predictions: &ArrayView<f64, Ix2>,
136 targets: &ArrayView<f64, Ix2>,
137 ) -> TrainResult<f64> {
138 if predictions.shape() != targets.shape() {
139 return Err(TrainError::LossError(format!(
140 "Shape mismatch: predictions {:?} vs targets {:?}",
141 predictions.shape(),
142 targets.shape()
143 )));
144 }
145
146 let n = predictions.len() as f64;
147 let mut total_loss = 0.0;
148
149 for i in 0..predictions.nrows() {
150 for j in 0..predictions.ncols() {
151 let diff = predictions[[i, j]] - targets[[i, j]];
152 total_loss += diff * diff;
153 }
154 }
155
156 Ok(total_loss / n)
157 }
158
159 fn gradient(
160 &self,
161 predictions: &ArrayView<f64, Ix2>,
162 targets: &ArrayView<f64, Ix2>,
163 ) -> TrainResult<Array<f64, Ix2>> {
164 if predictions.shape() != targets.shape() {
165 return Err(TrainError::LossError(format!(
166 "Shape mismatch: predictions {:?} vs targets {:?}",
167 predictions.shape(),
168 targets.shape()
169 )));
170 }
171
172 let n = predictions.len() as f64;
173 let mut grad = Array::zeros(predictions.raw_dim());
174
175 for i in 0..predictions.nrows() {
176 for j in 0..predictions.ncols() {
177 grad[[i, j]] = 2.0 * (predictions[[i, j]] - targets[[i, j]]) / n;
178 }
179 }
180
181 Ok(grad)
182 }
183}
184
185#[derive(Debug)]
187pub struct LogicalLoss {
188 pub config: LossConfig,
190 pub supervised_loss: Box<dyn Loss>,
192 pub rule_losses: Vec<Box<dyn Loss>>,
194 pub constraint_losses: Vec<Box<dyn Loss>>,
196}
197
198impl LogicalLoss {
199 pub fn new(
201 config: LossConfig,
202 supervised_loss: Box<dyn Loss>,
203 rule_losses: Vec<Box<dyn Loss>>,
204 constraint_losses: Vec<Box<dyn Loss>>,
205 ) -> Self {
206 Self {
207 config,
208 supervised_loss,
209 rule_losses,
210 constraint_losses,
211 }
212 }
213
214 pub fn compute_total(
216 &self,
217 predictions: &ArrayView<f64, Ix2>,
218 targets: &ArrayView<f64, Ix2>,
219 rule_values: &[ArrayView<f64, Ix2>],
220 constraint_values: &[ArrayView<f64, Ix2>],
221 ) -> TrainResult<f64> {
222 let mut total = 0.0;
223
224 let supervised = self.supervised_loss.compute(predictions, targets)?;
226 total += self.config.supervised_weight * supervised;
227
228 if !rule_values.is_empty() && !self.rule_losses.is_empty() {
230 let expected_true = Array::ones((rule_values[0].nrows(), rule_values[0].ncols()));
231 let expected_true_view = expected_true.view();
232
233 for (rule_val, rule_loss) in rule_values.iter().zip(self.rule_losses.iter()) {
234 let rule_loss_val = rule_loss.compute(rule_val, &expected_true_view)?;
235 total += self.config.rule_weight * rule_loss_val;
236 }
237 }
238
239 if !constraint_values.is_empty() && !self.constraint_losses.is_empty() {
241 let expected_zero =
242 Array::zeros((constraint_values[0].nrows(), constraint_values[0].ncols()));
243 let expected_zero_view = expected_zero.view();
244
245 for (constraint_val, constraint_loss) in
246 constraint_values.iter().zip(self.constraint_losses.iter())
247 {
248 let constraint_loss_val =
249 constraint_loss.compute(constraint_val, &expected_zero_view)?;
250 total += self.config.constraint_weight * constraint_loss_val;
251 }
252 }
253
254 Ok(total)
255 }
256}
257
258#[derive(Debug, Clone)]
260pub struct RuleSatisfactionLoss {
261 pub temperature: f64,
263}
264
265impl Default for RuleSatisfactionLoss {
266 fn default() -> Self {
267 Self { temperature: 1.0 }
268 }
269}
270
271impl Loss for RuleSatisfactionLoss {
272 fn compute(
273 &self,
274 rule_values: &ArrayView<f64, Ix2>,
275 targets: &ArrayView<f64, Ix2>,
276 ) -> TrainResult<f64> {
277 if rule_values.shape() != targets.shape() {
278 return Err(TrainError::LossError(format!(
279 "Shape mismatch: rule_values {:?} vs targets {:?}",
280 rule_values.shape(),
281 targets.shape()
282 )));
283 }
284
285 let n = rule_values.len() as f64;
286 let mut total_loss = 0.0;
287
288 for i in 0..rule_values.nrows() {
290 for j in 0..rule_values.ncols() {
291 let diff = targets[[i, j]] - rule_values[[i, j]];
292 total_loss += (diff / self.temperature).powi(2);
294 }
295 }
296
297 Ok(total_loss / n)
298 }
299
300 fn gradient(
301 &self,
302 rule_values: &ArrayView<f64, Ix2>,
303 targets: &ArrayView<f64, Ix2>,
304 ) -> TrainResult<Array<f64, Ix2>> {
305 if rule_values.shape() != targets.shape() {
306 return Err(TrainError::LossError(format!(
307 "Shape mismatch: rule_values {:?} vs targets {:?}",
308 rule_values.shape(),
309 targets.shape()
310 )));
311 }
312
313 let n = rule_values.len() as f64;
314 let mut grad = Array::zeros(rule_values.raw_dim());
315
316 for i in 0..rule_values.nrows() {
317 for j in 0..rule_values.ncols() {
318 let diff = targets[[i, j]] - rule_values[[i, j]];
319 grad[[i, j]] = -2.0 * diff / (self.temperature * self.temperature * n);
320 }
321 }
322
323 Ok(grad)
324 }
325}
326
327#[derive(Debug, Clone)]
329pub struct ConstraintViolationLoss {
330 pub penalty_weight: f64,
332}
333
334impl Default for ConstraintViolationLoss {
335 fn default() -> Self {
336 Self {
337 penalty_weight: 10.0,
338 }
339 }
340}
341
342impl Loss for ConstraintViolationLoss {
343 fn compute(
344 &self,
345 constraint_values: &ArrayView<f64, Ix2>,
346 targets: &ArrayView<f64, Ix2>,
347 ) -> TrainResult<f64> {
348 if constraint_values.shape() != targets.shape() {
349 return Err(TrainError::LossError(format!(
350 "Shape mismatch: constraint_values {:?} vs targets {:?}",
351 constraint_values.shape(),
352 targets.shape()
353 )));
354 }
355
356 let n = constraint_values.len() as f64;
357 let mut total_loss = 0.0;
358
359 for i in 0..constraint_values.nrows() {
361 for j in 0..constraint_values.ncols() {
362 let violation = (constraint_values[[i, j]] - targets[[i, j]]).max(0.0);
363 total_loss += self.penalty_weight * violation * violation;
364 }
365 }
366
367 Ok(total_loss / n)
368 }
369
370 fn gradient(
371 &self,
372 constraint_values: &ArrayView<f64, Ix2>,
373 targets: &ArrayView<f64, Ix2>,
374 ) -> TrainResult<Array<f64, Ix2>> {
375 if constraint_values.shape() != targets.shape() {
376 return Err(TrainError::LossError(format!(
377 "Shape mismatch: constraint_values {:?} vs targets {:?}",
378 constraint_values.shape(),
379 targets.shape()
380 )));
381 }
382
383 let n = constraint_values.len() as f64;
384 let mut grad = Array::zeros(constraint_values.raw_dim());
385
386 for i in 0..constraint_values.nrows() {
387 for j in 0..constraint_values.ncols() {
388 let violation = constraint_values[[i, j]] - targets[[i, j]];
389 if violation > 0.0 {
390 grad[[i, j]] = 2.0 * self.penalty_weight * violation / n;
391 }
392 }
393 }
394
395 Ok(grad)
396 }
397}
398
399#[derive(Debug, Clone)]
402pub struct FocalLoss {
403 pub alpha: f64,
405 pub gamma: f64,
407 pub epsilon: f64,
409}
410
411impl Default for FocalLoss {
412 fn default() -> Self {
413 Self {
414 alpha: 0.25,
415 gamma: 2.0,
416 epsilon: 1e-10,
417 }
418 }
419}
420
421impl Loss for FocalLoss {
422 fn compute(
423 &self,
424 predictions: &ArrayView<f64, Ix2>,
425 targets: &ArrayView<f64, Ix2>,
426 ) -> TrainResult<f64> {
427 if predictions.shape() != targets.shape() {
428 return Err(TrainError::LossError(format!(
429 "Shape mismatch: predictions {:?} vs targets {:?}",
430 predictions.shape(),
431 targets.shape()
432 )));
433 }
434
435 let n = predictions.nrows() as f64;
436 let mut total_loss = 0.0;
437
438 for i in 0..predictions.nrows() {
439 for j in 0..predictions.ncols() {
440 let pred = predictions[[i, j]]
441 .max(self.epsilon)
442 .min(1.0 - self.epsilon);
443 let target = targets[[i, j]];
444
445 if target > 0.5 {
448 let focal_weight = (1.0 - pred).powf(self.gamma);
450 total_loss -= self.alpha * focal_weight * pred.ln();
451 } else {
452 let focal_weight = pred.powf(self.gamma);
454 total_loss -= (1.0 - self.alpha) * focal_weight * (1.0 - pred).ln();
455 }
456 }
457 }
458
459 Ok(total_loss / n)
460 }
461
462 fn gradient(
463 &self,
464 predictions: &ArrayView<f64, Ix2>,
465 targets: &ArrayView<f64, Ix2>,
466 ) -> TrainResult<Array<f64, Ix2>> {
467 if predictions.shape() != targets.shape() {
468 return Err(TrainError::LossError(format!(
469 "Shape mismatch: predictions {:?} vs targets {:?}",
470 predictions.shape(),
471 targets.shape()
472 )));
473 }
474
475 let n = predictions.nrows() as f64;
476 let mut grad = Array::zeros(predictions.raw_dim());
477
478 for i in 0..predictions.nrows() {
479 for j in 0..predictions.ncols() {
480 let pred = predictions[[i, j]]
481 .max(self.epsilon)
482 .min(1.0 - self.epsilon);
483 let target = targets[[i, j]];
484
485 if target > 0.5 {
486 let focal_weight = (1.0 - pred).powf(self.gamma);
488 let d_focal = self.gamma * (1.0 - pred).powf(self.gamma - 1.0);
489 grad[[i, j]] = -self.alpha * (focal_weight / pred - d_focal * pred.ln()) / n;
490 } else {
491 let focal_weight = pred.powf(self.gamma);
493 let d_focal = self.gamma * pred.powf(self.gamma - 1.0);
494 grad[[i, j]] = -(1.0 - self.alpha)
495 * (d_focal * (1.0 - pred).ln() - focal_weight / (1.0 - pred))
496 / n;
497 }
498 }
499 }
500
501 Ok(grad)
502 }
503}
504
505#[derive(Debug, Clone)]
507pub struct HuberLoss {
508 pub delta: f64,
510}
511
512impl Default for HuberLoss {
513 fn default() -> Self {
514 Self { delta: 1.0 }
515 }
516}
517
518impl Loss for HuberLoss {
519 fn compute(
520 &self,
521 predictions: &ArrayView<f64, Ix2>,
522 targets: &ArrayView<f64, Ix2>,
523 ) -> TrainResult<f64> {
524 if predictions.shape() != targets.shape() {
525 return Err(TrainError::LossError(format!(
526 "Shape mismatch: predictions {:?} vs targets {:?}",
527 predictions.shape(),
528 targets.shape()
529 )));
530 }
531
532 let n = predictions.len() as f64;
533 let mut total_loss = 0.0;
534
535 for i in 0..predictions.nrows() {
536 for j in 0..predictions.ncols() {
537 let diff = (predictions[[i, j]] - targets[[i, j]]).abs();
538 if diff <= self.delta {
539 total_loss += 0.5 * diff * diff;
541 } else {
542 total_loss += self.delta * (diff - 0.5 * self.delta);
544 }
545 }
546 }
547
548 Ok(total_loss / n)
549 }
550
551 fn gradient(
552 &self,
553 predictions: &ArrayView<f64, Ix2>,
554 targets: &ArrayView<f64, Ix2>,
555 ) -> TrainResult<Array<f64, Ix2>> {
556 if predictions.shape() != targets.shape() {
557 return Err(TrainError::LossError(format!(
558 "Shape mismatch: predictions {:?} vs targets {:?}",
559 predictions.shape(),
560 targets.shape()
561 )));
562 }
563
564 let n = predictions.len() as f64;
565 let mut grad = Array::zeros(predictions.raw_dim());
566
567 for i in 0..predictions.nrows() {
568 for j in 0..predictions.ncols() {
569 let diff = predictions[[i, j]] - targets[[i, j]];
570 let abs_diff = diff.abs();
571
572 if abs_diff <= self.delta {
573 grad[[i, j]] = diff / n;
574 } else {
575 grad[[i, j]] = self.delta * diff.signum() / n;
576 }
577 }
578 }
579
580 Ok(grad)
581 }
582}
583
584#[derive(Debug, Clone, Default)]
586pub struct BCEWithLogitsLoss;
587
588impl Loss for BCEWithLogitsLoss {
589 fn compute(
590 &self,
591 logits: &ArrayView<f64, Ix2>,
592 targets: &ArrayView<f64, Ix2>,
593 ) -> TrainResult<f64> {
594 if logits.shape() != targets.shape() {
595 return Err(TrainError::LossError(format!(
596 "Shape mismatch: logits {:?} vs targets {:?}",
597 logits.shape(),
598 targets.shape()
599 )));
600 }
601
602 let n = logits.len() as f64;
603 let mut total_loss = 0.0;
604
605 for i in 0..logits.nrows() {
606 for j in 0..logits.ncols() {
607 let logit = logits[[i, j]];
608 let target = targets[[i, j]];
609
610 let max_val = logit.max(0.0);
613 total_loss += max_val - logit * target + (1.0 + (-logit.abs()).exp()).ln();
614 }
615 }
616
617 Ok(total_loss / n)
618 }
619
620 fn gradient(
621 &self,
622 logits: &ArrayView<f64, Ix2>,
623 targets: &ArrayView<f64, Ix2>,
624 ) -> TrainResult<Array<f64, Ix2>> {
625 if logits.shape() != targets.shape() {
626 return Err(TrainError::LossError(format!(
627 "Shape mismatch: logits {:?} vs targets {:?}",
628 logits.shape(),
629 targets.shape()
630 )));
631 }
632
633 let n = logits.len() as f64;
634 let mut grad = Array::zeros(logits.raw_dim());
635
636 for i in 0..logits.nrows() {
637 for j in 0..logits.ncols() {
638 let logit = logits[[i, j]];
639 let target = targets[[i, j]];
640
641 let sigmoid = 1.0 / (1.0 + (-logit).exp());
643 grad[[i, j]] = (sigmoid - target) / n;
644 }
645 }
646
647 Ok(grad)
648 }
649}
650
651#[derive(Debug, Clone)]
653pub struct DiceLoss {
654 pub smooth: f64,
656}
657
658impl Default for DiceLoss {
659 fn default() -> Self {
660 Self { smooth: 1.0 }
661 }
662}
663
664impl Loss for DiceLoss {
665 fn compute(
666 &self,
667 predictions: &ArrayView<f64, Ix2>,
668 targets: &ArrayView<f64, Ix2>,
669 ) -> TrainResult<f64> {
670 if predictions.shape() != targets.shape() {
671 return Err(TrainError::LossError(format!(
672 "Shape mismatch: predictions {:?} vs targets {:?}",
673 predictions.shape(),
674 targets.shape()
675 )));
676 }
677
678 let mut intersection = 0.0;
679 let mut pred_sum = 0.0;
680 let mut target_sum = 0.0;
681
682 for i in 0..predictions.nrows() {
683 for j in 0..predictions.ncols() {
684 let pred = predictions[[i, j]];
685 let target = targets[[i, j]];
686
687 intersection += pred * target;
688 pred_sum += pred;
689 target_sum += target;
690 }
691 }
692
693 let dice_coef = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth);
696 Ok(1.0 - dice_coef)
697 }
698
699 fn gradient(
700 &self,
701 predictions: &ArrayView<f64, Ix2>,
702 targets: &ArrayView<f64, Ix2>,
703 ) -> TrainResult<Array<f64, Ix2>> {
704 if predictions.shape() != targets.shape() {
705 return Err(TrainError::LossError(format!(
706 "Shape mismatch: predictions {:?} vs targets {:?}",
707 predictions.shape(),
708 targets.shape()
709 )));
710 }
711
712 let mut intersection = 0.0;
713 let mut pred_sum = 0.0;
714 let mut target_sum = 0.0;
715
716 for i in 0..predictions.nrows() {
717 for j in 0..predictions.ncols() {
718 intersection += predictions[[i, j]] * targets[[i, j]];
719 pred_sum += predictions[[i, j]];
720 target_sum += targets[[i, j]];
721 }
722 }
723
724 let denominator = pred_sum + target_sum + self.smooth;
725 let numerator = 2.0 * intersection + self.smooth;
726
727 let mut grad = Array::zeros(predictions.raw_dim());
728
729 for i in 0..predictions.nrows() {
730 for j in 0..predictions.ncols() {
731 let target = targets[[i, j]];
732 grad[[i, j]] =
734 -2.0 * (target * denominator - numerator) / (denominator * denominator);
735 }
736 }
737
738 Ok(grad)
739 }
740}
741
742#[derive(Debug, Clone)]
745pub struct TverskyLoss {
746 pub alpha: f64,
748 pub beta: f64,
750 pub smooth: f64,
752}
753
754impl Default for TverskyLoss {
755 fn default() -> Self {
756 Self {
757 alpha: 0.5,
758 beta: 0.5,
759 smooth: 1.0,
760 }
761 }
762}
763
764impl Loss for TverskyLoss {
765 fn compute(
766 &self,
767 predictions: &ArrayView<f64, Ix2>,
768 targets: &ArrayView<f64, Ix2>,
769 ) -> TrainResult<f64> {
770 if predictions.shape() != targets.shape() {
771 return Err(TrainError::LossError(format!(
772 "Shape mismatch: predictions {:?} vs targets {:?}",
773 predictions.shape(),
774 targets.shape()
775 )));
776 }
777
778 let mut true_pos = 0.0;
779 let mut false_pos = 0.0;
780 let mut false_neg = 0.0;
781
782 for i in 0..predictions.nrows() {
783 for j in 0..predictions.ncols() {
784 let pred = predictions[[i, j]];
785 let target = targets[[i, j]];
786
787 true_pos += pred * target;
788 false_pos += pred * (1.0 - target);
789 false_neg += (1.0 - pred) * target;
790 }
791 }
792
793 let tversky_index = (true_pos + self.smooth)
795 / (true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth);
796
797 Ok(1.0 - tversky_index)
798 }
799
800 fn gradient(
801 &self,
802 predictions: &ArrayView<f64, Ix2>,
803 targets: &ArrayView<f64, Ix2>,
804 ) -> TrainResult<Array<f64, Ix2>> {
805 if predictions.shape() != targets.shape() {
806 return Err(TrainError::LossError(format!(
807 "Shape mismatch: predictions {:?} vs targets {:?}",
808 predictions.shape(),
809 targets.shape()
810 )));
811 }
812
813 let mut true_pos = 0.0;
814 let mut false_pos = 0.0;
815 let mut false_neg = 0.0;
816
817 for i in 0..predictions.nrows() {
818 for j in 0..predictions.ncols() {
819 let pred = predictions[[i, j]];
820 let target = targets[[i, j]];
821
822 true_pos += pred * target;
823 false_pos += pred * (1.0 - target);
824 false_neg += (1.0 - pred) * target;
825 }
826 }
827
828 let denominator = true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth;
829 let numerator = true_pos + self.smooth;
830
831 let mut grad = Array::zeros(predictions.raw_dim());
832
833 for i in 0..predictions.nrows() {
834 for j in 0..predictions.ncols() {
835 let target = targets[[i, j]];
836
837 let d_tp = target;
839 let d_fp = self.alpha * (1.0 - target);
840 let d_fn = -self.beta * target;
841
842 grad[[i, j]] = -(d_tp * denominator - numerator * (d_tp + d_fp + d_fn))
843 / (denominator * denominator);
844 }
845 }
846
847 Ok(grad)
848 }
849}
850
851#[derive(Debug, Clone)]
854pub struct ContrastiveLoss {
855 pub margin: f64,
857}
858
859impl Default for ContrastiveLoss {
860 fn default() -> Self {
861 Self { margin: 1.0 }
862 }
863}
864
865impl Loss for ContrastiveLoss {
866 fn compute(
867 &self,
868 predictions: &ArrayView<f64, Ix2>,
869 targets: &ArrayView<f64, Ix2>,
870 ) -> TrainResult<f64> {
871 if predictions.ncols() != 2 || targets.ncols() != 1 {
872 return Err(TrainError::LossError(format!(
873 "ContrastiveLoss expects predictions shape [N, 2] (distances) and targets shape [N, 1] (labels), got {:?} and {:?}",
874 predictions.shape(),
875 targets.shape()
876 )));
877 }
878
879 let mut total_loss = 0.0;
880 let n = predictions.nrows() as f64;
881
882 for i in 0..predictions.nrows() {
883 let distance = predictions[[i, 0]];
884 let label = targets[[i, 0]];
885
886 if label > 0.5 {
887 total_loss += distance * distance;
889 } else {
890 total_loss += (self.margin - distance).max(0.0).powi(2);
892 }
893 }
894
895 Ok(total_loss / n)
896 }
897
898 fn gradient(
899 &self,
900 predictions: &ArrayView<f64, Ix2>,
901 targets: &ArrayView<f64, Ix2>,
902 ) -> TrainResult<Array<f64, Ix2>> {
903 let mut grad = Array::zeros(predictions.raw_dim());
904 let n = predictions.nrows() as f64;
905
906 for i in 0..predictions.nrows() {
907 let distance = predictions[[i, 0]];
908 let label = targets[[i, 0]];
909
910 if label > 0.5 {
911 grad[[i, 0]] = 2.0 * distance / n;
913 } else {
914 if distance < self.margin {
916 grad[[i, 0]] = -2.0 * (self.margin - distance) / n;
917 }
918 }
919 }
920
921 Ok(grad)
922 }
923}
924
925#[derive(Debug, Clone)]
928pub struct TripletLoss {
929 pub margin: f64,
931}
932
933impl Default for TripletLoss {
934 fn default() -> Self {
935 Self { margin: 1.0 }
936 }
937}
938
939impl Loss for TripletLoss {
940 fn compute(
941 &self,
942 predictions: &ArrayView<f64, Ix2>,
943 _targets: &ArrayView<f64, Ix2>,
944 ) -> TrainResult<f64> {
945 if predictions.ncols() != 2 {
946 return Err(TrainError::LossError(format!(
947 "TripletLoss expects predictions shape [N, 2] (pos_dist, neg_dist), got {:?}",
948 predictions.shape()
949 )));
950 }
951
952 let mut total_loss = 0.0;
953 let n = predictions.nrows() as f64;
954
955 for i in 0..predictions.nrows() {
956 let pos_distance = predictions[[i, 0]];
957 let neg_distance = predictions[[i, 1]];
958
959 let loss = (pos_distance - neg_distance + self.margin).max(0.0);
961 total_loss += loss;
962 }
963
964 Ok(total_loss / n)
965 }
966
967 fn gradient(
968 &self,
969 predictions: &ArrayView<f64, Ix2>,
970 _targets: &ArrayView<f64, Ix2>,
971 ) -> TrainResult<Array<f64, Ix2>> {
972 let mut grad = Array::zeros(predictions.raw_dim());
973 let n = predictions.nrows() as f64;
974
975 for i in 0..predictions.nrows() {
976 let pos_distance = predictions[[i, 0]];
977 let neg_distance = predictions[[i, 1]];
978
979 if pos_distance - neg_distance + self.margin > 0.0 {
980 grad[[i, 0]] = 1.0 / n;
982 grad[[i, 1]] = -1.0 / n;
984 }
985 }
986
987 Ok(grad)
988 }
989}
990
991#[derive(Debug, Clone)]
993pub struct HingeLoss {
994 pub margin: f64,
996}
997
998impl Default for HingeLoss {
999 fn default() -> Self {
1000 Self { margin: 1.0 }
1001 }
1002}
1003
1004impl Loss for HingeLoss {
1005 fn compute(
1006 &self,
1007 predictions: &ArrayView<f64, Ix2>,
1008 targets: &ArrayView<f64, Ix2>,
1009 ) -> TrainResult<f64> {
1010 if predictions.shape() != targets.shape() {
1011 return Err(TrainError::LossError(format!(
1012 "Shape mismatch: predictions {:?} vs targets {:?}",
1013 predictions.shape(),
1014 targets.shape()
1015 )));
1016 }
1017
1018 let mut total_loss = 0.0;
1019 let n = predictions.nrows() as f64;
1020
1021 for i in 0..predictions.nrows() {
1022 for j in 0..predictions.ncols() {
1023 let pred = predictions[[i, j]];
1024 let target = targets[[i, j]];
1025
1026 let loss = (self.margin - target * pred).max(0.0);
1028 total_loss += loss;
1029 }
1030 }
1031
1032 Ok(total_loss / n)
1033 }
1034
1035 fn gradient(
1036 &self,
1037 predictions: &ArrayView<f64, Ix2>,
1038 targets: &ArrayView<f64, Ix2>,
1039 ) -> TrainResult<Array<f64, Ix2>> {
1040 let mut grad = Array::zeros(predictions.raw_dim());
1041 let n = predictions.nrows() as f64;
1042
1043 for i in 0..predictions.nrows() {
1044 for j in 0..predictions.ncols() {
1045 let pred = predictions[[i, j]];
1046 let target = targets[[i, j]];
1047
1048 if self.margin - target * pred > 0.0 {
1049 grad[[i, j]] = -target / n;
1050 }
1051 }
1052 }
1053
1054 Ok(grad)
1055 }
1056}
1057
1058#[derive(Debug, Clone)]
1061pub struct KLDivergenceLoss {
1062 pub epsilon: f64,
1064}
1065
1066impl Default for KLDivergenceLoss {
1067 fn default() -> Self {
1068 Self { epsilon: 1e-10 }
1069 }
1070}
1071
1072impl Loss for KLDivergenceLoss {
1073 fn compute(
1074 &self,
1075 predictions: &ArrayView<f64, Ix2>,
1076 targets: &ArrayView<f64, Ix2>,
1077 ) -> TrainResult<f64> {
1078 if predictions.shape() != targets.shape() {
1079 return Err(TrainError::LossError(format!(
1080 "Shape mismatch: predictions {:?} vs targets {:?}",
1081 predictions.shape(),
1082 targets.shape()
1083 )));
1084 }
1085
1086 let mut total_loss = 0.0;
1087
1088 for i in 0..predictions.nrows() {
1089 for j in 0..predictions.ncols() {
1090 let pred = predictions[[i, j]].max(self.epsilon);
1091 let target = targets[[i, j]].max(self.epsilon);
1092
1093 total_loss += target * (target / pred).ln();
1095 }
1096 }
1097
1098 Ok(total_loss)
1099 }
1100
1101 fn gradient(
1102 &self,
1103 predictions: &ArrayView<f64, Ix2>,
1104 targets: &ArrayView<f64, Ix2>,
1105 ) -> TrainResult<Array<f64, Ix2>> {
1106 let mut grad = Array::zeros(predictions.raw_dim());
1107
1108 for i in 0..predictions.nrows() {
1109 for j in 0..predictions.ncols() {
1110 let pred = predictions[[i, j]].max(self.epsilon);
1111 let target = targets[[i, j]].max(self.epsilon);
1112
1113 grad[[i, j]] = -target / pred;
1115 }
1116 }
1117
1118 Ok(grad)
1119 }
1120}
1121
1122#[derive(Debug, Clone)]
1139pub struct PolyLoss {
1140 pub epsilon: f64,
1142 pub poly_coeff: f64,
1144}
1145
1146impl Default for PolyLoss {
1147 fn default() -> Self {
1148 Self {
1149 epsilon: 1e-10,
1150 poly_coeff: 1.0, }
1152 }
1153}
1154
1155impl PolyLoss {
1156 pub fn new(poly_coeff: f64) -> Self {
1158 Self {
1159 epsilon: 1e-10,
1160 poly_coeff,
1161 }
1162 }
1163}
1164
1165impl Loss for PolyLoss {
1166 fn compute(
1167 &self,
1168 predictions: &ArrayView<f64, Ix2>,
1169 targets: &ArrayView<f64, Ix2>,
1170 ) -> TrainResult<f64> {
1171 if predictions.shape() != targets.shape() {
1172 return Err(TrainError::LossError(format!(
1173 "Shape mismatch: predictions {:?} vs targets {:?}",
1174 predictions.shape(),
1175 targets.shape()
1176 )));
1177 }
1178
1179 let n = predictions.nrows() as f64;
1180 let mut total_loss = 0.0;
1181
1182 for i in 0..predictions.nrows() {
1183 for j in 0..predictions.ncols() {
1184 let pred = predictions[[i, j]]
1185 .max(self.epsilon)
1186 .min(1.0 - self.epsilon);
1187 let target = targets[[i, j]];
1188
1189 let ce = -target * pred.ln();
1191
1192 let poly_term = if target > 0.5 {
1195 self.poly_coeff * (1.0 - pred)
1197 } else {
1198 self.poly_coeff * pred
1200 };
1201
1202 total_loss += ce + poly_term;
1203 }
1204 }
1205
1206 Ok(total_loss / n)
1207 }
1208
1209 fn gradient(
1210 &self,
1211 predictions: &ArrayView<f64, Ix2>,
1212 targets: &ArrayView<f64, Ix2>,
1213 ) -> TrainResult<Array<f64, Ix2>> {
1214 let n = predictions.nrows() as f64;
1215 let mut grad = Array::zeros(predictions.raw_dim());
1216
1217 for i in 0..predictions.nrows() {
1218 for j in 0..predictions.ncols() {
1219 let pred = predictions[[i, j]]
1220 .max(self.epsilon)
1221 .min(1.0 - self.epsilon);
1222 let target = targets[[i, j]];
1223
1224 let ce_grad = -target / pred;
1226
1227 let poly_grad = if target > 0.5 {
1229 -self.poly_coeff
1231 } else {
1232 self.poly_coeff
1234 };
1235
1236 grad[[i, j]] = (ce_grad + poly_grad) / n;
1237 }
1238 }
1239
1240 Ok(grad)
1241 }
1242
1243 fn name(&self) -> &str {
1244 "poly_loss"
1245 }
1246}
1247
1248#[cfg(test)]
1249mod tests {
1250 use super::*;
1251 use scirs2_core::ndarray::array;
1252
1253 #[test]
1254 fn test_cross_entropy_loss() {
1255 let loss = CrossEntropyLoss::default();
1256 let predictions = array![[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]];
1257 let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
1258
1259 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1260 assert!(loss_val > 0.0);
1261
1262 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1263 assert_eq!(grad.shape(), predictions.shape());
1264 }
1265
1266 #[test]
1267 fn test_mse_loss() {
1268 let loss = MseLoss;
1269 let predictions = array![[1.0, 2.0], [3.0, 4.0]];
1270 let targets = array![[1.5, 2.5], [3.5, 4.5]];
1271
1272 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1273 assert!((loss_val - 0.25).abs() < 1e-6);
1274
1275 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1276 assert_eq!(grad.shape(), predictions.shape());
1277 }
1278
1279 #[test]
1280 fn test_rule_satisfaction_loss() {
1281 let loss = RuleSatisfactionLoss::default();
1282 let rule_values = array![[0.9, 0.8], [0.95, 0.85]];
1283 let targets = array![[1.0, 1.0], [1.0, 1.0]];
1284
1285 let loss_val = loss.compute(&rule_values.view(), &targets.view()).unwrap();
1286 assert!(loss_val > 0.0);
1287
1288 let grad = loss.gradient(&rule_values.view(), &targets.view()).unwrap();
1289 assert_eq!(grad.shape(), rule_values.shape());
1290 }
1291
1292 #[test]
1293 fn test_constraint_violation_loss() {
1294 let loss = ConstraintViolationLoss::default();
1295 let constraint_values = array![[0.1, -0.1], [0.2, -0.2]];
1296 let targets = array![[0.0, 0.0], [0.0, 0.0]];
1297
1298 let loss_val = loss
1299 .compute(&constraint_values.view(), &targets.view())
1300 .unwrap();
1301 assert!(loss_val > 0.0);
1302
1303 let grad = loss
1304 .gradient(&constraint_values.view(), &targets.view())
1305 .unwrap();
1306 assert_eq!(grad.shape(), constraint_values.shape());
1307 }
1308
1309 #[test]
1310 fn test_focal_loss() {
1311 let loss = FocalLoss::default();
1312 let predictions = array![[0.9, 0.1], [0.2, 0.8]];
1313 let targets = array![[1.0, 0.0], [0.0, 1.0]];
1314
1315 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1316 assert!(loss_val >= 0.0);
1317
1318 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1319 assert_eq!(grad.shape(), predictions.shape());
1320 }
1321
1322 #[test]
1323 fn test_huber_loss() {
1324 let loss = HuberLoss::default();
1325 let predictions = array![[1.0, 3.0], [2.0, 5.0]];
1326 let targets = array![[1.5, 2.0], [2.5, 4.0]];
1327
1328 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1329 assert!(loss_val > 0.0);
1330
1331 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1332 assert_eq!(grad.shape(), predictions.shape());
1333 }
1334
1335 #[test]
1336 fn test_bce_with_logits_loss() {
1337 let loss = BCEWithLogitsLoss;
1338 let logits = array![[0.5, -0.5], [1.0, -1.0]];
1339 let targets = array![[1.0, 0.0], [1.0, 0.0]];
1340
1341 let loss_val = loss.compute(&logits.view(), &targets.view()).unwrap();
1342 assert!(loss_val >= 0.0);
1343
1344 let grad = loss.gradient(&logits.view(), &targets.view()).unwrap();
1345 assert_eq!(grad.shape(), logits.shape());
1346 }
1347
1348 #[test]
1349 fn test_dice_loss() {
1350 let loss = DiceLoss::default();
1351 let predictions = array![[0.9, 0.1], [0.8, 0.2]];
1352 let targets = array![[1.0, 0.0], [1.0, 0.0]];
1353
1354 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1355 assert!(loss_val >= 0.0);
1356 assert!(loss_val <= 1.0);
1357
1358 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1359 assert_eq!(grad.shape(), predictions.shape());
1360 }
1361
1362 #[test]
1363 fn test_tversky_loss() {
1364 let loss = TverskyLoss::default();
1365 let predictions = array![[0.9, 0.1], [0.8, 0.2]];
1366 let targets = array![[1.0, 0.0], [1.0, 0.0]];
1367
1368 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1369 assert!(loss_val >= 0.0);
1370 assert!(loss_val <= 1.0);
1371
1372 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1373 assert_eq!(grad.shape(), predictions.shape());
1374 }
1375
1376 #[test]
1377 fn test_contrastive_loss() {
1378 let loss = ContrastiveLoss::default();
1379 let predictions = array![[0.5, 0.0], [1.5, 0.0], [0.2, 0.0]];
1382 let targets = array![[1.0], [0.0], [1.0]];
1383
1384 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1385 assert!(loss_val >= 0.0);
1386
1387 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1388 assert_eq!(grad.shape(), predictions.shape());
1389
1390 assert!(grad[[0, 0]] > 0.0);
1392 assert_eq!(grad[[1, 0]], 0.0);
1394 }
1395
1396 #[test]
1397 fn test_triplet_loss() {
1398 let loss = TripletLoss::default();
1399 let predictions = array![[0.5, 2.0], [1.0, 0.5], [0.3, 1.5]];
1401 let targets = array![[0.0], [0.0], [0.0]]; let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1404 assert!(loss_val >= 0.0);
1405
1406 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1407 assert_eq!(grad.shape(), predictions.shape());
1408
1409 assert_eq!(grad[[0, 0]], 0.0);
1411 assert_eq!(grad[[0, 1]], 0.0);
1412
1413 assert!(grad[[1, 0]] > 0.0);
1415 assert!(grad[[1, 1]] < 0.0);
1416 }
1417
1418 #[test]
1419 fn test_hinge_loss() {
1420 let loss = HingeLoss::default();
1421 let predictions = array![[0.5, -0.5], [2.0, -2.0]];
1423 let targets = array![[1.0, -1.0], [1.0, -1.0]];
1424
1425 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1426 assert!(loss_val >= 0.0);
1427
1428 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1429 assert_eq!(grad.shape(), predictions.shape());
1430
1431 assert_eq!(grad[[1, 0]], 0.0);
1433 assert_eq!(grad[[1, 1]], 0.0);
1434 }
1435
1436 #[test]
1437 fn test_kl_divergence_loss() {
1438 let loss = KLDivergenceLoss::default();
1439 let predictions = array![[0.6, 0.4], [0.7, 0.3]];
1441 let targets = array![[0.5, 0.5], [0.8, 0.2]];
1442
1443 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1444 assert!(loss_val >= 0.0);
1445
1446 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1447 assert_eq!(grad.shape(), predictions.shape());
1448
1449 let identical_preds = array![[0.5, 0.5]];
1451 let identical_targets = array![[0.5, 0.5]];
1452 let identical_loss = loss
1453 .compute(&identical_preds.view(), &identical_targets.view())
1454 .unwrap();
1455 assert!(identical_loss.abs() < 1e-6);
1456 }
1457
1458 #[test]
1459 fn test_poly_loss() {
1460 let loss = PolyLoss::default();
1461 let predictions = array![[0.9, 0.1], [0.2, 0.8]];
1462 let targets = array![[1.0, 0.0], [0.0, 1.0]];
1463
1464 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1465 assert!(loss_val > 0.0);
1466
1467 let grad = loss.gradient(&predictions.view(), &targets.view()).unwrap();
1468 assert_eq!(grad.shape(), predictions.shape());
1469
1470 let ce_loss = CrossEntropyLoss::default();
1472 let ce_val = ce_loss
1473 .compute(&predictions.view(), &targets.view())
1474 .unwrap();
1475
1476 assert!(loss_val >= ce_val);
1478 }
1479
1480 #[test]
1481 fn test_poly_loss_custom_coefficient() {
1482 let loss = PolyLoss::new(2.0);
1483 let predictions = array![[0.8, 0.2]];
1484 let targets = array![[1.0, 0.0]];
1485
1486 let loss_val = loss.compute(&predictions.view(), &targets.view()).unwrap();
1487 assert!(loss_val > 0.0);
1488
1489 let loss_low_coeff = PolyLoss::new(0.5);
1491 let loss_val_low = loss_low_coeff
1492 .compute(&predictions.view(), &targets.view())
1493 .unwrap();
1494
1495 assert!(loss_val > loss_val_low);
1496 }
1497}