tensorlogic_train/loss/
constraintviolationloss_traits.rs1use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::ConstraintViolationLoss;
15
16impl Default for ConstraintViolationLoss {
17 fn default() -> Self {
18 Self {
19 penalty_weight: 10.0,
20 }
21 }
22}
23
24impl Loss for ConstraintViolationLoss {
25 fn compute(
26 &self,
27 constraint_values: &ArrayView<f64, Ix2>,
28 targets: &ArrayView<f64, Ix2>,
29 ) -> TrainResult<f64> {
30 if constraint_values.shape() != targets.shape() {
31 return Err(TrainError::LossError(format!(
32 "Shape mismatch: constraint_values {:?} vs targets {:?}",
33 constraint_values.shape(),
34 targets.shape()
35 )));
36 }
37 let n = constraint_values.len() as f64;
38 let mut total_loss = 0.0;
39 for i in 0..constraint_values.nrows() {
40 for j in 0..constraint_values.ncols() {
41 let violation = (constraint_values[[i, j]] - targets[[i, j]]).max(0.0);
42 total_loss += self.penalty_weight * violation * violation;
43 }
44 }
45 Ok(total_loss / n)
46 }
47 fn gradient(
48 &self,
49 constraint_values: &ArrayView<f64, Ix2>,
50 targets: &ArrayView<f64, Ix2>,
51 ) -> TrainResult<Array<f64, Ix2>> {
52 if constraint_values.shape() != targets.shape() {
53 return Err(TrainError::LossError(format!(
54 "Shape mismatch: constraint_values {:?} vs targets {:?}",
55 constraint_values.shape(),
56 targets.shape()
57 )));
58 }
59 let n = constraint_values.len() as f64;
60 let mut grad = Array::zeros(constraint_values.raw_dim());
61 for i in 0..constraint_values.nrows() {
62 for j in 0..constraint_values.ncols() {
63 let violation = constraint_values[[i, j]] - targets[[i, j]];
64 if violation > 0.0 {
65 grad[[i, j]] = 2.0 * self.penalty_weight * violation / n;
66 }
67 }
68 }
69 Ok(grad)
70 }
71}