Skip to main content

sqry_core/search/
classifier.rs

1/// Query classifier - determines whether a query is semantic, text, or hybrid
2///
3/// Classification rules:
4/// - Semantic: Contains relation queries, symbol filters, AST node types
5/// - Text: Contains regex metacharacters, literal patterns, comments
6/// - Hybrid: Ambiguous queries that should try semantic first, then fallback
7use regex::Regex;
8use std::sync::OnceLock;
9
10/// Query type classification result
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum QueryType {
13    /// Pure semantic (AST-based) search
14    Semantic,
15    /// Pure text/regex search
16    Text,
17    /// Try semantic first, fallback to text if no results
18    Hybrid,
19}
20
21/// Query classifier
22pub struct QueryClassifier;
23
24impl QueryClassifier {
25    /// Classify a query string into semantic, text, or hybrid
26    ///
27    /// # Examples
28    ///
29    /// ```
30    /// use sqry_core::search::classifier::{QueryClassifier, QueryType};
31    ///
32    /// // Semantic patterns
33    /// assert_eq!(QueryClassifier::classify("callers(foo)"), QueryType::Semantic);
34    /// assert_eq!(QueryClassifier::classify("kind:function AND name:bar"), QueryType::Semantic);
35    ///
36    /// // Text patterns
37    /// assert_eq!(QueryClassifier::classify("TODO: fix this"), QueryType::Text);
38    /// assert_eq!(QueryClassifier::classify("^fn \\w+"), QueryType::Text);
39    ///
40    /// // Hybrid (ambiguous)
41    /// assert_eq!(QueryClassifier::classify("find_user"), QueryType::Hybrid);
42    /// ```
43    #[must_use]
44    pub fn classify(query: &str) -> QueryType {
45        // Rule 1: Explicit semantic patterns
46        if Self::is_semantic_pattern(query) {
47            return QueryType::Semantic;
48        }
49
50        // Rule 2: Explicit text patterns
51        if Self::is_text_pattern(query) {
52            return QueryType::Text;
53        }
54
55        // Rule 3: Ambiguous - use hybrid mode
56        QueryType::Hybrid
57    }
58
59    /// Check if query contains semantic indicators
60    ///
61    /// Optimized with compiled regex for single-pass matching
62    fn is_semantic_pattern(query: &str) -> bool {
63        // Compile regex once at first use
64        static SEMANTIC_RE: OnceLock<Regex> = OnceLock::new();
65
66        let re = SEMANTIC_RE.get_or_init(|| {
67            Regex::new(
68                r"(?x)
69                callers[(:]  |   # callers(foo) or callers:foo
70                callees[(:]  |   # callees(bar) or callees:bar
71                imports[(:]  |   # imports(baz) or imports:baz
72                exports[(:]  |   # exports(qux) or exports:qux
73                returns[(:]  |   # returns(type) or returns:type
74                impl:        |   # impl:Debug (CD Static Analysis trait implementation)
75                duplicates:  |   # duplicates:body (CD Static Analysis duplicate detection)
76                unused:      |   # unused:public (CD Static Analysis dead code)
77                circular:    |   # circular:calls (CD Static Analysis cycle detection)
78                kind:        |   # kind:function
79                visibility:  |   # visibility:public
80                scope\.      |   # scope.type, scope.name, scope.parent, scope.ancestor (P2-34)
81                async:       |   # async:true
82                static:      |   # static:false
83                @            |   # @function.def
84                ::           # foo::bar::baz
85            ",
86            )
87            .expect("Failed to compile semantic pattern regex")
88        });
89
90        re.is_match(query)
91    }
92
93    /// Check if query contains text search indicators
94    fn is_text_pattern(query: &str) -> bool {
95        // Text indicators:
96        // - Common code markers: TODO, FIXME, HACK, XXX, NOTE, BUG
97        // - Regex anchors: ^, $
98        // - Regex character classes: \b, \w, \d, \s
99        // - Regex quantifiers: +, *, ?, {n,m}
100        // - Regex groups: (...)
101        // - Comment patterns: //, #, /*
102
103        let markers = ["TODO", "FIXME", "HACK", "XXX", "NOTE", "BUG"];
104        if markers.iter().any(|&m| query.contains(m)) {
105            return true;
106        }
107
108        // Regex metacharacters (strong signal for text search)
109        if query.starts_with('^') || query.ends_with('$') {
110            return true;
111        }
112
113        // Character classes and escape sequences
114        if query.contains("\\b")
115            || query.contains("\\w")
116            || query.contains("\\d")
117            || query.contains("\\s")
118            || query.contains("\\W")
119            || query.contains("\\D")
120            || query.contains("\\S")
121        {
122            return true;
123        }
124
125        // Comment patterns
126        if query.contains("//") || query.starts_with('#') || query.contains("/*") {
127            return true;
128        }
129
130        false
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_classify_semantic_relation_queries() {
140        assert_eq!(
141            QueryClassifier::classify("callers(foo)"),
142            QueryType::Semantic
143        );
144        assert_eq!(
145            QueryClassifier::classify("callees(bar)"),
146            QueryType::Semantic
147        );
148        assert_eq!(
149            QueryClassifier::classify("imports(baz)"),
150            QueryType::Semantic
151        );
152    }
153
154    #[test]
155    fn test_classify_semantic_symbol_filters() {
156        assert_eq!(
157            QueryClassifier::classify("kind:function AND name:bar"),
158            QueryType::Semantic
159        );
160        assert_eq!(
161            QueryClassifier::classify("visibility:public"),
162            QueryType::Semantic
163        );
164        assert_eq!(
165            QueryClassifier::classify("async:true AND static:false"),
166            QueryType::Semantic
167        );
168    }
169
170    #[test]
171    fn test_classify_semantic_scope_queries() {
172        assert_eq!(
173            QueryClassifier::classify("scope:module::class"),
174            QueryType::Semantic
175        );
176        assert_eq!(
177            QueryClassifier::classify("foo::bar::baz"),
178            QueryType::Semantic
179        );
180    }
181
182    #[test]
183    fn test_classify_semantic_ast_nodes() {
184        assert_eq!(
185            QueryClassifier::classify("@function.def"),
186            QueryType::Semantic
187        );
188        assert_eq!(
189            QueryClassifier::classify("@class.name"),
190            QueryType::Semantic
191        );
192    }
193
194    #[test]
195    fn test_classify_text_markers() {
196        assert_eq!(QueryClassifier::classify("TODO: fix this"), QueryType::Text);
197        assert_eq!(
198            QueryClassifier::classify("FIXME: handle error"),
199            QueryType::Text
200        );
201        assert_eq!(
202            QueryClassifier::classify("HACK: temporary"),
203            QueryType::Text
204        );
205        assert_eq!(QueryClassifier::classify("XXX: review"), QueryType::Text);
206    }
207
208    #[test]
209    fn test_classify_text_regex_anchors() {
210        assert_eq!(QueryClassifier::classify("^fn main"), QueryType::Text);
211        assert_eq!(QueryClassifier::classify("error$"), QueryType::Text);
212    }
213
214    #[test]
215    fn test_classify_text_regex_character_classes() {
216        assert_eq!(QueryClassifier::classify("fn \\w+"), QueryType::Text);
217        assert_eq!(QueryClassifier::classify("\\d{3}"), QueryType::Text);
218        assert_eq!(QueryClassifier::classify("\\bfoo\\b"), QueryType::Text);
219    }
220
221    #[test]
222    fn test_classify_text_comment_patterns() {
223        assert_eq!(QueryClassifier::classify("// comment"), QueryType::Text);
224        assert_eq!(QueryClassifier::classify("# TODO"), QueryType::Text);
225        assert_eq!(QueryClassifier::classify("/* block */"), QueryType::Text);
226    }
227
228    #[test]
229    fn test_classify_hybrid_ambiguous() {
230        // Simple identifiers are ambiguous - could be semantic or text
231        assert_eq!(QueryClassifier::classify("find_user"), QueryType::Hybrid);
232        assert_eq!(QueryClassifier::classify("calculate"), QueryType::Hybrid);
233        assert_eq!(QueryClassifier::classify("UserModel"), QueryType::Hybrid);
234    }
235
236    #[test]
237    fn test_classify_hybrid_simple_patterns() {
238        // Simple patterns without clear indicators
239        assert_eq!(QueryClassifier::classify("error"), QueryType::Hybrid);
240        assert_eq!(QueryClassifier::classify("config"), QueryType::Hybrid);
241    }
242
243    // Edge case tests (Phase 1.2)
244    #[test]
245    fn test_classify_edge_case_empty_query() {
246        // Empty query should be hybrid (ambiguous)
247        assert_eq!(QueryClassifier::classify(""), QueryType::Hybrid);
248    }
249
250    #[test]
251    fn test_classify_edge_case_whitespace_only() {
252        // Whitespace-only queries should be hybrid
253        assert_eq!(QueryClassifier::classify("   "), QueryType::Hybrid);
254        assert_eq!(QueryClassifier::classify("\t"), QueryType::Hybrid);
255        assert_eq!(QueryClassifier::classify("\n"), QueryType::Hybrid);
256        assert_eq!(QueryClassifier::classify("  \t\n  "), QueryType::Hybrid);
257    }
258
259    #[test]
260    fn test_classify_edge_case_semantic_with_special_chars() {
261        // Semantic patterns with special characters in values
262        assert_eq!(
263            QueryClassifier::classify("kind:function-name"),
264            QueryType::Semantic
265        );
266        assert_eq!(
267            QueryClassifier::classify("kind:foo_bar"),
268            QueryType::Semantic
269        );
270        assert_eq!(
271            QueryClassifier::classify("kind:foo.bar"),
272            QueryType::Semantic
273        );
274        assert_eq!(
275            QueryClassifier::classify("callers(my::path::Foo)"),
276            QueryType::Semantic
277        );
278    }
279
280    #[test]
281    fn test_classify_edge_case_colon_without_field() {
282        // Colon without semantic field should be hybrid
283        assert_eq!(QueryClassifier::classify(":foo"), QueryType::Hybrid);
284        assert_eq!(QueryClassifier::classify("::"), QueryType::Semantic); // Double colon is semantic
285        assert_eq!(QueryClassifier::classify(":::"), QueryType::Semantic); // Triple colon contains ::
286    }
287
288    #[test]
289    fn test_classify_edge_case_mixed_semantic_and_text() {
290        // If query contains both semantic AND text indicators, semantic wins
291        assert_eq!(
292            QueryClassifier::classify("kind:function AND TODO"),
293            QueryType::Semantic
294        );
295        assert_eq!(
296            QueryClassifier::classify("@function TODO: fix"),
297            QueryType::Semantic
298        );
299        assert_eq!(
300            QueryClassifier::classify("callers(foo) // comment"),
301            QueryType::Semantic
302        );
303    }
304
305    #[test]
306    fn test_classify_edge_case_parentheses_without_semantic() {
307        // Parentheses alone are not semantic
308        assert_eq!(QueryClassifier::classify("(foo)"), QueryType::Hybrid);
309        assert_eq!(QueryClassifier::classify("foo(bar)"), QueryType::Hybrid);
310    }
311
312    #[test]
313    fn test_classify_edge_case_at_symbol_positions() {
314        // @ anywhere in query is semantic
315        assert_eq!(QueryClassifier::classify("@"), QueryType::Semantic);
316        assert_eq!(QueryClassifier::classify("@function"), QueryType::Semantic);
317        assert_eq!(QueryClassifier::classify("function@"), QueryType::Semantic);
318        assert_eq!(QueryClassifier::classify("foo@bar"), QueryType::Semantic);
319    }
320
321    #[test]
322    fn test_classify_edge_case_double_colon_positions() {
323        // :: anywhere is semantic (symbol paths)
324        assert_eq!(QueryClassifier::classify("::"), QueryType::Semantic);
325        assert_eq!(QueryClassifier::classify("foo::bar"), QueryType::Semantic);
326        assert_eq!(QueryClassifier::classify("::bar"), QueryType::Semantic);
327        assert_eq!(QueryClassifier::classify("foo::"), QueryType::Semantic);
328    }
329
330    #[test]
331    fn test_classify_edge_case_semantic_fields_case_sensitive() {
332        // Field names are case-sensitive in the regex
333        assert_eq!(
334            QueryClassifier::classify("kind:function"),
335            QueryType::Semantic
336        );
337        assert_eq!(
338            QueryClassifier::classify("KIND:function"),
339            QueryType::Hybrid
340        ); // uppercase not matched
341        assert_eq!(
342            QueryClassifier::classify("Kind:function"),
343            QueryType::Hybrid
344        ); // titlecase not matched
345    }
346
347    #[test]
348    fn test_classify_edge_case_relation_with_colon() {
349        // Relations can use both ( and :
350        assert_eq!(
351            QueryClassifier::classify("callers:foo"),
352            QueryType::Semantic
353        );
354        assert_eq!(
355            QueryClassifier::classify("callees:bar"),
356            QueryType::Semantic
357        );
358        assert_eq!(
359            QueryClassifier::classify("imports:baz"),
360            QueryType::Semantic
361        );
362        assert_eq!(
363            QueryClassifier::classify("exports:qux"),
364            QueryType::Semantic
365        );
366        assert_eq!(
367            QueryClassifier::classify("returns:Type"),
368            QueryType::Semantic
369        );
370    }
371
372    #[test]
373    fn test_classify_edge_case_relation_with_parentheses() {
374        // Relations can also use parentheses - test both branches of [(:]
375        assert_eq!(
376            QueryClassifier::classify("exports(MyType)"),
377            QueryType::Semantic
378        );
379        assert_eq!(
380            QueryClassifier::classify("returns(Result<T>)"),
381            QueryType::Semantic
382        );
383        assert_eq!(
384            QueryClassifier::classify("imports(std::collections)"),
385            QueryType::Semantic
386        );
387    }
388
389    #[test]
390    fn test_classify_edge_case_boolean_fields() {
391        // Boolean fields (async, static)
392        assert_eq!(QueryClassifier::classify("async:true"), QueryType::Semantic);
393        assert_eq!(
394            QueryClassifier::classify("async:false"),
395            QueryType::Semantic
396        );
397        assert_eq!(
398            QueryClassifier::classify("static:true"),
399            QueryType::Semantic
400        );
401        assert_eq!(
402            QueryClassifier::classify("static:false"),
403            QueryType::Semantic
404        );
405    }
406
407    #[test]
408    fn test_classify_edge_case_complex_queries() {
409        // Complex multi-clause queries
410        assert_eq!(
411            QueryClassifier::classify("kind:function AND async:true"),
412            QueryType::Semantic
413        );
414        assert_eq!(
415            QueryClassifier::classify("(callers(foo) OR callees(bar)) AND kind:struct"),
416            QueryType::Semantic
417        );
418        assert_eq!(
419            QueryClassifier::classify("visibility:public AND scope:module::MyClass"),
420            QueryType::Semantic
421        );
422    }
423
424    #[test]
425    fn test_classify_edge_case_text_false_positives() {
426        // Ensure these are NOT classified as semantic
427        assert_eq!(QueryClassifier::classify("caller"), QueryType::Hybrid); // missing (s
428        assert_eq!(QueryClassifier::classify("foo:bar"), QueryType::Hybrid); // not a known field
429        assert_eq!(QueryClassifier::classify("name:foo"), QueryType::Hybrid); // 'name' not in semantic regex
430        assert_eq!(QueryClassifier::classify("type:Bar"), QueryType::Hybrid); // 'type' not in semantic regex
431    }
432
433    #[test]
434    fn test_classify_edge_case_unicode() {
435        // Unicode in queries (should work with regex)
436        assert_eq!(QueryClassifier::classify("kind:函数"), QueryType::Semantic);
437        assert_eq!(
438            QueryClassifier::classify("foo::バー::baz"),
439            QueryType::Semantic
440        );
441        assert_eq!(
442            QueryClassifier::classify("callers(αβγ)"),
443            QueryType::Semantic
444        );
445    }
446
447    // CD Static Analysis predicate tests
448    #[test]
449    fn test_classify_cd_duplicates_predicate() {
450        assert_eq!(
451            QueryClassifier::classify("duplicates:body"),
452            QueryType::Semantic
453        );
454        assert_eq!(
455            QueryClassifier::classify("duplicates:signature"),
456            QueryType::Semantic
457        );
458        assert_eq!(
459            QueryClassifier::classify("duplicates:struct"),
460            QueryType::Semantic
461        );
462        assert_eq!(
463            QueryClassifier::classify("kind:function AND duplicates:body"),
464            QueryType::Semantic
465        );
466    }
467
468    #[test]
469    fn test_classify_cd_unused_predicate() {
470        assert_eq!(
471            QueryClassifier::classify("unused:public"),
472            QueryType::Semantic
473        );
474        assert_eq!(
475            QueryClassifier::classify("unused:private"),
476            QueryType::Semantic
477        );
478        assert_eq!(
479            QueryClassifier::classify("unused:function"),
480            QueryType::Semantic
481        );
482        assert_eq!(QueryClassifier::classify("unused:all"), QueryType::Semantic);
483    }
484
485    #[test]
486    fn test_classify_cd_circular_predicate() {
487        assert_eq!(
488            QueryClassifier::classify("circular:calls"),
489            QueryType::Semantic
490        );
491        assert_eq!(
492            QueryClassifier::classify("circular:imports"),
493            QueryType::Semantic
494        );
495        assert_eq!(
496            QueryClassifier::classify("circular:all"),
497            QueryType::Semantic
498        );
499    }
500
501    #[test]
502    fn test_classify_cd_combined_predicates() {
503        // Multiple CD predicates together
504        assert_eq!(
505            QueryClassifier::classify("kind:function AND duplicates:body AND unused:public"),
506            QueryType::Semantic
507        );
508        assert_eq!(
509            QueryClassifier::classify("impl:Debug AND circular:calls"),
510            QueryType::Semantic
511        );
512    }
513}