Skip to main content

tensorlogic_train/loss/
rulesatisfactionloss_traits.rs

1//! # RuleSatisfactionLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `RuleSatisfactionLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `Loss`
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::RuleSatisfactionLoss;
15
16impl Default for RuleSatisfactionLoss {
17    fn default() -> Self {
18        Self { temperature: 1.0 }
19    }
20}
21
22impl Loss for RuleSatisfactionLoss {
23    fn compute(
24        &self,
25        rule_values: &ArrayView<f64, Ix2>,
26        targets: &ArrayView<f64, Ix2>,
27    ) -> TrainResult<f64> {
28        if rule_values.shape() != targets.shape() {
29            return Err(TrainError::LossError(format!(
30                "Shape mismatch: rule_values {:?} vs targets {:?}",
31                rule_values.shape(),
32                targets.shape()
33            )));
34        }
35        let n = rule_values.len() as f64;
36        let mut total_loss = 0.0;
37        for i in 0..rule_values.nrows() {
38            for j in 0..rule_values.ncols() {
39                let diff = targets[[i, j]] - rule_values[[i, j]];
40                total_loss += (diff / self.temperature).powi(2);
41            }
42        }
43        Ok(total_loss / n)
44    }
45    fn gradient(
46        &self,
47        rule_values: &ArrayView<f64, Ix2>,
48        targets: &ArrayView<f64, Ix2>,
49    ) -> TrainResult<Array<f64, Ix2>> {
50        if rule_values.shape() != targets.shape() {
51            return Err(TrainError::LossError(format!(
52                "Shape mismatch: rule_values {:?} vs targets {:?}",
53                rule_values.shape(),
54                targets.shape()
55            )));
56        }
57        let n = rule_values.len() as f64;
58        let mut grad = Array::zeros(rule_values.raw_dim());
59        for i in 0..rule_values.nrows() {
60            for j in 0..rule_values.ncols() {
61                let diff = targets[[i, j]] - rule_values[[i, j]];
62                grad[[i, j]] = -2.0 * diff / (self.temperature * self.temperature * n);
63            }
64        }
65        Ok(grad)
66    }
67}