Skip to main content

tensorlogic_train/loss/
constraintviolationloss_traits.rs

1//! # ConstraintViolationLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `ConstraintViolationLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `Loss`
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::ConstraintViolationLoss;
15
16impl Default for ConstraintViolationLoss {
17    fn default() -> Self {
18        Self {
19            penalty_weight: 10.0,
20        }
21    }
22}
23
24impl Loss for ConstraintViolationLoss {
25    fn compute(
26        &self,
27        constraint_values: &ArrayView<f64, Ix2>,
28        targets: &ArrayView<f64, Ix2>,
29    ) -> TrainResult<f64> {
30        if constraint_values.shape() != targets.shape() {
31            return Err(TrainError::LossError(format!(
32                "Shape mismatch: constraint_values {:?} vs targets {:?}",
33                constraint_values.shape(),
34                targets.shape()
35            )));
36        }
37        let n = constraint_values.len() as f64;
38        let mut total_loss = 0.0;
39        for i in 0..constraint_values.nrows() {
40            for j in 0..constraint_values.ncols() {
41                let violation = (constraint_values[[i, j]] - targets[[i, j]]).max(0.0);
42                total_loss += self.penalty_weight * violation * violation;
43            }
44        }
45        Ok(total_loss / n)
46    }
47    fn gradient(
48        &self,
49        constraint_values: &ArrayView<f64, Ix2>,
50        targets: &ArrayView<f64, Ix2>,
51    ) -> TrainResult<Array<f64, Ix2>> {
52        if constraint_values.shape() != targets.shape() {
53            return Err(TrainError::LossError(format!(
54                "Shape mismatch: constraint_values {:?} vs targets {:?}",
55                constraint_values.shape(),
56                targets.shape()
57            )));
58        }
59        let n = constraint_values.len() as f64;
60        let mut grad = Array::zeros(constraint_values.raw_dim());
61        for i in 0..constraint_values.nrows() {
62            for j in 0..constraint_values.ncols() {
63                let violation = constraint_values[[i, j]] - targets[[i, j]];
64                if violation > 0.0 {
65                    grad[[i, j]] = 2.0 * self.penalty_weight * violation / n;
66                }
67            }
68        }
69        Ok(grad)
70    }
71}