the_code_graph_domain/use_cases/
query.rs1use crate::analysis::search::{detect_kind_boost, qualified_name_boost, rrf_merge};
2use crate::error::Result;
3use crate::model::*;
4use crate::ports::{EmbeddingProvider, GraphStore, SearchIndex, VectorStore};
5use std::sync::Arc;
6
7pub struct QueryUseCase<S, I> {
8 store: S,
9 index: I,
10 vector_store: Option<Arc<dyn VectorStore>>,
11 embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
12}
13
14impl<S: GraphStore, I: SearchIndex> QueryUseCase<S, I> {
15 pub fn new(store: S, index: I) -> Self {
16 Self {
17 store,
18 index,
19 vector_store: None,
20 embedding_provider: None,
21 }
22 }
23
24 pub fn with_hybrid(
25 store: S,
26 index: I,
27 vector_store: Option<Arc<dyn VectorStore>>,
28 embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
29 ) -> Self {
30 Self {
31 store,
32 index,
33 vector_store,
34 embedding_provider,
35 }
36 }
37
38 pub fn find(&self, pattern: &str) -> Result<Vec<SymbolNode>> {
39 self.store.find_by_name(pattern)
40 }
41
42 pub fn refs(&self, qualified_name: &str) -> Result<Vec<Reference>> {
43 let edges = self.store.get_edges_to(qualified_name)?;
44 Ok(edges
45 .into_iter()
46 .map(|e| Reference {
47 symbol: e.source,
48 edge_kind: e.kind,
49 location: None,
50 })
51 .collect())
52 }
53
54 pub fn callers(&self, qualified_name: &str) -> Result<Vec<Reference>> {
55 let edges = self.store.get_edges_to(qualified_name)?;
56 Ok(edges
57 .into_iter()
58 .filter(|e| e.kind == EdgeKind::Calls)
59 .map(|e| Reference {
60 symbol: e.source,
61 edge_kind: e.kind,
62 location: None,
63 })
64 .collect())
65 }
66
67 pub fn callees(&self, qualified_name: &str) -> Result<Vec<Reference>> {
68 let edges = self.store.get_edges_from(qualified_name)?;
69 Ok(edges
70 .into_iter()
71 .filter(|e| e.kind == EdgeKind::Calls)
72 .map(|e| Reference {
73 symbol: e.target,
74 edge_kind: e.kind,
75 location: None,
76 })
77 .collect())
78 }
79
80 pub fn search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
81 self.index.search(query, limit)
82 }
83
84 pub fn hybrid_search(
87 &self,
88 query: &str,
89 limit: usize,
90 mode: SearchMode,
91 config: &HybridSearchConfig,
92 ) -> Result<Vec<SearchResult>> {
93 if query.is_empty() {
94 return Ok(vec![]);
95 }
96
97 let fts_results: Vec<(String, f64)> = if mode != SearchMode::SemanticOnly {
99 self.index
100 .search(query, limit)?
101 .into_iter()
102 .map(|r| (r.qualified_name, r.score))
103 .collect()
104 } else {
105 vec![]
106 };
107
108 let vec_results: Vec<(String, f64)> = if mode != SearchMode::FtsOnly {
110 if let (Some(vs), Some(ep)) =
111 (self.vector_store.as_ref(), self.embedding_provider.as_ref())
112 {
113 if vs.has_embeddings() {
114 let query_vec = ep.embed_query(query)?;
115 vs.search_nearest(&query_vec, limit)?
116 } else {
117 vec![]
118 }
119 } else {
120 vec![]
121 }
122 } else {
123 vec![]
124 };
125
126 let merged: Vec<(String, f64)> = match mode {
128 SearchMode::FtsOnly => fts_results,
129 SearchMode::SemanticOnly => vec_results,
130 SearchMode::Hybrid => {
131 if vec_results.is_empty() {
132 fts_results
134 } else {
135 rrf_merge(&[fts_results, vec_results], config.rrf_k)
136 }
137 }
138 };
139
140 let kind_boosts = if config.kind_boost {
142 detect_kind_boost(query)
143 } else {
144 vec![]
145 };
146 let qn_boost = qualified_name_boost(query);
147
148 let mut results: Vec<SearchResult> = merged
150 .into_iter()
151 .take(limit)
152 .filter_map(|(qn, mut score)| {
153 let sym = self.store.get_symbol(&qn).ok().flatten()?;
154 if qn_boost > 1.0 && qn.contains(query) {
156 score *= qn_boost;
157 }
158 if !kind_boosts.is_empty() {
159 for kb in &kind_boosts {
160 if sym.kind == kb.kind {
161 score *= kb.multiplier;
162 }
163 }
164 }
165 Some(SearchResult {
166 qualified_name: sym.qualified_name.clone(),
167 name: sym.name.clone(),
168 kind: sym.kind,
169 file_path: sym.location.file.clone(),
170 score,
171 score_source: Some(match mode {
172 SearchMode::FtsOnly => ScoreSource::Fts5,
173 SearchMode::SemanticOnly => ScoreSource::Semantic,
174 SearchMode::Hybrid => ScoreSource::Hybrid,
175 }),
176 })
177 })
178 .collect();
179
180 results.sort_by(|a, b| {
182 b.score
183 .partial_cmp(&a.score)
184 .unwrap_or(std::cmp::Ordering::Equal)
185 });
186 Ok(results)
187 }
188
189 pub fn stats(&self) -> Result<GraphStats> {
190 self.store.stats()
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use crate::test_support::{InMemoryEmbeddingProvider, InMemoryGraphStore, InMemoryVectorStore};
198 use std::sync::Arc;
199
200 fn make_symbol(name: &str) -> SymbolNode {
201 SymbolNode {
202 name: name.into(),
203 qualified_name: format!("test.rs::{name}"),
204 kind: SymbolKind::Function,
205 location: Location {
206 file: "test.rs".into(),
207 line_start: 1,
208 line_end: 5,
209 col_start: 0,
210 col_end: 0,
211 },
212 visibility: Visibility::Public,
213 is_exported: false,
214 is_async: false,
215 is_test: false,
216 decorators: vec![],
217 signature: None,
218 }
219 }
220
221 #[test]
222 fn find_exact_match() {
223 let mut store = InMemoryGraphStore::new();
224 store.insert_symbol(make_symbol("foo"));
225 let uc = QueryUseCase::new(store.clone(), store);
226 let results = uc.find("foo").unwrap();
227 assert_eq!(results.len(), 1);
228 assert_eq!(results[0].name, "foo");
229 }
230
231 #[test]
232 fn find_prefix_fallback() {
233 let mut store = InMemoryGraphStore::new();
234 store.insert_symbol(make_symbol("foobar"));
235 let uc = QueryUseCase::new(store.clone(), store);
236 let results = uc.find("foo").unwrap();
237 assert_eq!(results.len(), 1);
238 assert_eq!(results[0].name, "foobar");
239 }
240
241 #[test]
242 fn find_no_match_returns_empty() {
243 let store = InMemoryGraphStore::new();
244 let uc = QueryUseCase::new(store.clone(), store);
245 let results = uc.find("bar").unwrap();
246 assert!(results.is_empty());
247 }
248
249 #[test]
250 fn find_exact_takes_priority_over_prefix() {
251 let mut store = InMemoryGraphStore::new();
252 store.insert_symbol(make_symbol("foo"));
253 store.insert_symbol(make_symbol("foobar"));
254 let uc = QueryUseCase::new(store.clone(), store);
255 let results = uc.find("foo").unwrap();
256 assert_eq!(results.len(), 1);
257 assert_eq!(results[0].name, "foo");
258 }
259
260 #[test]
265 fn search_falls_back_to_fts_when_no_vector_store() {
266 let mut store = InMemoryGraphStore::new();
267 store.insert_symbol(make_symbol("foo"));
268 let ep: Arc<dyn crate::ports::EmbeddingProvider> =
269 Arc::new(InMemoryEmbeddingProvider::new(4));
270 let uc = QueryUseCase::with_hybrid(store.clone(), store, None, Some(ep));
271 let cfg = HybridSearchConfig::default();
272 let results = uc
273 .hybrid_search("foo", 10, SearchMode::Hybrid, &cfg)
274 .unwrap();
275 assert!(!results.is_empty());
277 assert_eq!(results[0].name, "foo");
278 }
279
280 #[test]
281 fn search_uses_hybrid_when_vector_store_has_embeddings() {
282 let mut store = InMemoryGraphStore::new();
283 store.insert_symbol(make_symbol("foo"));
284 store.insert_symbol(make_symbol("bar"));
285
286 let vs = Arc::new(InMemoryVectorStore::new());
287 vs.store_embeddings(&[
289 EmbeddingEntry {
290 qualified_name: "test.rs::foo".into(),
291 vector: vec![1.0, 0.0, 0.0, 0.0],
292 text_hash: "h1".into(),
293 },
294 EmbeddingEntry {
295 qualified_name: "test.rs::bar".into(),
296 vector: vec![0.0, 1.0, 0.0, 0.0],
297 text_hash: "h2".into(),
298 },
299 ])
300 .unwrap();
301
302 let ep: Arc<dyn crate::ports::EmbeddingProvider> =
303 Arc::new(InMemoryEmbeddingProvider::new(4));
304 let vs_arc: Arc<dyn crate::ports::VectorStore> = vs;
305 let uc = QueryUseCase::with_hybrid(store.clone(), store, Some(vs_arc), Some(ep));
306 let cfg = HybridSearchConfig::default();
307 let results = uc
308 .hybrid_search("foo", 10, SearchMode::Hybrid, &cfg)
309 .unwrap();
310 assert!(!results.is_empty());
312 let names: Vec<&str> = results.iter().map(|r| r.name.as_str()).collect();
313 assert!(names.contains(&"foo"));
314 }
315
316 #[test]
317 fn search_semantic_only_skips_fts() {
318 let mut store = InMemoryGraphStore::new();
319 store.insert_symbol(make_symbol("alpha"));
321
322 let vs = Arc::new(InMemoryVectorStore::new());
323 vs.store_embeddings(&[EmbeddingEntry {
324 qualified_name: "test.rs::alpha".into(),
325 vector: vec![1.0, 0.0, 0.0, 0.0],
326 text_hash: "h1".into(),
327 }])
328 .unwrap();
329
330 let ep: Arc<dyn crate::ports::EmbeddingProvider> =
331 Arc::new(InMemoryEmbeddingProvider::new(4));
332 let vs_arc: Arc<dyn crate::ports::VectorStore> = vs;
333 let uc = QueryUseCase::with_hybrid(store.clone(), store, Some(vs_arc), Some(ep));
334 let cfg = HybridSearchConfig::default();
335 let results = uc
337 .hybrid_search("foo", 10, SearchMode::SemanticOnly, &cfg)
338 .unwrap();
339 assert!(!results.is_empty());
341 assert_eq!(results[0].name, "alpha");
342 }
343
344 #[test]
345 fn search_fts_only_skips_vectors() {
346 let mut store = InMemoryGraphStore::new();
347 store.insert_symbol(make_symbol("foo"));
348
349 let vs = Arc::new(InMemoryVectorStore::new());
351 vs.store_embeddings(&[EmbeddingEntry {
352 qualified_name: "test.rs::bar".into(),
353 vector: vec![1.0, 0.0, 0.0, 0.0],
354 text_hash: "h1".into(),
355 }])
356 .unwrap();
357
358 let ep: Arc<dyn crate::ports::EmbeddingProvider> =
359 Arc::new(InMemoryEmbeddingProvider::new(4));
360 let vs_arc: Arc<dyn crate::ports::VectorStore> = vs;
361 let uc = QueryUseCase::with_hybrid(store.clone(), store, Some(vs_arc), Some(ep));
362 let cfg = HybridSearchConfig::default();
363 let results = uc
364 .hybrid_search("foo", 10, SearchMode::FtsOnly, &cfg)
365 .unwrap();
366 assert!(!results.is_empty());
367 assert_eq!(results[0].name, "foo");
368 assert!(results.iter().all(|r| r.name != "bar"));
369 }
370
371 #[test]
372 fn search_empty_query_returns_empty() {
373 let mut store = InMemoryGraphStore::new();
374 store.insert_symbol(make_symbol("foo"));
375 let uc = QueryUseCase::new(store.clone(), store);
376 let cfg = HybridSearchConfig::default();
377 let results = uc.hybrid_search("", 10, SearchMode::Hybrid, &cfg).unwrap();
378 assert!(results.is_empty());
379 }
380}