Skip to main content

the_code_graph_domain/use_cases/
query.rs

1use crate::analysis::search::{detect_kind_boost, qualified_name_boost, rrf_merge};
2use crate::error::Result;
3use crate::model::*;
4use crate::ports::{EmbeddingProvider, GraphStore, SearchIndex, VectorStore};
5use std::sync::Arc;
6
7pub struct QueryUseCase<S, I> {
8    store: S,
9    index: I,
10    vector_store: Option<Arc<dyn VectorStore>>,
11    embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
12}
13
14impl<S: GraphStore, I: SearchIndex> QueryUseCase<S, I> {
15    pub fn new(store: S, index: I) -> Self {
16        Self {
17            store,
18            index,
19            vector_store: None,
20            embedding_provider: None,
21        }
22    }
23
24    pub fn with_hybrid(
25        store: S,
26        index: I,
27        vector_store: Option<Arc<dyn VectorStore>>,
28        embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
29    ) -> Self {
30        Self {
31            store,
32            index,
33            vector_store,
34            embedding_provider,
35        }
36    }
37
38    pub fn find(&self, pattern: &str) -> Result<Vec<SymbolNode>> {
39        self.store.find_by_name(pattern)
40    }
41
42    pub fn refs(&self, qualified_name: &str) -> Result<Vec<Reference>> {
43        let edges = self.store.get_edges_to(qualified_name)?;
44        Ok(edges
45            .into_iter()
46            .map(|e| Reference {
47                symbol: e.source,
48                edge_kind: e.kind,
49                location: None,
50            })
51            .collect())
52    }
53
54    pub fn callers(&self, qualified_name: &str) -> Result<Vec<Reference>> {
55        let edges = self.store.get_edges_to(qualified_name)?;
56        Ok(edges
57            .into_iter()
58            .filter(|e| e.kind == EdgeKind::Calls)
59            .map(|e| Reference {
60                symbol: e.source,
61                edge_kind: e.kind,
62                location: None,
63            })
64            .collect())
65    }
66
67    pub fn callees(&self, qualified_name: &str) -> Result<Vec<Reference>> {
68        let edges = self.store.get_edges_from(qualified_name)?;
69        Ok(edges
70            .into_iter()
71            .filter(|e| e.kind == EdgeKind::Calls)
72            .map(|e| Reference {
73                symbol: e.target,
74                edge_kind: e.kind,
75                location: None,
76            })
77            .collect())
78    }
79
80    pub fn search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
81        self.index.search(query, limit)
82    }
83
84    /// Hybrid search combining FTS and/or semantic vector search with RRF fusion.
85    /// Falls back to FTS when `mode == Hybrid` but no vector store is available.
86    pub fn hybrid_search(
87        &self,
88        query: &str,
89        limit: usize,
90        mode: SearchMode,
91        config: &HybridSearchConfig,
92    ) -> Result<Vec<SearchResult>> {
93        if query.is_empty() {
94            return Ok(vec![]);
95        }
96
97        // Collect FTS results unless semantic-only
98        let fts_results: Vec<(String, f64)> = if mode != SearchMode::SemanticOnly {
99            self.index
100                .search(query, limit)?
101                .into_iter()
102                .map(|r| (r.qualified_name, r.score))
103                .collect()
104        } else {
105            vec![]
106        };
107
108        // Collect vector results unless FTS-only, and only when a vector store is present
109        let vec_results: Vec<(String, f64)> = if mode != SearchMode::FtsOnly {
110            if let (Some(vs), Some(ep)) =
111                (self.vector_store.as_ref(), self.embedding_provider.as_ref())
112            {
113                if vs.has_embeddings() {
114                    let query_vec = ep.embed_query(query)?;
115                    vs.search_nearest(&query_vec, limit)?
116                } else {
117                    vec![]
118                }
119            } else {
120                vec![]
121            }
122        } else {
123            vec![]
124        };
125
126        // Build the merged ranked list
127        let merged: Vec<(String, f64)> = match mode {
128            SearchMode::FtsOnly => fts_results,
129            SearchMode::SemanticOnly => vec_results,
130            SearchMode::Hybrid => {
131                if vec_results.is_empty() {
132                    // Graceful fallback: no vector store / no embeddings → return FTS
133                    fts_results
134                } else {
135                    rrf_merge(&[fts_results, vec_results], config.rrf_k)
136                }
137            }
138        };
139
140        // Compute kind boosts once
141        let kind_boosts = if config.kind_boost {
142            detect_kind_boost(query)
143        } else {
144            vec![]
145        };
146        let qn_boost = qualified_name_boost(query);
147
148        // Resolve qualified names to SearchResult, applying kind boost
149        let mut results: Vec<SearchResult> = merged
150            .into_iter()
151            .take(limit)
152            .filter_map(|(qn, mut score)| {
153                let sym = self.store.get_symbol(&qn).ok().flatten()?;
154                // Apply qualified-name exact-match boost (2.0x for :: queries)
155                if qn_boost > 1.0 && qn.contains(query) {
156                    score *= qn_boost;
157                }
158                if !kind_boosts.is_empty() {
159                    for kb in &kind_boosts {
160                        if sym.kind == kb.kind {
161                            score *= kb.multiplier;
162                        }
163                    }
164                }
165                Some(SearchResult {
166                    qualified_name: sym.qualified_name.clone(),
167                    name: sym.name.clone(),
168                    kind: sym.kind,
169                    file_path: sym.location.file.clone(),
170                    score,
171                    score_source: Some(match mode {
172                        SearchMode::FtsOnly => ScoreSource::Fts5,
173                        SearchMode::SemanticOnly => ScoreSource::Semantic,
174                        SearchMode::Hybrid => ScoreSource::Hybrid,
175                    }),
176                })
177            })
178            .collect();
179
180        // Sort by score descending
181        results.sort_by(|a, b| {
182            b.score
183                .partial_cmp(&a.score)
184                .unwrap_or(std::cmp::Ordering::Equal)
185        });
186        Ok(results)
187    }
188
189    pub fn stats(&self) -> Result<GraphStats> {
190        self.store.stats()
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::test_support::{InMemoryEmbeddingProvider, InMemoryGraphStore, InMemoryVectorStore};
198    use std::sync::Arc;
199
200    fn make_symbol(name: &str) -> SymbolNode {
201        SymbolNode {
202            name: name.into(),
203            qualified_name: format!("test.rs::{name}"),
204            kind: SymbolKind::Function,
205            location: Location {
206                file: "test.rs".into(),
207                line_start: 1,
208                line_end: 5,
209                col_start: 0,
210                col_end: 0,
211            },
212            visibility: Visibility::Public,
213            is_exported: false,
214            is_async: false,
215            is_test: false,
216            decorators: vec![],
217            signature: None,
218        }
219    }
220
221    #[test]
222    fn find_exact_match() {
223        let mut store = InMemoryGraphStore::new();
224        store.insert_symbol(make_symbol("foo"));
225        let uc = QueryUseCase::new(store.clone(), store);
226        let results = uc.find("foo").unwrap();
227        assert_eq!(results.len(), 1);
228        assert_eq!(results[0].name, "foo");
229    }
230
231    #[test]
232    fn find_prefix_fallback() {
233        let mut store = InMemoryGraphStore::new();
234        store.insert_symbol(make_symbol("foobar"));
235        let uc = QueryUseCase::new(store.clone(), store);
236        let results = uc.find("foo").unwrap();
237        assert_eq!(results.len(), 1);
238        assert_eq!(results[0].name, "foobar");
239    }
240
241    #[test]
242    fn find_no_match_returns_empty() {
243        let store = InMemoryGraphStore::new();
244        let uc = QueryUseCase::new(store.clone(), store);
245        let results = uc.find("bar").unwrap();
246        assert!(results.is_empty());
247    }
248
249    #[test]
250    fn find_exact_takes_priority_over_prefix() {
251        let mut store = InMemoryGraphStore::new();
252        store.insert_symbol(make_symbol("foo"));
253        store.insert_symbol(make_symbol("foobar"));
254        let uc = QueryUseCase::new(store.clone(), store);
255        let results = uc.find("foo").unwrap();
256        assert_eq!(results.len(), 1);
257        assert_eq!(results[0].name, "foo");
258    }
259
260    // -----------------------------------------------------------------------
261    // Hybrid search tests
262    // -----------------------------------------------------------------------
263
264    #[test]
265    fn search_falls_back_to_fts_when_no_vector_store() {
266        let mut store = InMemoryGraphStore::new();
267        store.insert_symbol(make_symbol("foo"));
268        let ep: Arc<dyn crate::ports::EmbeddingProvider> =
269            Arc::new(InMemoryEmbeddingProvider::new(4));
270        let uc = QueryUseCase::with_hybrid(store.clone(), store, None, Some(ep));
271        let cfg = HybridSearchConfig::default();
272        let results = uc
273            .hybrid_search("foo", 10, SearchMode::Hybrid, &cfg)
274            .unwrap();
275        // Falls back to FTS — must still find the symbol
276        assert!(!results.is_empty());
277        assert_eq!(results[0].name, "foo");
278    }
279
280    #[test]
281    fn search_uses_hybrid_when_vector_store_has_embeddings() {
282        let mut store = InMemoryGraphStore::new();
283        store.insert_symbol(make_symbol("foo"));
284        store.insert_symbol(make_symbol("bar"));
285
286        let vs = Arc::new(InMemoryVectorStore::new());
287        // Seed the vector store with embeddings for both symbols
288        vs.store_embeddings(&[
289            EmbeddingEntry {
290                qualified_name: "test.rs::foo".into(),
291                vector: vec![1.0, 0.0, 0.0, 0.0],
292                text_hash: "h1".into(),
293            },
294            EmbeddingEntry {
295                qualified_name: "test.rs::bar".into(),
296                vector: vec![0.0, 1.0, 0.0, 0.0],
297                text_hash: "h2".into(),
298            },
299        ])
300        .unwrap();
301
302        let ep: Arc<dyn crate::ports::EmbeddingProvider> =
303            Arc::new(InMemoryEmbeddingProvider::new(4));
304        let vs_arc: Arc<dyn crate::ports::VectorStore> = vs;
305        let uc = QueryUseCase::with_hybrid(store.clone(), store, Some(vs_arc), Some(ep));
306        let cfg = HybridSearchConfig::default();
307        let results = uc
308            .hybrid_search("foo", 10, SearchMode::Hybrid, &cfg)
309            .unwrap();
310        // Both symbols should appear in merged results
311        assert!(!results.is_empty());
312        let names: Vec<&str> = results.iter().map(|r| r.name.as_str()).collect();
313        assert!(names.contains(&"foo"));
314    }
315
316    #[test]
317    fn search_semantic_only_skips_fts() {
318        let mut store = InMemoryGraphStore::new();
319        // "alpha" does NOT match "foo" text-search wise, but is the only symbol in vector store
320        store.insert_symbol(make_symbol("alpha"));
321
322        let vs = Arc::new(InMemoryVectorStore::new());
323        vs.store_embeddings(&[EmbeddingEntry {
324            qualified_name: "test.rs::alpha".into(),
325            vector: vec![1.0, 0.0, 0.0, 0.0],
326            text_hash: "h1".into(),
327        }])
328        .unwrap();
329
330        let ep: Arc<dyn crate::ports::EmbeddingProvider> =
331            Arc::new(InMemoryEmbeddingProvider::new(4));
332        let vs_arc: Arc<dyn crate::ports::VectorStore> = vs;
333        let uc = QueryUseCase::with_hybrid(store.clone(), store, Some(vs_arc), Some(ep));
334        let cfg = HybridSearchConfig::default();
335        // SemanticOnly: only vector results are returned
336        let results = uc
337            .hybrid_search("foo", 10, SearchMode::SemanticOnly, &cfg)
338            .unwrap();
339        // "alpha" found via vector, not via FTS
340        assert!(!results.is_empty());
341        assert_eq!(results[0].name, "alpha");
342    }
343
344    #[test]
345    fn search_fts_only_skips_vectors() {
346        let mut store = InMemoryGraphStore::new();
347        store.insert_symbol(make_symbol("foo"));
348
349        // Vector store has "bar" — should NOT appear in FtsOnly results
350        let vs = Arc::new(InMemoryVectorStore::new());
351        vs.store_embeddings(&[EmbeddingEntry {
352            qualified_name: "test.rs::bar".into(),
353            vector: vec![1.0, 0.0, 0.0, 0.0],
354            text_hash: "h1".into(),
355        }])
356        .unwrap();
357
358        let ep: Arc<dyn crate::ports::EmbeddingProvider> =
359            Arc::new(InMemoryEmbeddingProvider::new(4));
360        let vs_arc: Arc<dyn crate::ports::VectorStore> = vs;
361        let uc = QueryUseCase::with_hybrid(store.clone(), store, Some(vs_arc), Some(ep));
362        let cfg = HybridSearchConfig::default();
363        let results = uc
364            .hybrid_search("foo", 10, SearchMode::FtsOnly, &cfg)
365            .unwrap();
366        assert!(!results.is_empty());
367        assert_eq!(results[0].name, "foo");
368        assert!(results.iter().all(|r| r.name != "bar"));
369    }
370
371    #[test]
372    fn search_empty_query_returns_empty() {
373        let mut store = InMemoryGraphStore::new();
374        store.insert_symbol(make_symbol("foo"));
375        let uc = QueryUseCase::new(store.clone(), store);
376        let cfg = HybridSearchConfig::default();
377        let results = uc.hybrid_search("", 10, SearchMode::Hybrid, &cfg).unwrap();
378        assert!(results.is_empty());
379    }
380}