Skip to main content

search_semantically/
query_classifier.rs

1#[derive(Debug, Clone, PartialEq, Eq)]
2pub enum QueryType {
3    Identifier,
4    NaturalLanguage,
5    PathLike,
6}
7
8fn has_camel_case(s: &str) -> bool {
9    let bytes = s.as_bytes();
10    for i in 1..bytes.len() {
11        if bytes[i].is_ascii_uppercase() && bytes[i - 1].is_ascii_lowercase() {
12            return true;
13        }
14    }
15    false
16}
17
18fn has_snake_case(s: &str) -> bool {
19    s.contains('_') && s.chars().any(|c| c.is_alphanumeric())
20}
21
22fn is_screaming_snake(s: &str) -> bool {
23    if !s.contains('_') {
24        return false;
25    }
26    s.chars()
27        .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || c == '_')
28}
29
30pub fn classify_query(query: &str) -> QueryType {
31    let trimmed = query.trim();
32    if trimmed.is_empty() {
33        return QueryType::NaturalLanguage;
34    }
35
36    if trimmed.contains('/') || trimmed.contains('\\') {
37        return QueryType::PathLike;
38    }
39
40    if has_dotted_path(trimmed) {
41        return QueryType::PathLike;
42    }
43
44    let words: Vec<&str> = trimmed.split_whitespace().collect();
45    if words.len() == 1 && has_file_extension(trimmed) {
46        return QueryType::PathLike;
47    }
48
49    if words.len() == 1 {
50        return QueryType::Identifier;
51    }
52
53    if words.len() <= 3 {
54        let looks_like_code = words
55            .iter()
56            .any(|w| has_camel_case(w) || has_snake_case(w) || is_screaming_snake(w));
57        if looks_like_code {
58            return QueryType::Identifier;
59        }
60    }
61
62    QueryType::NaturalLanguage
63}
64
65fn has_dotted_path(s: &str) -> bool {
66    let parts: Vec<&str> = s.split('.').collect();
67    parts.len() >= 3
68        && parts
69            .iter()
70            .all(|p| !p.is_empty() && p.chars().all(|c| c.is_alphanumeric() || c == '_'))
71}
72
73fn has_file_extension(s: &str) -> bool {
74    if !s.contains('.') {
75        return false;
76    }
77    let ext = s.rsplit('.').next().expect("should have part after dot");
78    !ext.is_empty() && ext.len() <= 5 && ext.chars().all(|c| c.is_alphanumeric())
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn empty_query_classified_as_natural_language() {
87        assert_eq!(classify_query(""), QueryType::NaturalLanguage);
88    }
89
90    #[test]
91    fn whitespace_only_query_classified_as_natural_language() {
92        assert_eq!(classify_query("   "), QueryType::NaturalLanguage);
93    }
94
95    #[test]
96    fn single_camel_case_word_is_identifier() {
97        assert_eq!(classify_query("myFunction"), QueryType::Identifier);
98    }
99
100    #[test]
101    fn single_snake_case_word_is_identifier() {
102        assert_eq!(classify_query("my_function"), QueryType::Identifier);
103    }
104
105    #[test]
106    fn single_lowercase_word_is_identifier() {
107        assert_eq!(classify_query("search"), QueryType::Identifier);
108    }
109
110    #[test]
111    fn single_uppercase_word_is_identifier() {
112        assert_eq!(classify_query("Search"), QueryType::Identifier);
113    }
114
115    #[test]
116    fn path_with_slash_is_path_like() {
117        assert_eq!(classify_query("src/tools/mod.rs"), QueryType::PathLike);
118    }
119
120    #[test]
121    fn path_with_backslash_is_path_like() {
122        assert_eq!(classify_query("src\\tools\\mod.rs"), QueryType::PathLike);
123    }
124
125    #[test]
126    fn dotted_path_three_segments_is_path_like() {
127        assert_eq!(classify_query("foo.bar.baz"), QueryType::PathLike);
128    }
129
130    #[test]
131    fn filename_with_extension_is_path_like() {
132        assert_eq!(classify_query("config.yaml"), QueryType::PathLike);
133    }
134
135    #[test]
136    fn screaming_snake_case_is_identifier() {
137        assert_eq!(classify_query("MAX_SIZE DEFAULT"), QueryType::Identifier);
138    }
139
140    #[test]
141    fn short_camel_case_phrase_is_identifier() {
142        assert_eq!(
143            classify_query("myFunction handles input"),
144            QueryType::Identifier
145        );
146    }
147
148    #[test]
149    fn natural_language_sentence_is_natural_language() {
150        assert_eq!(
151            classify_query("find all places where database connections are established"),
152            QueryType::NaturalLanguage
153        );
154    }
155
156    #[test]
157    fn short_mixed_words_without_code_patterns_is_natural_language() {
158        assert_eq!(
159            classify_query("how does this work"),
160            QueryType::NaturalLanguage
161        );
162    }
163
164    #[test]
165    fn four_words_without_code_patterns_is_natural_language() {
166        assert_eq!(
167            classify_query("find the error handler function"),
168            QueryType::NaturalLanguage
169        );
170    }
171}