Skip to main content

ski/
pipeline.rs

1//! The shared decision pipeline: stage selection + gating, single-sourced so the
2//! hook (hot path), `ski why` (tuning aid), and `examples/eval` run identical math.
3//!
4//! Previously each of the three re-implemented the stage cascade. `ski why` had
5//! drifted the furthest — it ranked with [`crate::rank::rank_all`] (no
6//! context/file/project channels) and reranked the *bare* prompt, so it could star
7//! a different skill than the hook would inject on the same prompt. The tool meant
8//! to explain the ranker didn't reproduce it.
9//!
10//! The caller owns the inputs that depend on conversation state (the query vector,
11//! the context blend, file/project channels, the rerank query); this module owns
12//! the cascade over the resulting `hits`:
13//! 1. a dominant lexical (BM25) winner injects directly, unless stage-1 already has
14//!    a confident lone dense winner;
15//! 2. otherwise the cross-encoder arbitrates the ambiguous middle;
16//! 3. otherwise the cheap stage-1 cosine result stands.
17//!
18//! It returns the winning stage, its rows (for display), and the hits that clear
19//! that stage's gate — *before* deny / session-dedup / slash-removal / `max_skills`,
20//! which stay with the caller (only the hook has a session).
21
22use crate::confidence::Stage;
23use crate::config::Config;
24use crate::index::Index;
25use crate::lexical::{self, Lex};
26use crate::rank::Hit;
27use crate::rerank;
28
29/// The outcome of the decision cascade for one prompt.
30#[derive(Debug)]
31pub struct Plan {
32    /// Which stage produced the decision.
33    pub stage: Stage,
34    /// The winning stage's ranking, for display: the reranked list when the
35    /// cross-encoder fired, otherwise the stage-1 hits. (For the lexical fast-path
36    /// these are still the stage-1 hits; the winner is in [`Plan::lexical`].)
37    pub rows: Vec<Hit>,
38    /// The lexical fast-path winner, if one fired.
39    pub lexical: Option<Lex>,
40    /// Hits that clear the winning stage's gate, in rank order, *before*
41    /// deny / dedup / slash-removal / cap. For the lexical stage this is the single
42    /// dominant winner (its stage-1 [`Hit`], pulled from `rows`).
43    pub passed: Vec<Hit>,
44    /// Display threshold for the winning stage (`min_similarity` / `rerank_min` /
45    /// `lexical_min`).
46    pub threshold: f32,
47}
48
49/// Stage-1 cosine gate: the hits clearing the absolute floor (`min_similarity`)
50/// and within the relative margin (`score_margin`) of the leader, plus any `force`d
51/// skill on a keyword hit. Pre deny/dedup/cap; pure (no IO), so it is the unit-test
52/// seam for the gate the hook used to inline in `select`.
53pub fn cosine_passed(hits: &[Hit], cfg: &Config) -> Vec<Hit> {
54    let top = hits.first().map(|h| h.score).unwrap_or(0.0);
55    hits.iter()
56        .filter(|h| {
57            let forced = cfg.force.contains(&h.id) && h.keyword > 0.0;
58            forced || (h.score >= cfg.min_similarity && h.score >= top - cfg.score_margin)
59        })
60        .cloned()
61        .collect()
62}
63
64/// Run the stage cascade over already-ranked `hits` (from
65/// [`crate::rank::rank_all_ctx`]). `prompt` is the bare user prompt (for the lexical
66/// channel); `rerank_query` is the context-enriched query the cross-encoder reads.
67pub fn decide(hits: &[Hit], idx: &Index, prompt: &str, rerank_query: &str, cfg: &Config) -> Plan {
68    // Stage 1.5: a dominant lexical (BM25-over-description) winner injects directly
69    // — high precision exactly where the bi-encoder cosine is muddy — unless stage-1
70    // already has a confident lone dense winner, which is trusted outright.
71    if !rerank::confident_winner(hits, cfg) {
72        if let Some(win) = lexical::dominant(prompt, idx, cfg) {
73            let passed = hits.iter().filter(|h| h.id == win.id).cloned().collect();
74            return Plan {
75                stage: Stage::Lexical,
76                rows: hits.to_vec(),
77                lexical: Some(win),
78                passed,
79                threshold: cfg.lexical_min,
80            };
81        }
82    }
83    // Stage 2: the cross-encoder arbitrates the ambiguous middle; a confident winner
84    // / nothing-relevant keeps the cheap stage-1 result.
85    match rerank::is_ambiguous(hits, cfg)
86        .then(|| rerank::rerank(hits, idx, rerank_query, cfg))
87        .flatten()
88    {
89        Some(reranked) => {
90            let passed = rerank::passes(&reranked, cfg);
91            Plan {
92                stage: Stage::Rerank,
93                rows: reranked,
94                lexical: None,
95                passed,
96                threshold: cfg.rerank_min,
97            }
98        }
99        None => Plan {
100            stage: Stage::Cosine,
101            passed: cosine_passed(hits, cfg),
102            rows: hits.to_vec(),
103            lexical: None,
104            threshold: cfg.min_similarity,
105        },
106    }
107}
108
109/// Human-readable stage label for `ski why` / `examples/eval` display.
110/// `model` names the stage-1 embedder (only shown for the cosine stage).
111pub fn stage_label(stage: Stage, model: &str) -> String {
112    match stage {
113        Stage::Cosine => format!("stage1:{model}"),
114        Stage::Rerank => "rerank:turbo".to_string(),
115        Stage::Lexical => "lexical(BM25)".to_string(),
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    fn hit(id: &str, score: f32, keyword: f32) -> Hit {
124        Hit {
125            id: id.to_string(),
126            name: id.to_string(),
127            cosine: score - keyword,
128            context: 0.0,
129            file: 0.0,
130            project: 0.0,
131            keyword,
132            phrase: 0.0,
133            score,
134        }
135    }
136
137    #[test]
138    fn cosine_passed_applies_floor_and_margin() {
139        let cfg = Config::default(); // min 0.30, margin 0.15
140        let hits = vec![
141            hit("a", 0.90, 0.0),
142            hit("b", 0.80, 0.0), // within margin of 0.90
143            hit("c", 0.50, 0.0), // below 0.90 - 0.15 margin
144            hit("d", 0.10, 0.0), // below the floor
145        ];
146        let got: Vec<String> = cosine_passed(&hits, &cfg)
147            .into_iter()
148            .map(|h| h.id)
149            .collect();
150        assert_eq!(got, ["a", "b"]);
151    }
152
153    #[test]
154    fn cosine_passed_force_bypasses_floor_on_keyword() {
155        let cfg = Config {
156            force: vec!["x".to_string()],
157            ..Default::default()
158        };
159        // x is sub-floor but forced with a keyword hit; y is sub-floor, not forced.
160        let hits = vec![hit("x", 0.10, 0.15), hit("y", 0.20, 0.0)];
161        let got: Vec<String> = cosine_passed(&hits, &cfg)
162            .into_iter()
163            .map(|h| h.id)
164            .collect();
165        assert_eq!(got, ["x"]);
166    }
167
168    #[test]
169    fn stage_label_renders_each_stage() {
170        assert_eq!(stage_label(Stage::Cosine, "bge"), "stage1:bge");
171        assert_eq!(stage_label(Stage::Rerank, "bge"), "rerank:turbo");
172        assert_eq!(stage_label(Stage::Lexical, "bge"), "lexical(BM25)");
173    }
174}