1use std::collections::{BTreeMap, BTreeSet};
20
21use chrono::{DateTime, Utc};
22use serde::{Deserialize, Serialize};
23
24use super::types::default_policy_schema_version;
25
26const CONFIDENCE_LEARN_RATE: f32 = 0.30;
29const MAX_CONFIDENCE_DELTA: f32 = 0.15;
32const MIN_LEARNED_CONFIDENCE: f32 = 0.05;
35const MAX_LEARNED_CONFIDENCE: f32 = 0.99;
36const ALLOWLIST_MIN_FP_SAMPLES: usize = 8;
40const ALLOWLIST_MAX_TP_RATE: f32 = 0.15;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum Disposition {
49 TruePositive,
51 FalsePositive,
53 Benign,
56}
57
58impl Disposition {
59 fn is_true_positive(self) -> bool {
60 matches!(self, Disposition::TruePositive)
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(deny_unknown_fields)]
67pub struct DispositionRecord {
68 pub finding_fingerprint: String,
69 pub rule_id: String,
70 #[serde(default, skip_serializing_if = "Option::is_none")]
71 pub sha256: Option<String>,
72 pub analyst_disposition: Disposition,
73 pub recorded_at: DateTime<Utc>,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub note: Option<String>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
81#[serde(deny_unknown_fields)]
82pub struct DispositionOverlay {
83 #[serde(default = "default_policy_schema_version")]
84 pub schema_version: String,
85 #[serde(default)]
86 pub records: Vec<DispositionRecord>,
87}
88
89impl Default for DispositionOverlay {
90 fn default() -> Self {
91 Self {
92 schema_version: default_policy_schema_version(),
93 records: Vec::new(),
94 }
95 }
96}
97
98fn per_rule_counts(overlay: &DispositionOverlay) -> BTreeMap<String, (usize, usize)> {
101 let mut counts: BTreeMap<String, (usize, usize)> = BTreeMap::new();
102 for r in &overlay.records {
103 let entry = counts.entry(r.rule_id.clone()).or_insert((0, 0));
104 if r.analyst_disposition.is_true_positive() {
105 entry.0 += 1;
106 } else {
107 entry.1 += 1;
108 }
109 }
110 counts
111}
112
113fn smoothed_tp_rate(tp: usize, fp: usize) -> f32 {
115 (tp as f32 + 1.0) / (tp as f32 + fp as f32 + 2.0)
116}
117
118#[must_use]
121pub fn learned_confidence_adjustments(overlay: &DispositionOverlay) -> BTreeMap<String, f32> {
122 per_rule_counts(overlay)
123 .into_iter()
124 .map(|(rule, (tp, fp))| {
125 let delta = (CONFIDENCE_LEARN_RATE * (smoothed_tp_rate(tp, fp) - 0.5))
126 .clamp(-MAX_CONFIDENCE_DELTA, MAX_CONFIDENCE_DELTA);
127 (rule, delta)
128 })
129 .collect()
130}
131
132#[must_use]
136pub fn learned_allowlist(overlay: &DispositionOverlay) -> BTreeSet<String> {
137 per_rule_counts(overlay)
138 .into_iter()
139 .filter(|&(_, (tp, fp))| {
140 fp >= ALLOWLIST_MIN_FP_SAMPLES && smoothed_tp_rate(tp, fp) < ALLOWLIST_MAX_TP_RATE
141 })
142 .map(|(rule, _)| rule)
143 .collect()
144}
145
146#[must_use]
149pub fn adjust_confidence(base: f32, delta: f32) -> f32 {
150 (base + delta).clamp(MIN_LEARNED_CONFIDENCE, MAX_LEARNED_CONFIDENCE)
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 fn rec(rule: &str, d: Disposition) -> DispositionRecord {
158 DispositionRecord {
159 finding_fingerprint: format!("fp-{rule}-{d:?}"),
160 rule_id: rule.to_string(),
161 sha256: None,
162 analyst_disposition: d,
163 recorded_at: Utc::now(),
164 note: None,
165 }
166 }
167
168 fn overlay(records: Vec<DispositionRecord>) -> DispositionOverlay {
169 DispositionOverlay {
170 schema_version: "1".into(),
171 records,
172 }
173 }
174
175 #[test]
178 fn confidence_delta_is_monotone_in_tp_ratio() {
179 let mostly_tp = overlay(vec![
180 rec("R", Disposition::TruePositive),
181 rec("R", Disposition::TruePositive),
182 rec("R", Disposition::TruePositive),
183 rec("R", Disposition::FalsePositive),
184 ]);
185 let mostly_fp = overlay(vec![
186 rec("R", Disposition::FalsePositive),
187 rec("R", Disposition::FalsePositive),
188 rec("R", Disposition::FalsePositive),
189 rec("R", Disposition::TruePositive),
190 ]);
191 let up = learned_confidence_adjustments(&mostly_tp)["R"];
192 let down = learned_confidence_adjustments(&mostly_fp)["R"];
193 assert!(up > 0.0, "TP-heavy must raise confidence: {up}");
194 assert!(down < 0.0, "FP-heavy must lower confidence: {down}");
195 assert!(up > down);
196 }
197
198 #[test]
201 fn confidence_delta_is_hard_bounded() {
202 let flood: Vec<_> = (0..10_000)
203 .map(|_| rec("R", Disposition::TruePositive))
204 .collect();
205 let d = learned_confidence_adjustments(&overlay(flood))["R"];
206 assert!(d <= MAX_CONFIDENCE_DELTA, "delta exceeded the cap: {d}");
207 assert!(adjust_confidence(0.95, d) <= MAX_LEARNED_CONFIDENCE);
208 assert!(adjust_confidence(0.0, -1.0) >= MIN_LEARNED_CONFIDENCE);
209 }
210
211 #[test]
214 fn allowlist_requires_min_samples_and_low_tp_rate() {
215 let few_fp = overlay(vec![
216 rec("R", Disposition::FalsePositive),
217 rec("R", Disposition::FalsePositive),
218 rec("R", Disposition::FalsePositive),
219 ]);
220 assert!(
221 !learned_allowlist(&few_fp).contains("R"),
222 "3 FP must not allowlist"
223 );
224
225 let many_fp = overlay(
226 (0..ALLOWLIST_MIN_FP_SAMPLES)
227 .map(|_| rec("R", Disposition::FalsePositive))
228 .collect(),
229 );
230 assert!(
231 learned_allowlist(&many_fp).contains("R"),
232 "{ALLOWLIST_MIN_FP_SAMPLES} FP with ~0 TP rate must allowlist"
233 );
234
235 let mut mixed: Vec<_> = (0..ALLOWLIST_MIN_FP_SAMPLES)
236 .map(|_| rec("R", Disposition::FalsePositive))
237 .collect();
238 mixed.extend((0..ALLOWLIST_MIN_FP_SAMPLES).map(|_| rec("R", Disposition::TruePositive)));
239 assert!(
240 !learned_allowlist(&overlay(mixed)).contains("R"),
241 "a high TP rate must keep the rule active even with many FP"
242 );
243 }
244
245 #[test]
248 fn empty_overlay_is_identity() {
249 let o = overlay(vec![]);
250 assert!(learned_confidence_adjustments(&o).is_empty());
251 assert!(learned_allowlist(&o).is_empty());
252 }
253
254 #[test]
257 fn overlay_deserialises_additively() {
258 let json = r#"{"records":[{"finding_fingerprint":"x","rule_id":"R","analyst_disposition":"false_positive","recorded_at":"2026-01-01T00:00:00Z"}]}"#;
259 let o: DispositionOverlay = serde_json::from_str(json).unwrap();
260 assert_eq!(o.records.len(), 1);
261 assert_eq!(o.records[0].analyst_disposition, Disposition::FalsePositive);
262 assert!(!o.schema_version.is_empty());
263 }
264}