tensorlogic_train/loss/
rulesatisfactionloss_traits.rs1use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::RuleSatisfactionLoss;
15
16impl Default for RuleSatisfactionLoss {
17 fn default() -> Self {
18 Self { temperature: 1.0 }
19 }
20}
21
22impl Loss for RuleSatisfactionLoss {
23 fn compute(
24 &self,
25 rule_values: &ArrayView<f64, Ix2>,
26 targets: &ArrayView<f64, Ix2>,
27 ) -> TrainResult<f64> {
28 if rule_values.shape() != targets.shape() {
29 return Err(TrainError::LossError(format!(
30 "Shape mismatch: rule_values {:?} vs targets {:?}",
31 rule_values.shape(),
32 targets.shape()
33 )));
34 }
35 let n = rule_values.len() as f64;
36 let mut total_loss = 0.0;
37 for i in 0..rule_values.nrows() {
38 for j in 0..rule_values.ncols() {
39 let diff = targets[[i, j]] - rule_values[[i, j]];
40 total_loss += (diff / self.temperature).powi(2);
41 }
42 }
43 Ok(total_loss / n)
44 }
45 fn gradient(
46 &self,
47 rule_values: &ArrayView<f64, Ix2>,
48 targets: &ArrayView<f64, Ix2>,
49 ) -> TrainResult<Array<f64, Ix2>> {
50 if rule_values.shape() != targets.shape() {
51 return Err(TrainError::LossError(format!(
52 "Shape mismatch: rule_values {:?} vs targets {:?}",
53 rule_values.shape(),
54 targets.shape()
55 )));
56 }
57 let n = rule_values.len() as f64;
58 let mut grad = Array::zeros(rule_values.raw_dim());
59 for i in 0..rule_values.nrows() {
60 for j in 0..rule_values.ncols() {
61 let diff = targets[[i, j]] - rule_values[[i, j]];
62 grad[[i, j]] = -2.0 * diff / (self.temperature * self.temperature * n);
63 }
64 }
65 Ok(grad)
66 }
67}