smos_domain/value_objects/
nli.rs1use crate::config::NliConfig;
8use crate::enums::{MergeReason, NliLabel};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub struct NliScores {
14 pub entailment: f32,
15 pub neutral: f32,
16 pub contradiction: f32,
17}
18
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct NliResult {
26 pub label: NliLabel,
27 pub scores: NliScores,
28 pub available: bool,
29}
30
31impl NliResult {
32 pub fn exact_match_result() -> Self {
37 Self {
38 label: NliLabel::Entailment,
39 scores: NliScores {
40 entailment: 1.0,
41 neutral: 0.0,
42 contradiction: 0.0,
43 },
44 available: true,
45 }
46 }
47
48 pub fn is_contradiction(&self, cfg: &NliConfig) -> bool {
50 self.label == NliLabel::Contradiction
51 && self.scores.contradiction >= cfg.contradiction_threshold
52 }
53
54 pub fn is_entailment(&self, cfg: &NliConfig) -> bool {
56 self.label == NliLabel::Entailment && self.scores.entailment >= cfg.entailment_threshold
57 }
58
59 pub fn decide_merge(&self, cfg: &NliConfig) -> MergeReason {
66 if !self.available {
67 return MergeReason::NeutralSkipped;
68 }
69 if self.is_contradiction(cfg) {
70 return MergeReason::Drift;
71 }
72 if self.is_entailment(cfg) {
73 return MergeReason::Merged;
74 }
75 MergeReason::NeutralSkipped
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 fn cfg() -> NliConfig {
84 NliConfig::default()
85 }
86
87 fn nli(
88 label: NliLabel,
89 available: bool,
90 entailment: f32,
91 neutral: f32,
92 contradiction: f32,
93 ) -> NliResult {
94 NliResult {
95 label,
96 scores: NliScores {
97 entailment,
98 neutral,
99 contradiction,
100 },
101 available,
102 }
103 }
104
105 #[test]
106 fn exact_match_result_is_entailment_with_full_scores() {
107 let r = NliResult::exact_match_result();
108 assert_eq!(r.label, NliLabel::Entailment);
109 assert_eq!(r.scores.entailment, 1.0);
110 assert_eq!(r.scores.neutral, 0.0);
111 assert_eq!(r.scores.contradiction, 0.0);
112 assert!(r.available);
113 }
114
115 #[test]
116 fn is_contradiction_true_when_label_and_score_dominate() {
117 let r = nli(NliLabel::Contradiction, true, 0.1, 0.2, 0.7);
118 assert!(r.is_contradiction(&cfg()));
119 }
120
121 #[test]
122 fn is_contradiction_false_when_label_is_entailment() {
123 let r = nli(NliLabel::Entailment, true, 0.9, 0.05, 0.05);
124 assert!(!r.is_contradiction(&cfg()));
125 }
126
127 #[test]
128 fn is_contradiction_false_when_score_below_threshold() {
129 let r = nli(NliLabel::Contradiction, true, 0.4, 0.2, 0.4);
130 assert!(!r.is_contradiction(&cfg()));
131 }
132
133 #[test]
134 fn is_entailment_true_when_label_and_score_dominate() {
135 let r = nli(NliLabel::Entailment, true, 0.8, 0.1, 0.1);
136 assert!(r.is_entailment(&cfg()));
137 }
138
139 #[test]
140 fn is_entailment_false_when_label_is_contradiction() {
141 let r = nli(NliLabel::Contradiction, true, 0.3, 0.2, 0.5);
142 assert!(!r.is_entailment(&cfg()));
143 }
144
145 #[test]
146 fn is_entailment_false_when_score_below_threshold() {
147 let r = nli(NliLabel::Entailment, true, 0.5, 0.3, 0.2);
148 assert!(!r.is_entailment(&cfg()));
149 }
150
151 #[test]
152 fn decide_merge_entailment_yields_merged() {
153 let r = nli(NliLabel::Entailment, true, 0.8, 0.1, 0.1);
154 assert_eq!(r.decide_merge(&cfg()), MergeReason::Merged);
155 }
156
157 #[test]
158 fn decide_merge_contradiction_yields_drift() {
159 let r = nli(NliLabel::Contradiction, true, 0.1, 0.1, 0.8);
160 assert_eq!(r.decide_merge(&cfg()), MergeReason::Drift);
161 }
162
163 #[test]
164 fn decide_merge_neutral_yields_neutral_skipped() {
165 let r = nli(NliLabel::Neutral, true, 0.1, 0.8, 0.1);
166 assert_eq!(r.decide_merge(&cfg()), MergeReason::NeutralSkipped);
167 }
168
169 #[test]
170 fn decide_merge_unavailable_yields_neutral_skipped() {
171 let r = nli(NliLabel::Entailment, false, 1.0, 0.0, 0.0);
172 assert_eq!(r.decide_merge(&cfg()), MergeReason::NeutralSkipped);
173 }
174
175 #[test]
176 fn decide_merge_entailment_below_threshold_yields_neutral_skipped() {
177 let r = nli(NliLabel::Entailment, true, 0.5, 0.3, 0.2);
178 assert_eq!(r.decide_merge(&cfg()), MergeReason::NeutralSkipped);
179 }
180
181 #[test]
182 fn serde_roundtrip_preserves_nli_result() {
183 let r = nli(NliLabel::Contradiction, true, 0.1, 0.2, 0.7);
184 let json = serde_json::to_string(&r).unwrap();
185 let back: NliResult = serde_json::from_str(&json).unwrap();
186 assert_eq!(r, back);
187 }
188}