search_semantically/
query_classifier.rs1#[derive(Debug, Clone, PartialEq, Eq)]
2pub enum QueryType {
3 Identifier,
4 NaturalLanguage,
5 PathLike,
6}
7
8fn has_camel_case(s: &str) -> bool {
9 let bytes = s.as_bytes();
10 for i in 1..bytes.len() {
11 if bytes[i].is_ascii_uppercase() && bytes[i - 1].is_ascii_lowercase() {
12 return true;
13 }
14 }
15 false
16}
17
18fn has_snake_case(s: &str) -> bool {
19 s.contains('_') && s.chars().any(|c| c.is_alphanumeric())
20}
21
22fn is_screaming_snake(s: &str) -> bool {
23 if !s.contains('_') {
24 return false;
25 }
26 s.chars()
27 .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || c == '_')
28}
29
30pub fn classify_query(query: &str) -> QueryType {
31 let trimmed = query.trim();
32 if trimmed.is_empty() {
33 return QueryType::NaturalLanguage;
34 }
35
36 if trimmed.contains('/') || trimmed.contains('\\') {
37 return QueryType::PathLike;
38 }
39
40 if has_dotted_path(trimmed) {
41 return QueryType::PathLike;
42 }
43
44 let words: Vec<&str> = trimmed.split_whitespace().collect();
45 if words.len() == 1 && has_file_extension(trimmed) {
46 return QueryType::PathLike;
47 }
48
49 if words.len() == 1 {
50 return QueryType::Identifier;
51 }
52
53 if words.len() <= 3 {
54 let looks_like_code = words
55 .iter()
56 .any(|w| has_camel_case(w) || has_snake_case(w) || is_screaming_snake(w));
57 if looks_like_code {
58 return QueryType::Identifier;
59 }
60 }
61
62 QueryType::NaturalLanguage
63}
64
65fn has_dotted_path(s: &str) -> bool {
66 let parts: Vec<&str> = s.split('.').collect();
67 parts.len() >= 3
68 && parts
69 .iter()
70 .all(|p| !p.is_empty() && p.chars().all(|c| c.is_alphanumeric() || c == '_'))
71}
72
73fn has_file_extension(s: &str) -> bool {
74 if !s.contains('.') {
75 return false;
76 }
77 let ext = s.rsplit('.').next().expect("should have part after dot");
78 !ext.is_empty() && ext.len() <= 5 && ext.chars().all(|c| c.is_alphanumeric())
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn empty_query_classified_as_natural_language() {
87 assert_eq!(classify_query(""), QueryType::NaturalLanguage);
88 }
89
90 #[test]
91 fn whitespace_only_query_classified_as_natural_language() {
92 assert_eq!(classify_query(" "), QueryType::NaturalLanguage);
93 }
94
95 #[test]
96 fn single_camel_case_word_is_identifier() {
97 assert_eq!(classify_query("myFunction"), QueryType::Identifier);
98 }
99
100 #[test]
101 fn single_snake_case_word_is_identifier() {
102 assert_eq!(classify_query("my_function"), QueryType::Identifier);
103 }
104
105 #[test]
106 fn single_lowercase_word_is_identifier() {
107 assert_eq!(classify_query("search"), QueryType::Identifier);
108 }
109
110 #[test]
111 fn single_uppercase_word_is_identifier() {
112 assert_eq!(classify_query("Search"), QueryType::Identifier);
113 }
114
115 #[test]
116 fn path_with_slash_is_path_like() {
117 assert_eq!(classify_query("src/tools/mod.rs"), QueryType::PathLike);
118 }
119
120 #[test]
121 fn path_with_backslash_is_path_like() {
122 assert_eq!(classify_query("src\\tools\\mod.rs"), QueryType::PathLike);
123 }
124
125 #[test]
126 fn dotted_path_three_segments_is_path_like() {
127 assert_eq!(classify_query("foo.bar.baz"), QueryType::PathLike);
128 }
129
130 #[test]
131 fn filename_with_extension_is_path_like() {
132 assert_eq!(classify_query("config.yaml"), QueryType::PathLike);
133 }
134
135 #[test]
136 fn screaming_snake_case_is_identifier() {
137 assert_eq!(classify_query("MAX_SIZE DEFAULT"), QueryType::Identifier);
138 }
139
140 #[test]
141 fn short_camel_case_phrase_is_identifier() {
142 assert_eq!(
143 classify_query("myFunction handles input"),
144 QueryType::Identifier
145 );
146 }
147
148 #[test]
149 fn natural_language_sentence_is_natural_language() {
150 assert_eq!(
151 classify_query("find all places where database connections are established"),
152 QueryType::NaturalLanguage
153 );
154 }
155
156 #[test]
157 fn short_mixed_words_without_code_patterns_is_natural_language() {
158 assert_eq!(
159 classify_query("how does this work"),
160 QueryType::NaturalLanguage
161 );
162 }
163
164 #[test]
165 fn four_words_without_code_patterns_is_natural_language() {
166 assert_eq!(
167 classify_query("find the error handler function"),
168 QueryType::NaturalLanguage
169 );
170 }
171}