Skip to main content

smos_domain/value_objects/
nli.rs

1//! `NliResult` / `NliScores` — natural-language-inference verdict value objects.
2//!
3//! The actual DeBERTa classifier lives in an adapter; the domain layer owns the
4//! policy: the canonical exact-match result and the threshold-based predicates
5//! that downstream merge and confidence logic consume.
6
7use crate::config::NliConfig;
8use crate::enums::{MergeReason, NliLabel};
9use serde::{Deserialize, Serialize};
10
11/// Per-label softmax scores produced by an NLI classifier.
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub struct NliScores {
14    pub entailment: f32,
15    pub neutral: f32,
16    pub contradiction: f32,
17}
18
19/// Full NLI verdict for a fact pair.
20///
21/// `available = false` marks graceful-degradation placeholders emitted when the
22/// classifier is unreachable. Downstream code must NOT treat those as "no
23/// contradiction detected" — they mean "not checked".
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct NliResult {
26    pub label: NliLabel,
27    pub scores: NliScores,
28    pub available: bool,
29}
30
31impl NliResult {
32    /// The canonical NLI result returned for an exact text match.
33    ///
34    /// Identical text is entailment by definition; this bypasses the model and
35    /// avoids DeBERTa's known quirk of returning `neutral` on identical pairs.
36    pub fn exact_match_result() -> Self {
37        Self {
38            label: NliLabel::Entailment,
39            scores: NliScores {
40                entailment: 1.0,
41                neutral: 0.0,
42                contradiction: 0.0,
43            },
44            available: true,
45        }
46    }
47
48    /// `true` iff the contradiction label dominates above the configured threshold.
49    pub fn is_contradiction(&self, cfg: &NliConfig) -> bool {
50        self.label == NliLabel::Contradiction
51            && self.scores.contradiction >= cfg.contradiction_threshold
52    }
53
54    /// `true` iff the entailment label dominates above the configured threshold.
55    pub fn is_entailment(&self, cfg: &NliConfig) -> bool {
56        self.label == NliLabel::Entailment && self.scores.entailment >= cfg.entailment_threshold
57    }
58
59    /// Classify the NLI verdict into a [`MergeReason`].
60    ///
61    /// - `available = false`  → `NeutralSkipped` (refuse to guess, §5.3).
62    /// - contradiction         → `Drift`.
63    /// - entailment            → `Merged`.
64    /// - neutral               → `NeutralSkipped`.
65    pub fn decide_merge(&self, cfg: &NliConfig) -> MergeReason {
66        if !self.available {
67            return MergeReason::NeutralSkipped;
68        }
69        if self.is_contradiction(cfg) {
70            return MergeReason::Drift;
71        }
72        if self.is_entailment(cfg) {
73            return MergeReason::Merged;
74        }
75        MergeReason::NeutralSkipped
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    fn cfg() -> NliConfig {
84        NliConfig::default()
85    }
86
87    fn nli(
88        label: NliLabel,
89        available: bool,
90        entailment: f32,
91        neutral: f32,
92        contradiction: f32,
93    ) -> NliResult {
94        NliResult {
95            label,
96            scores: NliScores {
97                entailment,
98                neutral,
99                contradiction,
100            },
101            available,
102        }
103    }
104
105    #[test]
106    fn exact_match_result_is_entailment_with_full_scores() {
107        let r = NliResult::exact_match_result();
108        assert_eq!(r.label, NliLabel::Entailment);
109        assert_eq!(r.scores.entailment, 1.0);
110        assert_eq!(r.scores.neutral, 0.0);
111        assert_eq!(r.scores.contradiction, 0.0);
112        assert!(r.available);
113    }
114
115    #[test]
116    fn is_contradiction_true_when_label_and_score_dominate() {
117        let r = nli(NliLabel::Contradiction, true, 0.1, 0.2, 0.7);
118        assert!(r.is_contradiction(&cfg()));
119    }
120
121    #[test]
122    fn is_contradiction_false_when_label_is_entailment() {
123        let r = nli(NliLabel::Entailment, true, 0.9, 0.05, 0.05);
124        assert!(!r.is_contradiction(&cfg()));
125    }
126
127    #[test]
128    fn is_contradiction_false_when_score_below_threshold() {
129        let r = nli(NliLabel::Contradiction, true, 0.4, 0.2, 0.4);
130        assert!(!r.is_contradiction(&cfg()));
131    }
132
133    #[test]
134    fn is_entailment_true_when_label_and_score_dominate() {
135        let r = nli(NliLabel::Entailment, true, 0.8, 0.1, 0.1);
136        assert!(r.is_entailment(&cfg()));
137    }
138
139    #[test]
140    fn is_entailment_false_when_label_is_contradiction() {
141        let r = nli(NliLabel::Contradiction, true, 0.3, 0.2, 0.5);
142        assert!(!r.is_entailment(&cfg()));
143    }
144
145    #[test]
146    fn is_entailment_false_when_score_below_threshold() {
147        let r = nli(NliLabel::Entailment, true, 0.5, 0.3, 0.2);
148        assert!(!r.is_entailment(&cfg()));
149    }
150
151    #[test]
152    fn decide_merge_entailment_yields_merged() {
153        let r = nli(NliLabel::Entailment, true, 0.8, 0.1, 0.1);
154        assert_eq!(r.decide_merge(&cfg()), MergeReason::Merged);
155    }
156
157    #[test]
158    fn decide_merge_contradiction_yields_drift() {
159        let r = nli(NliLabel::Contradiction, true, 0.1, 0.1, 0.8);
160        assert_eq!(r.decide_merge(&cfg()), MergeReason::Drift);
161    }
162
163    #[test]
164    fn decide_merge_neutral_yields_neutral_skipped() {
165        let r = nli(NliLabel::Neutral, true, 0.1, 0.8, 0.1);
166        assert_eq!(r.decide_merge(&cfg()), MergeReason::NeutralSkipped);
167    }
168
169    #[test]
170    fn decide_merge_unavailable_yields_neutral_skipped() {
171        let r = nli(NliLabel::Entailment, false, 1.0, 0.0, 0.0);
172        assert_eq!(r.decide_merge(&cfg()), MergeReason::NeutralSkipped);
173    }
174
175    #[test]
176    fn decide_merge_entailment_below_threshold_yields_neutral_skipped() {
177        let r = nli(NliLabel::Entailment, true, 0.5, 0.3, 0.2);
178        assert_eq!(r.decide_merge(&cfg()), MergeReason::NeutralSkipped);
179    }
180
181    #[test]
182    fn serde_roundtrip_preserves_nli_result() {
183        let r = nli(NliLabel::Contradiction, true, 0.1, 0.2, 0.7);
184        let json = serde_json::to_string(&r).unwrap();
185        let back: NliResult = serde_json::from_str(&json).unwrap();
186        assert_eq!(r, back);
187    }
188}