ski/pipeline.rs
1//! The shared decision pipeline: stage selection + gating, single-sourced so the
2//! hook (hot path), `ski why` (tuning aid), and `examples/eval` run identical math.
3//!
4//! Previously each of the three re-implemented the stage cascade. `ski why` had
5//! drifted the furthest — it ranked with [`crate::rank::rank_all`] (no
6//! context/file/project channels) and reranked the *bare* prompt, so it could star
7//! a different skill than the hook would inject on the same prompt. The tool meant
8//! to explain the ranker didn't reproduce it.
9//!
10//! The caller owns the inputs that depend on conversation state (the query vector,
11//! the context blend, file/project channels, the rerank query); this module owns
12//! the cascade over the resulting `hits`:
13//! 1. a dominant lexical (BM25) winner injects directly, unless stage-1 already has
14//! a confident lone dense winner;
15//! 2. otherwise the cross-encoder arbitrates the ambiguous middle;
16//! 3. otherwise the cheap stage-1 cosine result stands.
17//!
18//! It returns the winning stage, its rows (for display), and the hits that clear
19//! that stage's gate — *before* deny / session-dedup / slash-removal / `max_skills`,
20//! which stay with the caller (only the hook has a session).
21
22use crate::confidence::Stage;
23use crate::config::Config;
24use crate::index::Index;
25use crate::lexical::{self, Lex};
26use crate::rank::Hit;
27use crate::rerank;
28
29/// The outcome of the decision cascade for one prompt.
30#[derive(Debug)]
31pub struct Plan {
32 /// Which stage produced the decision.
33 pub stage: Stage,
34 /// The winning stage's ranking, for display: the reranked list when the
35 /// cross-encoder fired, otherwise the stage-1 hits. (For the lexical fast-path
36 /// these are still the stage-1 hits; the winner is in [`Plan::lexical`].)
37 pub rows: Vec<Hit>,
38 /// The lexical fast-path winner, if one fired.
39 pub lexical: Option<Lex>,
40 /// Hits that clear the winning stage's gate, in rank order, *before*
41 /// deny / dedup / slash-removal / cap. For the lexical stage this is the single
42 /// dominant winner (its stage-1 [`Hit`], pulled from `rows`).
43 pub passed: Vec<Hit>,
44 /// Display threshold for the winning stage (`min_similarity` / `rerank_min` /
45 /// `lexical_min`).
46 pub threshold: f32,
47}
48
49/// Stage-1 cosine gate: the hits clearing the absolute floor (`min_similarity`)
50/// and within the relative margin (`score_margin`) of the leader, plus any `force`d
51/// skill on a keyword hit. Pre deny/dedup/cap; pure (no IO), so it is the unit-test
52/// seam for the gate the hook used to inline in `select`.
53pub fn cosine_passed(hits: &[Hit], cfg: &Config) -> Vec<Hit> {
54 let top = hits.first().map(|h| h.score).unwrap_or(0.0);
55 hits.iter()
56 .filter(|h| {
57 let forced = cfg.force.contains(&h.id) && h.keyword > 0.0;
58 forced || (h.score >= cfg.min_similarity && h.score >= top - cfg.score_margin)
59 })
60 .cloned()
61 .collect()
62}
63
64/// Run the stage cascade over already-ranked `hits` (from
65/// [`crate::rank::rank_all_ctx`]). `prompt` is the bare user prompt (for the lexical
66/// channel); `rerank_query` is the context-enriched query the cross-encoder reads.
67pub fn decide(hits: &[Hit], idx: &Index, prompt: &str, rerank_query: &str, cfg: &Config) -> Plan {
68 // Stage 1.5: a dominant lexical (BM25-over-description) winner injects directly
69 // — high precision exactly where the bi-encoder cosine is muddy — unless stage-1
70 // already has a confident lone dense winner, which is trusted outright.
71 if !rerank::confident_winner(hits, cfg) {
72 if let Some(win) = lexical::dominant(prompt, idx, cfg) {
73 let passed = hits.iter().filter(|h| h.id == win.id).cloned().collect();
74 return Plan {
75 stage: Stage::Lexical,
76 rows: hits.to_vec(),
77 lexical: Some(win),
78 passed,
79 threshold: cfg.lexical_min,
80 };
81 }
82 }
83 // Stage 2: the cross-encoder arbitrates the ambiguous middle; a confident winner
84 // / nothing-relevant keeps the cheap stage-1 result.
85 match rerank::is_ambiguous(hits, cfg)
86 .then(|| rerank::rerank(hits, idx, rerank_query, cfg))
87 .flatten()
88 {
89 Some(reranked) => {
90 let passed = rerank::passes(&reranked, cfg);
91 Plan {
92 stage: Stage::Rerank,
93 rows: reranked,
94 lexical: None,
95 passed,
96 threshold: cfg.rerank_min,
97 }
98 }
99 None => Plan {
100 stage: Stage::Cosine,
101 passed: cosine_passed(hits, cfg),
102 rows: hits.to_vec(),
103 lexical: None,
104 threshold: cfg.min_similarity,
105 },
106 }
107}
108
109/// Human-readable stage label for `ski why` / `examples/eval` display.
110/// `model` names the stage-1 embedder (only shown for the cosine stage).
111pub fn stage_label(stage: Stage, model: &str) -> String {
112 match stage {
113 Stage::Cosine => format!("stage1:{model}"),
114 Stage::Rerank => "rerank:turbo".to_string(),
115 Stage::Lexical => "lexical(BM25)".to_string(),
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 fn hit(id: &str, score: f32, keyword: f32) -> Hit {
124 Hit {
125 id: id.to_string(),
126 name: id.to_string(),
127 cosine: score - keyword,
128 context: 0.0,
129 file: 0.0,
130 project: 0.0,
131 keyword,
132 phrase: 0.0,
133 score,
134 }
135 }
136
137 #[test]
138 fn cosine_passed_applies_floor_and_margin() {
139 let cfg = Config::default(); // min 0.30, margin 0.15
140 let hits = vec![
141 hit("a", 0.90, 0.0),
142 hit("b", 0.80, 0.0), // within margin of 0.90
143 hit("c", 0.50, 0.0), // below 0.90 - 0.15 margin
144 hit("d", 0.10, 0.0), // below the floor
145 ];
146 let got: Vec<String> = cosine_passed(&hits, &cfg)
147 .into_iter()
148 .map(|h| h.id)
149 .collect();
150 assert_eq!(got, ["a", "b"]);
151 }
152
153 #[test]
154 fn cosine_passed_force_bypasses_floor_on_keyword() {
155 let cfg = Config {
156 force: vec!["x".to_string()],
157 ..Default::default()
158 };
159 // x is sub-floor but forced with a keyword hit; y is sub-floor, not forced.
160 let hits = vec![hit("x", 0.10, 0.15), hit("y", 0.20, 0.0)];
161 let got: Vec<String> = cosine_passed(&hits, &cfg)
162 .into_iter()
163 .map(|h| h.id)
164 .collect();
165 assert_eq!(got, ["x"]);
166 }
167
168 #[test]
169 fn stage_label_renders_each_stage() {
170 assert_eq!(stage_label(Stage::Cosine, "bge"), "stage1:bge");
171 assert_eq!(stage_label(Stage::Rerank, "bge"), "rerank:turbo");
172 assert_eq!(stage_label(Stage::Lexical, "bge"), "lexical(BM25)");
173 }
174}