Skip to main content

talon_cli/mcp/session/
suppression.rs

1use super::ledger::{InjectedChunk, SuppressedRecall, SuppressionReason, TurnLedger};
2
3const CONFIDENCE_GATE: f64 = 0.60;
4
5/// Default per-turn score decay for chunks seen in prior turns.
6///
7/// Lower = more aggressive suppression; higher = more permissive re-injection.
8pub const DEFAULT_DECAY: f64 = 0.85;
9
10/// Returns the effective score of a chunk seen `turns_since` turns ago, after
11/// applying the per-turn decay multiplier.
12///
13/// `effective = raw × decay^turns_since`
14///
15/// The chunk passes suppression if `effective >= CONFIDENCE_GATE`.
16/// A chunk seen 0 turns ago (same turn) always returns 0.0 (never re-inject).
17fn effective_score(raw: f64, turns_since: usize, decay: f64) -> f64 {
18    if turns_since == 0 {
19        return 0.0;
20    }
21    raw * decay.powi(i32::try_from(turns_since).unwrap_or(i32::MAX))
22}
23
24/// A candidate from recall output before suppression filtering.
25#[derive(Debug)]
26pub struct RecallCandidate {
27    pub chunk_id: String,
28    pub path: String,
29    pub score: f64,
30    pub title: String,
31    pub snippet: String,
32}
33
34#[derive(Debug)]
35pub struct SuppressionResult {
36    pub injected: Vec<RecallCandidate>,
37    pub suppressed: Vec<SuppressedRecall>,
38}
39
40/// Apply output-level suppression to a list of recall candidates.
41///
42/// Suppresses chunks below the confidence gate or whose score, after applying
43/// the per-turn decay multiplier, falls below the gate. Does NOT use query
44/// similarity — we deduplicate injected context, not input messages.
45///
46/// `decay` is the per-turn multiplier (e.g. 0.85). A chunk last injected N
47/// turns ago has its raw score multiplied by `decay^N` before comparing to
48/// the confidence gate. If all candidates are suppressed, `injected` is empty
49/// and the caller must skip injection entirely rather than substituting
50/// lower-ranked results.
51#[must_use]
52pub fn apply_suppression(
53    candidates: Vec<RecallCandidate>,
54    ledger: &TurnLedger,
55    decay: f64,
56) -> SuppressionResult {
57    let mut injected = Vec::new();
58    let mut suppressed = Vec::new();
59
60    for candidate in candidates {
61        if candidate.score < CONFIDENCE_GATE {
62            suppressed.push(SuppressedRecall {
63                chunk_id: candidate.chunk_id,
64                path: candidate.path,
65                score: candidate.score,
66                reason: SuppressionReason::BelowConfidenceGate,
67            });
68            continue;
69        }
70
71        // Chunk-level decay.
72        if ledger
73            .turns_since_chunk_last_injected(&candidate.chunk_id)
74            .is_some_and(|n| effective_score(candidate.score, n, decay) < CONFIDENCE_GATE)
75        {
76            suppressed.push(SuppressedRecall {
77                chunk_id: candidate.chunk_id,
78                path: candidate.path,
79                score: candidate.score,
80                reason: SuppressionReason::SameChunkRecentlyInjected,
81            });
82            continue;
83        }
84
85        // Note-level decay: same multiplier as chunk, applied to the whole note path.
86        if ledger
87            .turns_since_note_last_injected(&candidate.path)
88            .is_some_and(|n| effective_score(candidate.score, n, decay) < CONFIDENCE_GATE)
89        {
90            suppressed.push(SuppressedRecall {
91                chunk_id: candidate.chunk_id,
92                path: candidate.path,
93                score: candidate.score,
94                reason: SuppressionReason::SameNoteRecentlyInjected,
95            });
96            continue;
97        }
98
99        injected.push(candidate);
100    }
101
102    SuppressionResult {
103        injected,
104        suppressed,
105    }
106}
107
108/// Builds an [`InjectedChunk`] record from a suppression-approved candidate.
109#[must_use]
110pub fn to_injected_chunk(candidate: &RecallCandidate) -> InjectedChunk {
111    InjectedChunk {
112        chunk_id: candidate.chunk_id.clone(),
113        path: candidate.path.clone(),
114        score: candidate.score,
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::{DEFAULT_DECAY, RecallCandidate, apply_suppression};
121    use crate::mcp::session::ledger::{InjectedChunk, SuppressionReason, TurnLedger, TurnRecord};
122
123    fn candidate(chunk_id: &str, path: &str, score: f64) -> RecallCandidate {
124        RecallCandidate {
125            chunk_id: chunk_id.to_owned(),
126            path: path.to_owned(),
127            score,
128            title: "Test".to_owned(),
129            snippet: "snippet".to_owned(),
130        }
131    }
132
133    fn ledger_with_chunk_n_turns_ago(chunk_id: &str, path: &str, n: usize) -> TurnLedger {
134        let mut ledger = TurnLedger::new();
135        // Insert the target chunk injection in turn 0.
136        ledger.record_turn(TurnRecord {
137            turn_id: "t0".to_owned(),
138            query_fingerprint: String::new(),
139            injected: vec![InjectedChunk {
140                chunk_id: chunk_id.to_owned(),
141                path: path.to_owned(),
142                score: 0.9,
143            }],
144            suppressed: vec![],
145            skipped: false,
146        });
147        // Add n subsequent empty turns so the chunk is n turns in the past.
148        for i in 0..n {
149            ledger.record_turn(TurnRecord {
150                turn_id: format!("e{i}"),
151                query_fingerprint: String::new(),
152                injected: vec![],
153                suppressed: vec![],
154                skipped: false,
155            });
156        }
157        ledger
158    }
159
160    // Verify the decay formula at DEFAULT_DECAY = 0.85, CONFIDENCE_GATE = 0.60:
161    //   effective = score × 0.85^turns_since
162    //   inject if effective >= 0.60
163
164    #[test]
165    fn low_score_chunk_suppressed_by_confidence_gate() {
166        // 0.65 passes raw gate (0.65 > 0.60) but after decay: 0.65 × 0.85 = 0.553 < 0.60 → suppressed
167        let ledger = ledger_with_chunk_n_turns_ago("c", "notes/foo.md", 1);
168        let result = apply_suppression(
169            vec![candidate("c", "notes/foo.md", 0.65)],
170            &ledger,
171            DEFAULT_DECAY,
172        );
173        assert_eq!(result.injected.len(), 0);
174        assert_eq!(
175            result.suppressed[0].reason,
176            SuppressionReason::SameChunkRecentlyInjected
177        );
178    }
179
180    #[test]
181    fn high_score_chunk_passes_one_turn_ago() {
182        // 0.85 × 0.85^1 = 0.72 >= 0.60 → high-confidence chunks still eligible after 1 turn
183        let ledger = ledger_with_chunk_n_turns_ago("c", "notes/foo.md", 1);
184        let result = apply_suppression(
185            vec![candidate("c", "notes/foo.md", 0.85)],
186            &ledger,
187            DEFAULT_DECAY,
188        );
189        assert_eq!(
190            result.injected.len(),
191            1,
192            "score 0.85 should pass after 1 turn with decay 0.85"
193        );
194    }
195
196    #[test]
197    fn moderate_score_suppressed_three_turns_ago() {
198        // 0.65 × 0.85^3 = 0.65 × 0.614 = 0.399 < 0.60 → suppressed
199        let ledger = ledger_with_chunk_n_turns_ago("c", "notes/foo.md", 3);
200        let result = apply_suppression(
201            vec![candidate("c", "notes/foo.md", 0.65)],
202            &ledger,
203            DEFAULT_DECAY,
204        );
205        assert_eq!(result.injected.len(), 0);
206    }
207
208    #[test]
209    fn high_score_eligible_three_turns_ago() {
210        // 0.98 × 0.85^3 = 0.98 × 0.614 = 0.602 >= 0.60 → passes
211        let ledger = ledger_with_chunk_n_turns_ago("c", "notes/foo.md", 3);
212        let result = apply_suppression(
213            vec![candidate("c", "notes/foo.md", 0.98)],
214            &ledger,
215            DEFAULT_DECAY,
216        );
217        assert_eq!(
218            result.injected.len(),
219            1,
220            "score 0.98 should re-emerge after 3 turns with gate 0.60"
221        );
222    }
223
224    #[test]
225    fn below_confidence_gate_suppressed() {
226        let result = apply_suppression(
227            vec![candidate("new", "notes/bar.md", 0.2)],
228            &TurnLedger::new(),
229            DEFAULT_DECAY,
230        );
231        assert_eq!(result.injected.len(), 0);
232        assert_eq!(
233            result.suppressed[0].reason,
234            SuppressionReason::BelowConfidenceGate
235        );
236    }
237
238    #[test]
239    fn novel_chunk_passes_through() {
240        let result = apply_suppression(
241            vec![candidate("new", "notes/new.md", 0.85)],
242            &TurnLedger::new(),
243            DEFAULT_DECAY,
244        );
245        assert_eq!(result.injected.len(), 1);
246        assert!(result.suppressed.is_empty());
247    }
248
249    #[test]
250    fn all_suppressed_means_empty_injected() {
251        // Both chunks have scores that decay below gate after 1 turn (0.46 × 0.85 = 0.391)
252        let ledger = ledger_with_chunk_n_turns_ago("c", "notes/foo.md", 1);
253        let result = apply_suppression(
254            vec![
255                candidate("c", "notes/foo.md", 0.46),
256                candidate("d", "notes/foo.md", 0.46),
257            ],
258            &ledger,
259            DEFAULT_DECAY,
260        );
261        assert_eq!(
262            result.injected.len(),
263            0,
264            "caller must skip injection, not substitute lower-ranked results"
265        );
266    }
267}