Skip to main content

zeph_index/
retriever.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Hybrid code retrieval: query classification, semantic search, budget packing.
5//!
6//! # Retrieval strategy
7//!
8//! [`classify_query`] inspects the free-text query for heuristic signals:
9//!
10//! | Signal | Examples | Strategy |
11//! |--------|----------|----------|
12//! | Symbol patterns only | `"fn my_fn"`, `"SkillMatcher::match"`, `"my_snake_func"` | [`RetrievalStrategy::Grep`] |
13//! | Conceptual patterns only | `"how does auth work?"`, `"explain the retry logic"` | [`RetrievalStrategy::Semantic`] |
14//! | Both | `"where is SkillMatcher used?"` | [`RetrievalStrategy::Hybrid`] |
15//!
16//! For `Grep` queries, [`CodeRetriever::retrieve`] returns an empty chunk list and
17//! the agent falls back to its shell grep tool. For `Semantic` and `Hybrid` queries
18//! an embedding round-trip is made and the top-scoring Qdrant results are packed
19//! within a token budget.
20//!
21//! # Token budget
22//!
23//! [`RetrievalConfig::budget_ratio`] controls what fraction of the caller's available
24//! context window is allocated to code chunks. The packing loop stops before adding a
25//! chunk that would exceed the budget, so the retrieved set always fits the window.
26
27use std::fmt::Write;
28use std::sync::Arc;
29
30use crate::error::Result;
31use crate::store::{CodeStore, SearchHit};
32use zeph_common::{EmbeddingVector, Unnormalized};
33use zeph_llm::any::AnyProvider;
34use zeph_llm::provider::LlmProvider;
35use zeph_memory::TokenCounter;
36
37/// The retrieval strategy selected by [`classify_query`] for a given query.
38///
39/// # Examples
40///
41/// ```
42/// use zeph_index::retriever::{RetrievalStrategy, classify_query};
43///
44/// assert_eq!(classify_query("how does authentication work?"), RetrievalStrategy::Semantic);
45/// assert_eq!(classify_query("fn my_handler"), RetrievalStrategy::Grep);
46/// assert_eq!(classify_query("where is MyHandler used?"), RetrievalStrategy::Hybrid);
47/// ```
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49#[non_exhaustive]
50pub enum RetrievalStrategy {
51    /// Vector similarity search for conceptual or descriptive queries.
52    ///
53    /// The query is embedded and the top-K chunks from Qdrant are returned.
54    Semantic,
55    /// Exact symbol lookup — the retriever returns an empty chunk list.
56    ///
57    /// The caller (agent) is expected to use a `grep` or `symbol_definition` tool
58    /// instead of the vector store for precise symbol lookups.
59    Grep,
60    /// Both semantic search **and** a hint that grep may also help.
61    ///
62    /// Semantic results are still returned, but the caller can additionally
63    /// perform a textual search for the identified symbol names.
64    Hybrid,
65}
66
67/// Configuration for [`CodeRetriever`].
68///
69/// # Examples
70///
71/// ```
72/// use zeph_index::retriever::RetrievalConfig;
73///
74/// let cfg = RetrievalConfig::default();
75/// assert_eq!(cfg.max_chunks, 12);
76/// assert!(cfg.score_threshold > 0.0);
77/// assert!(cfg.budget_ratio > 0.0 && cfg.budget_ratio < 1.0);
78/// ```
79#[derive(Debug, Clone)]
80pub struct RetrievalConfig {
81    /// Maximum number of chunks to fetch from Qdrant before applying score and budget filters.
82    pub max_chunks: usize,
83    /// Minimum cosine similarity score to accept (chunks below this are dropped).
84    pub score_threshold: f32,
85    /// Maximum fraction of `available_tokens` allocated to code chunks (0.0–1.0).
86    pub budget_ratio: f32,
87    /// Maximum seconds to wait for `provider.embed()` before returning
88    /// [`crate::error::IndexError::EmbedTimeout`]. Defaults to `10`.
89    pub embed_timeout_secs: u64,
90}
91
92impl Default for RetrievalConfig {
93    fn default() -> Self {
94        Self {
95            max_chunks: 12,
96            score_threshold: 0.25,
97            budget_ratio: 0.40,
98            embed_timeout_secs: 10,
99        }
100    }
101}
102
103/// The result of a single retrieval operation.
104///
105/// Returned by [`CodeRetriever::retrieve`] and [`CodeRetriever::retrieve_filtered`].
106/// Pass to [`format_as_context`] to produce an XML snippet for injection into the
107/// agent message.
108#[derive(Debug)]
109pub struct RetrievedCode {
110    /// Ordered list of matching chunks (highest score first, budget-capped).
111    pub chunks: Vec<SearchHit>,
112    /// Estimated total tokens consumed by `chunks` (including a small per-chunk overhead).
113    pub total_tokens: usize,
114    /// Strategy that was used to produce this result.
115    pub strategy: RetrievalStrategy,
116}
117
118/// Budget-aware code retriever with automatic query classification.
119///
120/// Wraps a [`CodeStore`] and an LLM provider (for embedding) and exposes a single
121/// high-level [`CodeRetriever::retrieve`] method.
122///
123/// # Examples
124///
125/// ```no_run
126/// use std::sync::Arc;
127/// use zeph_index::retriever::{CodeRetriever, RetrievalConfig, format_as_context};
128/// use zeph_index::store::CodeStore;
129/// # async fn example() -> zeph_index::Result<()> {
130/// # let store: CodeStore = panic!("placeholder");
131/// # let provider: Arc<zeph_llm::any::AnyProvider> = panic!("placeholder");
132///
133/// let retriever = CodeRetriever::new(store, provider, RetrievalConfig::default());
134/// let result = retriever.retrieve("explain how authentication works", 8_000).await?;
135/// let xml = format_as_context(&result);
136/// println!("{xml}");
137/// # Ok(())
138/// # }
139/// ```
140pub struct CodeRetriever {
141    store: CodeStore,
142    provider: Arc<AnyProvider>,
143    config: RetrievalConfig,
144    token_counter: Arc<TokenCounter>,
145}
146
147impl CodeRetriever {
148    /// Create a new `CodeRetriever`.
149    ///
150    /// `store` must have its Qdrant collection already created (see
151    /// [`CodeStore::ensure_collection`]).
152    #[must_use]
153    pub fn new(store: CodeStore, provider: Arc<AnyProvider>, config: RetrievalConfig) -> Self {
154        Self {
155            store,
156            provider,
157            config,
158            token_counter: Arc::new(TokenCounter::new()),
159        }
160    }
161
162    /// Retrieve relevant code chunks for a free-text query.
163    ///
164    /// Classifies `query` via [`classify_query`], then:
165    ///
166    /// * For [`RetrievalStrategy::Grep`] queries — returns an empty [`RetrievedCode`]
167    ///   so the agent falls back to its shell `grep` or `symbol_definition` tools.
168    /// * For [`RetrievalStrategy::Semantic`] / [`RetrievalStrategy::Hybrid`] — embeds
169    ///   the query, searches Qdrant, applies the score threshold, and packs results
170    ///   within `available_tokens * budget_ratio`.
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if the embedding call or Qdrant search fails.
175    #[tracing::instrument(name = "index.retriever.retrieve", skip(self), fields(%query, available_tokens))]
176    pub async fn retrieve(&self, query: &str, available_tokens: usize) -> Result<RetrievedCode> {
177        let strategy = classify_query(query);
178
179        let token_budget = budget_tokens(available_tokens, self.config.budget_ratio);
180
181        match strategy {
182            RetrievalStrategy::Grep => Ok(RetrievedCode {
183                chunks: vec![],
184                total_tokens: 0,
185                strategy,
186            }),
187            RetrievalStrategy::Semantic | RetrievalStrategy::Hybrid => {
188                let chunks = self
189                    .semantic_search(query, token_budget, None::<String>)
190                    .await?;
191                let total_tokens: usize = chunks
192                    .iter()
193                    .map(|c| self.token_counter.count_tokens(&c.code) + 20)
194                    .sum();
195                Ok(RetrievedCode {
196                    chunks,
197                    total_tokens,
198                    strategy,
199                })
200            }
201        }
202    }
203
204    /// Retrieve relevant code, restricting results to a single language.
205    ///
206    /// Behaves like [`CodeRetriever::retrieve`] but adds a Qdrant payload filter so
207    /// only chunks whose `language` field matches `language` are returned.
208    ///
209    /// Useful when the user or agent has already established the relevant language
210    /// (e.g. "show me the Python error handling" should not return Rust results).
211    ///
212    /// # Arguments
213    ///
214    /// * `language` — the language identifier as returned by [`crate::languages::Lang::id`]
215    ///   (e.g. `"rust"`, `"python"`).
216    ///
217    /// # Errors
218    ///
219    /// Returns an error if embedding or Qdrant search fails.
220    #[tracing::instrument(name = "index.retriever.retrieve_filtered", skip(self), fields(%query, available_tokens, %language))]
221    pub async fn retrieve_filtered(
222        &self,
223        query: &str,
224        available_tokens: usize,
225        language: &str,
226    ) -> Result<RetrievedCode> {
227        let strategy = classify_query(query);
228
229        let token_budget = budget_tokens(available_tokens, self.config.budget_ratio);
230
231        let chunks = self
232            .semantic_search(query, token_budget, Some(language.to_string()))
233            .await?;
234        let total_tokens: usize = chunks
235            .iter()
236            .map(|c| self.token_counter.count_tokens(&c.code) + 20)
237            .sum();
238
239        Ok(RetrievedCode {
240            chunks,
241            total_tokens,
242            strategy,
243        })
244    }
245
246    #[tracing::instrument(name = "index.retriever.semantic_search", skip(self), fields(%query, token_budget))]
247    async fn semantic_search(
248        &self,
249        query: &str,
250        token_budget: usize,
251        language_filter: Option<String>,
252    ) -> Result<Vec<SearchHit>> {
253        let timeout = std::time::Duration::from_secs(self.config.embed_timeout_secs);
254        let raw_vector = tokio::time::timeout(timeout, self.provider.embed(query))
255            .await
256            .map_err(|_| {
257                tracing::warn!(
258                    embed_timeout_secs = self.config.embed_timeout_secs,
259                    "embedding timed out"
260                );
261                crate::error::IndexError::EmbedTimeout(self.config.embed_timeout_secs)
262            })??;
263
264        // Normalize to unit length so Qdrant gRPC cosine search returns correct scores.
265        // Qdrant gRPC silently returns near-zero scores for unnormalized vectors (#3421).
266        let query_vector = EmbeddingVector::<Unnormalized>::new(raw_vector).normalize();
267
268        let mut hits = self
269            .store
270            .search(query_vector, self.config.max_chunks, language_filter)
271            .await?;
272
273        hits.retain(|h| h.score >= self.config.score_threshold);
274
275        let mut packed = Vec::new();
276        let mut used_tokens = 0;
277
278        for hit in hits {
279            let cost = self.token_counter.count_tokens(&hit.code) + 20;
280            if used_tokens + cost > token_budget {
281                break;
282            }
283            used_tokens += cost;
284            packed.push(hit);
285        }
286
287        Ok(packed)
288    }
289}
290
291/// Format retrieved code chunks as an XML `<code_context>` block.
292///
293/// The output is suitable for direct injection into the agent's user or assistant
294/// message. Each chunk is wrapped in a `<chunk>` element with `file`, `lines`,
295/// `name`, and `score` attributes.
296///
297/// Returns an empty string when `result.chunks` is empty so callers can append
298/// without adding unnecessary whitespace.
299///
300/// # Examples
301///
302/// ```
303/// use zeph_index::retriever::{RetrievedCode, RetrievalStrategy, format_as_context};
304/// use zeph_index::store::SearchHit;
305///
306/// let result = RetrievedCode {
307///     chunks: vec![SearchHit {
308///         code: "fn hello() {}".to_string(),
309///         file_path: "src/lib.rs".to_string(),
310///         line_range: (1, 1),
311///         score: 0.9,
312///         node_type: zeph_index::store::NodeKind::from("function_item"),
313///         language: zeph_index::languages::Lang::Rust,
314///         entity_name: Some("hello".to_string()),
315///         scope_chain: String::new(),
316///     }],
317///     total_tokens: 10,
318///     strategy: RetrievalStrategy::Semantic,
319/// };
320///
321/// let xml = format_as_context(&result);
322/// assert!(xml.starts_with("<code_context>"));
323/// assert!(xml.contains("file=\"src/lib.rs\""));
324/// assert!(xml.ends_with("</code_context>"));
325/// ```
326#[must_use]
327pub fn format_as_context(result: &RetrievedCode) -> String {
328    if result.chunks.is_empty() {
329        return String::new();
330    }
331
332    let mut out = String::from("<code_context>\n");
333
334    for chunk in &result.chunks {
335        let name = chunk
336            .entity_name
337            .as_deref()
338            .unwrap_or(chunk.node_type.as_ref());
339        let _ = writeln!(
340            out,
341            "  <chunk file=\"{}\" lines=\"{}-{}\" name=\"{}\" score=\"{:.2}\">",
342            chunk.file_path, chunk.line_range.0, chunk.line_range.1, name, chunk.score,
343        );
344        out.push_str(&chunk.code);
345        out.push_str("\n  </chunk>\n");
346    }
347
348    out.push_str("</code_context>");
349    out
350}
351
352/// Classify a free-text query to select the best retrieval strategy.
353///
354/// The heuristic looks for symbol-like patterns (Rust path syntax, `fn`/`struct`/`impl`
355/// keywords, `CamelCase` type names, `snake_case` identifiers) and conceptual signal
356/// words (`"how"`, `"explain"`, `"where"`, …).
357///
358/// | Signals present | Returned strategy |
359/// |-----------------|-------------------|
360/// | Symbol only | [`RetrievalStrategy::Grep`] |
361/// | Conceptual only | [`RetrievalStrategy::Semantic`] |
362/// | Both | [`RetrievalStrategy::Hybrid`] |
363/// | Neither | [`RetrievalStrategy::Semantic`] |
364///
365/// # Examples
366///
367/// ```
368/// use zeph_index::retriever::{RetrievalStrategy, classify_query};
369///
370/// assert_eq!(classify_query("how does retry logic work?"), RetrievalStrategy::Semantic);
371/// assert_eq!(classify_query("fn handle_request"), RetrievalStrategy::Grep);
372/// assert_eq!(classify_query("where is MyRouter defined?"), RetrievalStrategy::Hybrid);
373/// ```
374#[must_use]
375pub fn classify_query(query: &str) -> RetrievalStrategy {
376    let has_symbol_pattern = query.contains("::")
377        || query.contains("fn ")
378        || query.contains("struct ")
379        || query.contains("impl ")
380        || query.contains("trait ")
381        || query.contains("mod ")
382        || query.contains("class ")
383        || query.contains("def ")
384        || has_camel_case(query)
385        || has_snake_case_identifier(query);
386
387    let has_conceptual = query.contains("how")
388        || query.contains("where")
389        || query.contains("why")
390        || query.contains("find all")
391        || query.contains("explain")
392        || query.contains("what does")
393        || query.contains("show me");
394
395    match (has_symbol_pattern, has_conceptual) {
396        (true, true) => RetrievalStrategy::Hybrid,
397        (true, false) => RetrievalStrategy::Grep,
398        (false, _) => RetrievalStrategy::Semantic,
399    }
400}
401
402fn has_camel_case(text: &str) -> bool {
403    text.split_whitespace().any(|word| {
404        let chars: Vec<char> = word.chars().collect();
405        chars.len() >= 3
406            && chars[0].is_uppercase()
407            && chars.iter().any(|c| c.is_lowercase())
408            && chars.iter().skip(1).any(|c| c.is_uppercase())
409    })
410}
411
412fn has_snake_case_identifier(text: &str) -> bool {
413    text.split_whitespace().any(|word| {
414        word.len() >= 3
415            && word.contains('_')
416            && word.chars().all(|c| c.is_alphanumeric() || c == '_')
417            && word.starts_with(|c: char| c.is_lowercase())
418    })
419}
420
421fn budget_tokens(available: usize, ratio: f32) -> usize {
422    // Scale to per-mille to stay in integer arithmetic.
423    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
424    let per_mille = (ratio * 1000.0) as usize;
425    available.saturating_mul(per_mille) / 1000
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use crate::store::{NodeKind, SearchHit};
432
433    #[test]
434    fn classify_symbol_query_rust() {
435        assert_eq!(
436            classify_query("find SkillMatcher::match_skills"),
437            RetrievalStrategy::Grep
438        );
439    }
440
441    #[test]
442    fn classify_conceptual_query() {
443        assert_eq!(
444            classify_query("how does skill matching work?"),
445            RetrievalStrategy::Semantic
446        );
447    }
448
449    #[test]
450    fn classify_mixed_query() {
451        assert_eq!(
452            classify_query("where is SkillMatcher used?"),
453            RetrievalStrategy::Hybrid
454        );
455    }
456
457    #[test]
458    fn classify_default_is_semantic() {
459        assert_eq!(classify_query("help"), RetrievalStrategy::Semantic);
460    }
461
462    #[test]
463    fn classify_snake_case_identifier() {
464        assert_eq!(classify_query("my_function"), RetrievalStrategy::Grep);
465    }
466
467    #[test]
468    fn camel_case_detection() {
469        assert!(has_camel_case("HttpClient"));
470        assert!(has_camel_case("find MyStruct"));
471        assert!(!has_camel_case("simple word"));
472        assert!(!has_camel_case("HTTP"));
473        assert!(!has_camel_case("ab"));
474    }
475
476    #[test]
477    fn snake_case_detection() {
478        assert!(has_snake_case_identifier("my_function"));
479        assert!(has_snake_case_identifier("call some_method here"));
480        assert!(!has_snake_case_identifier("NoSnake"));
481        assert!(has_snake_case_identifier("a_b"));
482    }
483
484    #[test]
485    fn format_as_context_empty() {
486        let result = RetrievedCode {
487            chunks: vec![],
488            total_tokens: 0,
489            strategy: RetrievalStrategy::Semantic,
490        };
491        assert_eq!(format_as_context(&result), "");
492    }
493
494    #[test]
495    fn format_as_context_xml() {
496        let result = RetrievedCode {
497            chunks: vec![SearchHit {
498                code: "fn hello() {}".to_string(),
499                file_path: "src/lib.rs".to_string(),
500                line_range: (1, 3),
501                score: 0.85,
502                node_type: NodeKind::from("function_item"),
503                language: crate::languages::Lang::Rust,
504                entity_name: Some("hello".to_string()),
505                scope_chain: String::new(),
506            }],
507            total_tokens: 10,
508            strategy: RetrievalStrategy::Semantic,
509        };
510        let xml = format_as_context(&result);
511        assert!(xml.contains("<code_context>"));
512        assert!(xml.contains("</code_context>"));
513        assert!(xml.contains("file=\"src/lib.rs\""));
514        assert!(xml.contains("name=\"hello\""));
515        assert!(xml.contains("score=\"0.85\""));
516        assert!(xml.contains("fn hello() {}"));
517    }
518
519    #[test]
520    fn snake_case_a_b_three_chars_passes() {
521        assert!(has_snake_case_identifier("a_b"));
522    }
523
524    #[test]
525    fn budget_tokens_ratio_zero() {
526        assert_eq!(budget_tokens(10_000, 0.0), 0);
527    }
528
529    #[test]
530    fn budget_tokens_ratio_one() {
531        assert_eq!(budget_tokens(10_000, 1.0), 10_000);
532    }
533
534    #[test]
535    fn budget_tokens_ratio_half() {
536        assert_eq!(budget_tokens(8_000, 0.5), 4_000);
537    }
538
539    #[test]
540    fn budget_tokens_zero_available() {
541        assert_eq!(budget_tokens(0, 0.4), 0);
542    }
543
544    #[test]
545    fn format_as_context_uses_node_type_when_no_entity_name() {
546        let result = RetrievedCode {
547            chunks: vec![SearchHit {
548                code: "struct Foo {}".to_string(),
549                file_path: "src/foo.rs".to_string(),
550                line_range: (1, 2),
551                score: 0.75,
552                node_type: NodeKind::from("struct_item"),
553                language: crate::languages::Lang::Rust,
554                entity_name: None,
555                scope_chain: String::new(),
556            }],
557            total_tokens: 5,
558            strategy: RetrievalStrategy::Semantic,
559        };
560        let xml = format_as_context(&result);
561        assert!(xml.contains("name=\"struct_item\""));
562    }
563
564    #[test]
565    fn classify_fn_keyword_is_grep() {
566        assert_eq!(classify_query("fn my_func"), RetrievalStrategy::Grep);
567    }
568
569    #[test]
570    fn classify_struct_keyword_is_grep() {
571        assert_eq!(classify_query("struct MyType"), RetrievalStrategy::Grep);
572    }
573
574    #[test]
575    fn classify_explain_conceptual_is_semantic() {
576        assert_eq!(
577            classify_query("explain the architecture"),
578            RetrievalStrategy::Semantic
579        );
580    }
581
582    #[test]
583    fn retrieval_strategy_debug() {
584        assert_eq!(format!("{:?}", RetrievalStrategy::Semantic), "Semantic");
585        assert_eq!(format!("{:?}", RetrievalStrategy::Grep), "Grep");
586        assert_eq!(format!("{:?}", RetrievalStrategy::Hybrid), "Hybrid");
587    }
588
589    #[test]
590    fn retrieval_config_defaults() {
591        let cfg = RetrievalConfig::default();
592        assert_eq!(cfg.max_chunks, 12);
593        assert!(cfg.score_threshold > 0.0);
594        assert!(cfg.budget_ratio > 0.0 && cfg.budget_ratio < 1.0);
595    }
596}