zeph_index/retriever.rs
1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Hybrid code retrieval: query classification, semantic search, budget packing.
5//!
6//! # Retrieval strategy
7//!
8//! [`classify_query`] inspects the free-text query for heuristic signals:
9//!
10//! | Signal | Examples | Strategy |
11//! |--------|----------|----------|
12//! | Symbol patterns only | `"fn my_fn"`, `"SkillMatcher::match"`, `"my_snake_func"` | [`RetrievalStrategy::Grep`] |
13//! | Conceptual patterns only | `"how does auth work?"`, `"explain the retry logic"` | [`RetrievalStrategy::Semantic`] |
14//! | Both | `"where is SkillMatcher used?"` | [`RetrievalStrategy::Hybrid`] |
15//!
16//! For `Grep` queries, [`CodeRetriever::retrieve`] returns an empty chunk list and
17//! the agent falls back to its shell grep tool. For `Semantic` and `Hybrid` queries
18//! an embedding round-trip is made and the top-scoring Qdrant results are packed
19//! within a token budget.
20//!
21//! # Token budget
22//!
23//! [`RetrievalConfig::budget_ratio`] controls what fraction of the caller's available
24//! context window is allocated to code chunks. The packing loop stops before adding a
25//! chunk that would exceed the budget, so the retrieved set always fits the window.
26
27use std::fmt::Write;
28use std::sync::Arc;
29
30use crate::error::Result;
31use crate::store::{CodeStore, SearchHit};
32use zeph_common::{EmbeddingVector, Unnormalized};
33use zeph_llm::any::AnyProvider;
34use zeph_llm::provider::LlmProvider;
35use zeph_memory::TokenCounter;
36
37/// The retrieval strategy selected by [`classify_query`] for a given query.
38///
39/// # Examples
40///
41/// ```
42/// use zeph_index::retriever::{RetrievalStrategy, classify_query};
43///
44/// assert_eq!(classify_query("how does authentication work?"), RetrievalStrategy::Semantic);
45/// assert_eq!(classify_query("fn my_handler"), RetrievalStrategy::Grep);
46/// assert_eq!(classify_query("where is MyHandler used?"), RetrievalStrategy::Hybrid);
47/// ```
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49#[non_exhaustive]
50pub enum RetrievalStrategy {
51 /// Vector similarity search for conceptual or descriptive queries.
52 ///
53 /// The query is embedded and the top-K chunks from Qdrant are returned.
54 Semantic,
55 /// Exact symbol lookup — the retriever returns an empty chunk list.
56 ///
57 /// The caller (agent) is expected to use a `grep` or `symbol_definition` tool
58 /// instead of the vector store for precise symbol lookups.
59 Grep,
60 /// Both semantic search **and** a hint that grep may also help.
61 ///
62 /// Semantic results are still returned, but the caller can additionally
63 /// perform a textual search for the identified symbol names.
64 Hybrid,
65}
66
67/// Configuration for [`CodeRetriever`].
68///
69/// # Examples
70///
71/// ```
72/// use zeph_index::retriever::RetrievalConfig;
73///
74/// let cfg = RetrievalConfig::default();
75/// assert_eq!(cfg.max_chunks, 12);
76/// assert!(cfg.score_threshold > 0.0);
77/// assert!(cfg.budget_ratio > 0.0 && cfg.budget_ratio < 1.0);
78/// ```
79#[derive(Debug, Clone)]
80pub struct RetrievalConfig {
81 /// Maximum number of chunks to fetch from Qdrant before applying score and budget filters.
82 pub max_chunks: usize,
83 /// Minimum cosine similarity score to accept (chunks below this are dropped).
84 pub score_threshold: f32,
85 /// Maximum fraction of `available_tokens` allocated to code chunks (0.0–1.0).
86 pub budget_ratio: f32,
87 /// Maximum seconds to wait for `provider.embed()` before returning
88 /// [`crate::error::IndexError::EmbedTimeout`]. Defaults to `10`.
89 pub embed_timeout_secs: u64,
90}
91
92impl Default for RetrievalConfig {
93 fn default() -> Self {
94 Self {
95 max_chunks: 12,
96 score_threshold: 0.25,
97 budget_ratio: 0.40,
98 embed_timeout_secs: 10,
99 }
100 }
101}
102
103/// The result of a single retrieval operation.
104///
105/// Returned by [`CodeRetriever::retrieve`] and [`CodeRetriever::retrieve_filtered`].
106/// Pass to [`format_as_context`] to produce an XML snippet for injection into the
107/// agent message.
108#[derive(Debug)]
109pub struct RetrievedCode {
110 /// Ordered list of matching chunks (highest score first, budget-capped).
111 pub chunks: Vec<SearchHit>,
112 /// Estimated total tokens consumed by `chunks` (including a small per-chunk overhead).
113 pub total_tokens: usize,
114 /// Strategy that was used to produce this result.
115 pub strategy: RetrievalStrategy,
116}
117
118/// Budget-aware code retriever with automatic query classification.
119///
120/// Wraps a [`CodeStore`] and an LLM provider (for embedding) and exposes a single
121/// high-level [`CodeRetriever::retrieve`] method.
122///
123/// # Examples
124///
125/// ```no_run
126/// use std::sync::Arc;
127/// use zeph_index::retriever::{CodeRetriever, RetrievalConfig, format_as_context};
128/// use zeph_index::store::CodeStore;
129/// # async fn example() -> zeph_index::Result<()> {
130/// # let store: CodeStore = panic!("placeholder");
131/// # let provider: Arc<zeph_llm::any::AnyProvider> = panic!("placeholder");
132///
133/// let retriever = CodeRetriever::new(store, provider, RetrievalConfig::default());
134/// let result = retriever.retrieve("explain how authentication works", 8_000).await?;
135/// let xml = format_as_context(&result);
136/// println!("{xml}");
137/// # Ok(())
138/// # }
139/// ```
140pub struct CodeRetriever {
141 store: CodeStore,
142 provider: Arc<AnyProvider>,
143 config: RetrievalConfig,
144 token_counter: Arc<TokenCounter>,
145}
146
147impl CodeRetriever {
148 /// Create a new `CodeRetriever`.
149 ///
150 /// `store` must have its Qdrant collection already created (see
151 /// [`CodeStore::ensure_collection`]).
152 #[must_use]
153 pub fn new(store: CodeStore, provider: Arc<AnyProvider>, config: RetrievalConfig) -> Self {
154 Self {
155 store,
156 provider,
157 config,
158 token_counter: Arc::new(TokenCounter::new()),
159 }
160 }
161
162 /// Retrieve relevant code chunks for a free-text query.
163 ///
164 /// Classifies `query` via [`classify_query`], then:
165 ///
166 /// * For [`RetrievalStrategy::Grep`] queries — returns an empty [`RetrievedCode`]
167 /// so the agent falls back to its shell `grep` or `symbol_definition` tools.
168 /// * For [`RetrievalStrategy::Semantic`] / [`RetrievalStrategy::Hybrid`] — embeds
169 /// the query, searches Qdrant, applies the score threshold, and packs results
170 /// within `available_tokens * budget_ratio`.
171 ///
172 /// # Errors
173 ///
174 /// Returns an error if the embedding call or Qdrant search fails.
175 #[tracing::instrument(name = "index.retriever.retrieve", skip(self), fields(%query, available_tokens))]
176 pub async fn retrieve(&self, query: &str, available_tokens: usize) -> Result<RetrievedCode> {
177 let strategy = classify_query(query);
178
179 let token_budget = budget_tokens(available_tokens, self.config.budget_ratio);
180
181 match strategy {
182 RetrievalStrategy::Grep => Ok(RetrievedCode {
183 chunks: vec![],
184 total_tokens: 0,
185 strategy,
186 }),
187 RetrievalStrategy::Semantic | RetrievalStrategy::Hybrid => {
188 let chunks = self
189 .semantic_search(query, token_budget, None::<String>)
190 .await?;
191 let total_tokens: usize = chunks
192 .iter()
193 .map(|c| self.token_counter.count_tokens(&c.code) + 20)
194 .sum();
195 Ok(RetrievedCode {
196 chunks,
197 total_tokens,
198 strategy,
199 })
200 }
201 }
202 }
203
204 /// Retrieve relevant code, restricting results to a single language.
205 ///
206 /// Behaves like [`CodeRetriever::retrieve`] but adds a Qdrant payload filter so
207 /// only chunks whose `language` field matches `language` are returned.
208 ///
209 /// Useful when the user or agent has already established the relevant language
210 /// (e.g. "show me the Python error handling" should not return Rust results).
211 ///
212 /// # Arguments
213 ///
214 /// * `language` — the language identifier as returned by [`crate::languages::Lang::id`]
215 /// (e.g. `"rust"`, `"python"`).
216 ///
217 /// # Errors
218 ///
219 /// Returns an error if embedding or Qdrant search fails.
220 #[tracing::instrument(name = "index.retriever.retrieve_filtered", skip(self), fields(%query, available_tokens, %language))]
221 pub async fn retrieve_filtered(
222 &self,
223 query: &str,
224 available_tokens: usize,
225 language: &str,
226 ) -> Result<RetrievedCode> {
227 let strategy = classify_query(query);
228
229 let token_budget = budget_tokens(available_tokens, self.config.budget_ratio);
230
231 let chunks = self
232 .semantic_search(query, token_budget, Some(language.to_string()))
233 .await?;
234 let total_tokens: usize = chunks
235 .iter()
236 .map(|c| self.token_counter.count_tokens(&c.code) + 20)
237 .sum();
238
239 Ok(RetrievedCode {
240 chunks,
241 total_tokens,
242 strategy,
243 })
244 }
245
246 #[tracing::instrument(name = "index.retriever.semantic_search", skip(self), fields(%query, token_budget))]
247 async fn semantic_search(
248 &self,
249 query: &str,
250 token_budget: usize,
251 language_filter: Option<String>,
252 ) -> Result<Vec<SearchHit>> {
253 let timeout = std::time::Duration::from_secs(self.config.embed_timeout_secs);
254 let raw_vector = tokio::time::timeout(timeout, self.provider.embed(query))
255 .await
256 .map_err(|_| {
257 tracing::warn!(
258 embed_timeout_secs = self.config.embed_timeout_secs,
259 "embedding timed out"
260 );
261 crate::error::IndexError::EmbedTimeout(self.config.embed_timeout_secs)
262 })??;
263
264 // Normalize to unit length so Qdrant gRPC cosine search returns correct scores.
265 // Qdrant gRPC silently returns near-zero scores for unnormalized vectors (#3421).
266 let query_vector = EmbeddingVector::<Unnormalized>::new(raw_vector).normalize();
267
268 let mut hits = self
269 .store
270 .search(query_vector, self.config.max_chunks, language_filter)
271 .await?;
272
273 hits.retain(|h| h.score >= self.config.score_threshold);
274
275 let mut packed = Vec::new();
276 let mut used_tokens = 0;
277
278 for hit in hits {
279 let cost = self.token_counter.count_tokens(&hit.code) + 20;
280 if used_tokens + cost > token_budget {
281 break;
282 }
283 used_tokens += cost;
284 packed.push(hit);
285 }
286
287 Ok(packed)
288 }
289}
290
291/// Format retrieved code chunks as an XML `<code_context>` block.
292///
293/// The output is suitable for direct injection into the agent's user or assistant
294/// message. Each chunk is wrapped in a `<chunk>` element with `file`, `lines`,
295/// `name`, and `score` attributes.
296///
297/// Returns an empty string when `result.chunks` is empty so callers can append
298/// without adding unnecessary whitespace.
299///
300/// # Examples
301///
302/// ```
303/// use zeph_index::retriever::{RetrievedCode, RetrievalStrategy, format_as_context};
304/// use zeph_index::store::SearchHit;
305///
306/// let result = RetrievedCode {
307/// chunks: vec![SearchHit {
308/// code: "fn hello() {}".to_string(),
309/// file_path: "src/lib.rs".to_string(),
310/// line_range: (1, 1),
311/// score: 0.9,
312/// node_type: zeph_index::store::NodeKind::from("function_item"),
313/// language: zeph_index::languages::Lang::Rust,
314/// entity_name: Some("hello".to_string()),
315/// scope_chain: String::new(),
316/// }],
317/// total_tokens: 10,
318/// strategy: RetrievalStrategy::Semantic,
319/// };
320///
321/// let xml = format_as_context(&result);
322/// assert!(xml.starts_with("<code_context>"));
323/// assert!(xml.contains("file=\"src/lib.rs\""));
324/// assert!(xml.ends_with("</code_context>"));
325/// ```
326#[must_use]
327pub fn format_as_context(result: &RetrievedCode) -> String {
328 if result.chunks.is_empty() {
329 return String::new();
330 }
331
332 let mut out = String::from("<code_context>\n");
333
334 for chunk in &result.chunks {
335 let name = chunk
336 .entity_name
337 .as_deref()
338 .unwrap_or(chunk.node_type.as_ref());
339 let _ = writeln!(
340 out,
341 " <chunk file=\"{}\" lines=\"{}-{}\" name=\"{}\" score=\"{:.2}\">",
342 chunk.file_path, chunk.line_range.0, chunk.line_range.1, name, chunk.score,
343 );
344 out.push_str(&chunk.code);
345 out.push_str("\n </chunk>\n");
346 }
347
348 out.push_str("</code_context>");
349 out
350}
351
352/// Classify a free-text query to select the best retrieval strategy.
353///
354/// The heuristic looks for symbol-like patterns (Rust path syntax, `fn`/`struct`/`impl`
355/// keywords, `CamelCase` type names, `snake_case` identifiers) and conceptual signal
356/// words (`"how"`, `"explain"`, `"where"`, …).
357///
358/// | Signals present | Returned strategy |
359/// |-----------------|-------------------|
360/// | Symbol only | [`RetrievalStrategy::Grep`] |
361/// | Conceptual only | [`RetrievalStrategy::Semantic`] |
362/// | Both | [`RetrievalStrategy::Hybrid`] |
363/// | Neither | [`RetrievalStrategy::Semantic`] |
364///
365/// # Examples
366///
367/// ```
368/// use zeph_index::retriever::{RetrievalStrategy, classify_query};
369///
370/// assert_eq!(classify_query("how does retry logic work?"), RetrievalStrategy::Semantic);
371/// assert_eq!(classify_query("fn handle_request"), RetrievalStrategy::Grep);
372/// assert_eq!(classify_query("where is MyRouter defined?"), RetrievalStrategy::Hybrid);
373/// ```
374#[must_use]
375pub fn classify_query(query: &str) -> RetrievalStrategy {
376 let has_symbol_pattern = query.contains("::")
377 || query.contains("fn ")
378 || query.contains("struct ")
379 || query.contains("impl ")
380 || query.contains("trait ")
381 || query.contains("mod ")
382 || query.contains("class ")
383 || query.contains("def ")
384 || has_camel_case(query)
385 || has_snake_case_identifier(query);
386
387 let has_conceptual = query.contains("how")
388 || query.contains("where")
389 || query.contains("why")
390 || query.contains("find all")
391 || query.contains("explain")
392 || query.contains("what does")
393 || query.contains("show me");
394
395 match (has_symbol_pattern, has_conceptual) {
396 (true, true) => RetrievalStrategy::Hybrid,
397 (true, false) => RetrievalStrategy::Grep,
398 (false, _) => RetrievalStrategy::Semantic,
399 }
400}
401
402fn has_camel_case(text: &str) -> bool {
403 text.split_whitespace().any(|word| {
404 let chars: Vec<char> = word.chars().collect();
405 chars.len() >= 3
406 && chars[0].is_uppercase()
407 && chars.iter().any(|c| c.is_lowercase())
408 && chars.iter().skip(1).any(|c| c.is_uppercase())
409 })
410}
411
412fn has_snake_case_identifier(text: &str) -> bool {
413 text.split_whitespace().any(|word| {
414 word.len() >= 3
415 && word.contains('_')
416 && word.chars().all(|c| c.is_alphanumeric() || c == '_')
417 && word.starts_with(|c: char| c.is_lowercase())
418 })
419}
420
421fn budget_tokens(available: usize, ratio: f32) -> usize {
422 // Scale to per-mille to stay in integer arithmetic.
423 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
424 let per_mille = (ratio * 1000.0) as usize;
425 available.saturating_mul(per_mille) / 1000
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use crate::store::{NodeKind, SearchHit};
432
433 #[test]
434 fn classify_symbol_query_rust() {
435 assert_eq!(
436 classify_query("find SkillMatcher::match_skills"),
437 RetrievalStrategy::Grep
438 );
439 }
440
441 #[test]
442 fn classify_conceptual_query() {
443 assert_eq!(
444 classify_query("how does skill matching work?"),
445 RetrievalStrategy::Semantic
446 );
447 }
448
449 #[test]
450 fn classify_mixed_query() {
451 assert_eq!(
452 classify_query("where is SkillMatcher used?"),
453 RetrievalStrategy::Hybrid
454 );
455 }
456
457 #[test]
458 fn classify_default_is_semantic() {
459 assert_eq!(classify_query("help"), RetrievalStrategy::Semantic);
460 }
461
462 #[test]
463 fn classify_snake_case_identifier() {
464 assert_eq!(classify_query("my_function"), RetrievalStrategy::Grep);
465 }
466
467 #[test]
468 fn camel_case_detection() {
469 assert!(has_camel_case("HttpClient"));
470 assert!(has_camel_case("find MyStruct"));
471 assert!(!has_camel_case("simple word"));
472 assert!(!has_camel_case("HTTP"));
473 assert!(!has_camel_case("ab"));
474 }
475
476 #[test]
477 fn snake_case_detection() {
478 assert!(has_snake_case_identifier("my_function"));
479 assert!(has_snake_case_identifier("call some_method here"));
480 assert!(!has_snake_case_identifier("NoSnake"));
481 assert!(has_snake_case_identifier("a_b"));
482 }
483
484 #[test]
485 fn format_as_context_empty() {
486 let result = RetrievedCode {
487 chunks: vec![],
488 total_tokens: 0,
489 strategy: RetrievalStrategy::Semantic,
490 };
491 assert_eq!(format_as_context(&result), "");
492 }
493
494 #[test]
495 fn format_as_context_xml() {
496 let result = RetrievedCode {
497 chunks: vec![SearchHit {
498 code: "fn hello() {}".to_string(),
499 file_path: "src/lib.rs".to_string(),
500 line_range: (1, 3),
501 score: 0.85,
502 node_type: NodeKind::from("function_item"),
503 language: crate::languages::Lang::Rust,
504 entity_name: Some("hello".to_string()),
505 scope_chain: String::new(),
506 }],
507 total_tokens: 10,
508 strategy: RetrievalStrategy::Semantic,
509 };
510 let xml = format_as_context(&result);
511 assert!(xml.contains("<code_context>"));
512 assert!(xml.contains("</code_context>"));
513 assert!(xml.contains("file=\"src/lib.rs\""));
514 assert!(xml.contains("name=\"hello\""));
515 assert!(xml.contains("score=\"0.85\""));
516 assert!(xml.contains("fn hello() {}"));
517 }
518
519 #[test]
520 fn snake_case_a_b_three_chars_passes() {
521 assert!(has_snake_case_identifier("a_b"));
522 }
523
524 #[test]
525 fn budget_tokens_ratio_zero() {
526 assert_eq!(budget_tokens(10_000, 0.0), 0);
527 }
528
529 #[test]
530 fn budget_tokens_ratio_one() {
531 assert_eq!(budget_tokens(10_000, 1.0), 10_000);
532 }
533
534 #[test]
535 fn budget_tokens_ratio_half() {
536 assert_eq!(budget_tokens(8_000, 0.5), 4_000);
537 }
538
539 #[test]
540 fn budget_tokens_zero_available() {
541 assert_eq!(budget_tokens(0, 0.4), 0);
542 }
543
544 #[test]
545 fn format_as_context_uses_node_type_when_no_entity_name() {
546 let result = RetrievedCode {
547 chunks: vec![SearchHit {
548 code: "struct Foo {}".to_string(),
549 file_path: "src/foo.rs".to_string(),
550 line_range: (1, 2),
551 score: 0.75,
552 node_type: NodeKind::from("struct_item"),
553 language: crate::languages::Lang::Rust,
554 entity_name: None,
555 scope_chain: String::new(),
556 }],
557 total_tokens: 5,
558 strategy: RetrievalStrategy::Semantic,
559 };
560 let xml = format_as_context(&result);
561 assert!(xml.contains("name=\"struct_item\""));
562 }
563
564 #[test]
565 fn classify_fn_keyword_is_grep() {
566 assert_eq!(classify_query("fn my_func"), RetrievalStrategy::Grep);
567 }
568
569 #[test]
570 fn classify_struct_keyword_is_grep() {
571 assert_eq!(classify_query("struct MyType"), RetrievalStrategy::Grep);
572 }
573
574 #[test]
575 fn classify_explain_conceptual_is_semantic() {
576 assert_eq!(
577 classify_query("explain the architecture"),
578 RetrievalStrategy::Semantic
579 );
580 }
581
582 #[test]
583 fn retrieval_strategy_debug() {
584 assert_eq!(format!("{:?}", RetrievalStrategy::Semantic), "Semantic");
585 assert_eq!(format!("{:?}", RetrievalStrategy::Grep), "Grep");
586 assert_eq!(format!("{:?}", RetrievalStrategy::Hybrid), "Hybrid");
587 }
588
589 #[test]
590 fn retrieval_config_defaults() {
591 let cfg = RetrievalConfig::default();
592 assert_eq!(cfg.max_chunks, 12);
593 assert!(cfg.score_threshold > 0.0);
594 assert!(cfg.budget_ratio > 0.0 && cfg.budget_ratio < 1.0);
595 }
596}