Skip to main content

the_code_graph_domain/analysis/
search.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use crate::model::{Edge, EdgeKind, SymbolKind, SymbolNode};
5
6// ---------------------------------------------------------------------------
7// RRF Fusion
8// ---------------------------------------------------------------------------
9
10/// Reciprocal Rank Fusion: merge multiple ranked lists into a single ranking.
11/// `k` is the smoothing constant (typically 60).
12pub fn rrf_merge(lists: &[Vec<(String, f64)>], k: usize) -> Vec<(String, f64)> {
13    let mut scores: HashMap<String, f64> = HashMap::new();
14    for list in lists {
15        for (rank, (qn, _)) in list.iter().enumerate() {
16            *scores.entry(qn.clone()).or_default() += 1.0 / (k + rank + 1) as f64;
17        }
18    }
19    let mut merged: Vec<_> = scores.into_iter().collect();
20    merged.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
21    merged
22}
23
24// ---------------------------------------------------------------------------
25// Text representation
26// ---------------------------------------------------------------------------
27
28/// Build a natural-language text representation of a symbol for embedding.
29pub fn symbol_to_text(sym: &SymbolNode, edges: &[Edge]) -> String {
30    let mut parts = vec![kind_to_str(sym.kind).to_string(), sym.name.clone()];
31    parts.push(format!("in {}", file_stem(&sym.location.file)));
32    if let Some(sig) = &sym.signature {
33        parts.push(format!("signature: {sig}"));
34    }
35    let calls: Vec<_> = edges
36        .iter()
37        .filter(|e| e.kind == EdgeKind::Calls && e.source == sym.qualified_name)
38        .take(3)
39        .map(|e| short_name(&e.target))
40        .collect();
41    if !calls.is_empty() {
42        parts.push(format!("calls {}", calls.join(", ")));
43    }
44    let callers: Vec<_> = edges
45        .iter()
46        .filter(|e| e.kind == EdgeKind::Calls && e.target == sym.qualified_name)
47        .take(3)
48        .map(|e| short_name(&e.source))
49        .collect();
50    if !callers.is_empty() {
51        parts.push(format!("called by {}", callers.join(", ")));
52    }
53    parts.join(", ")
54}
55
56fn kind_to_str(kind: SymbolKind) -> &'static str {
57    match kind {
58        SymbolKind::Function => "Function",
59        SymbolKind::Class => "Class",
60        SymbolKind::Interface => "Interface",
61        SymbolKind::Struct => "Struct",
62        SymbolKind::Trait => "Trait",
63        SymbolKind::Enum => "Enum",
64        SymbolKind::TypeAlias => "TypeAlias",
65        SymbolKind::Method => "Method",
66        SymbolKind::Property => "Property",
67        SymbolKind::Const => "Const",
68        SymbolKind::Macro => "Macro",
69        SymbolKind::Variable => "Variable",
70        SymbolKind::Component => "Component",
71        SymbolKind::Test => "Test",
72    }
73}
74
75fn file_stem(path: &Path) -> String {
76    path.file_stem()
77        .and_then(|s| s.to_str())
78        .unwrap_or("unknown")
79        .to_string()
80}
81
82fn short_name(qualified: &str) -> String {
83    qualified
84        .rsplit("::")
85        .next()
86        .unwrap_or(qualified)
87        .to_string()
88}
89
90// ---------------------------------------------------------------------------
91// Kind boosting
92// ---------------------------------------------------------------------------
93
94/// A boost hint derived from query shape heuristics.
95pub struct KindBoost {
96    pub kind: SymbolKind,
97    pub multiplier: f64,
98}
99
100/// Detect likely symbol kinds from query shape to boost relevance scores.
101/// Returns an empty vec when the query is a qualified name (contains `::`)
102/// because qualified names use exact-match boosting instead (see `qualified_name_boost`).
103pub fn detect_kind_boost(query: &str) -> Vec<KindBoost> {
104    let mut boosts = Vec::new();
105    if query.contains("::") {
106        return boosts; // qualified name pattern — use qualified_name_boost instead
107    }
108    let first = query.chars().next().unwrap_or('a');
109    if first.is_uppercase() && !query.contains('_') {
110        // PascalCase → likely a type-level symbol
111        for kind in [SymbolKind::Struct, SymbolKind::Trait, SymbolKind::Interface] {
112            boosts.push(KindBoost {
113                kind,
114                multiplier: 1.5,
115            });
116        }
117    } else if query.contains('_') && query.chars().all(|c| c.is_lowercase() || c == '_') {
118        // snake_case → likely a function or method
119        for kind in [SymbolKind::Function, SymbolKind::Method] {
120            boosts.push(KindBoost {
121                kind,
122                multiplier: 1.5,
123            });
124        }
125    }
126    boosts
127}
128
129/// Returns 2.0 if the query contains `::` and is a qualified-name pattern,
130/// 1.0 otherwise. Applied as a multiplier to results whose qualified_name
131/// contains the query as a substring.
132pub fn qualified_name_boost(query: &str) -> f64 {
133    if query.contains("::") {
134        2.0
135    } else {
136        1.0
137    }
138}
139
140// ---------------------------------------------------------------------------
141// Tests
142// ---------------------------------------------------------------------------
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::model::{Edge, EdgeKind, Location, SymbolKind, SymbolNode, Visibility};
148    use std::path::PathBuf;
149
150    // ------------------------------------------------------------------
151    // rrf_merge
152    // ------------------------------------------------------------------
153
154    #[test]
155    fn rrf_merge_single_list() {
156        let lists = vec![vec![("a".into(), 1.0), ("b".into(), 0.5)]];
157        let merged = rrf_merge(&lists, 60);
158        assert_eq!(merged[0].0, "a");
159        assert!(merged[0].1 > merged[1].1);
160    }
161
162    #[test]
163    fn rrf_merge_two_lists_boosts_overlap() {
164        let l1 = vec![("a".into(), 1.0), ("b".into(), 0.5)];
165        let l2 = vec![("b".into(), 1.0), ("c".into(), 0.5)];
166        let merged = rrf_merge(&[l1, l2], 60);
167        assert_eq!(merged[0].0, "b"); // appears in both lists
168    }
169
170    #[test]
171    fn rrf_merge_empty_lists() {
172        let merged = rrf_merge(&[], 60);
173        assert!(merged.is_empty());
174    }
175
176    // ------------------------------------------------------------------
177    // symbol_to_text
178    // ------------------------------------------------------------------
179
180    #[test]
181    fn symbol_to_text_basic() {
182        let sym = make_symbol(
183            "foo",
184            SymbolKind::Function,
185            "src/lib.rs",
186            Some("fn foo(x: i32) -> bool".into()),
187        );
188        let text = symbol_to_text(&sym, &[]);
189        assert!(text.contains("Function"));
190        assert!(text.contains("foo"));
191        assert!(text.contains("lib"));
192        assert!(text.contains("signature:"));
193    }
194
195    #[test]
196    fn symbol_to_text_with_edges() {
197        let sym = make_symbol("foo", SymbolKind::Function, "src/lib.rs", None);
198        let edges = vec![
199            make_call_edge("mod::foo", "mod::bar"), // foo calls bar
200            make_call_edge("mod::baz", "mod::foo"), // baz calls foo
201        ];
202        let text = symbol_to_text(&sym, &edges);
203        assert!(text.contains("calls bar"));
204        assert!(text.contains("called by baz"));
205    }
206
207    // ------------------------------------------------------------------
208    // detect_kind_boost
209    // ------------------------------------------------------------------
210
211    #[test]
212    fn detect_kind_boost_pascal_case() {
213        let boosts = detect_kind_boost("AuthService");
214        assert!(!boosts.is_empty());
215        assert!(boosts.iter().any(|b| b.kind == SymbolKind::Struct));
216        assert!(boosts
217            .iter()
218            .all(|b| (b.multiplier - 1.5).abs() < f64::EPSILON));
219    }
220
221    #[test]
222    fn detect_kind_boost_snake_case() {
223        let boosts = detect_kind_boost("validate_token");
224        assert!(!boosts.is_empty());
225        assert!(boosts.iter().any(|b| b.kind == SymbolKind::Function));
226    }
227
228    #[test]
229    fn detect_kind_boost_qualified() {
230        let boosts = detect_kind_boost("auth::validate");
231        assert!(boosts.is_empty()); // :: pattern uses qualified_name_boost instead
232    }
233
234    #[test]
235    fn qualified_name_boost_with_colons() {
236        assert!((qualified_name_boost("auth::validate") - 2.0).abs() < f64::EPSILON);
237    }
238
239    #[test]
240    fn qualified_name_boost_without_colons() {
241        assert!((qualified_name_boost("validate_token") - 1.0).abs() < f64::EPSILON);
242    }
243
244    // ------------------------------------------------------------------
245    // Helpers
246    // ------------------------------------------------------------------
247
248    fn make_symbol(name: &str, kind: SymbolKind, file: &str, sig: Option<String>) -> SymbolNode {
249        SymbolNode {
250            name: name.to_string(),
251            qualified_name: format!("mod::{name}"),
252            kind,
253            location: Location {
254                file: PathBuf::from(file),
255                line_start: 1,
256                line_end: 5,
257                col_start: 0,
258                col_end: 0,
259            },
260            visibility: Visibility::Public,
261            is_exported: true,
262            is_async: false,
263            is_test: false,
264            decorators: vec![],
265            signature: sig,
266        }
267    }
268
269    fn make_call_edge(source: &str, target: &str) -> Edge {
270        Edge {
271            kind: EdgeKind::Calls,
272            source: source.to_string(),
273            target: target.to_string(),
274            metadata: None,
275        }
276    }
277}