1use std::collections::{HashMap, HashSet};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum QueryIntent {
11 ToolDiscovery,
14 ToolExecution,
17 ToolDocumentation,
20 Comparison,
23 Troubleshooting,
26 General,
28}
29
30impl QueryIntent {
31 pub fn confidence_threshold(&self) -> f32 {
33 match self {
34 QueryIntent::ToolExecution => 0.8,
35 QueryIntent::Comparison => 0.7,
36 QueryIntent::Troubleshooting => 0.7,
37 QueryIntent::ToolDocumentation => 0.6,
38 QueryIntent::ToolDiscovery => 0.5,
39 QueryIntent::General => 0.0,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum EntityType {
47 SkillName,
49 ToolName,
51 ActionVerb,
53 Category,
55 Target,
57}
58
59#[derive(Debug, Clone)]
61pub struct ExtractedEntity {
62 pub text: String,
64 pub entity_type: EntityType,
66 pub confidence: f32,
68 pub position: usize,
70}
71
72#[derive(Debug, Clone)]
74pub struct QueryExpansion {
75 pub original: String,
77 pub expanded: Vec<String>,
79 pub expansion_type: ExpansionType,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum ExpansionType {
86 Synonym,
88 Acronym,
90 Pattern,
92}
93
94#[derive(Debug, Clone)]
96pub struct ProcessedQuery {
97 pub original: String,
99 pub normalized: String,
101 pub intent: QueryIntent,
103 pub intent_confidence: f32,
105 pub entities: Vec<ExtractedEntity>,
107 pub expansions: Vec<QueryExpansion>,
109 pub suggested_filters: Vec<SuggestedFilter>,
111}
112
113#[derive(Debug, Clone)]
115pub struct SuggestedFilter {
116 pub field: String,
118 pub value: String,
120 pub confidence: f32,
122}
123
124pub struct QueryProcessor {
126 known_skills: HashSet<String>,
128 known_tools: HashSet<String>,
130 synonyms: HashMap<String, Vec<String>>,
132 acronyms: HashMap<String, String>,
134 action_verbs: HashSet<String>,
136 categories: HashMap<String, Vec<String>>,
138}
139
140impl Default for QueryProcessor {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146impl QueryProcessor {
147 pub fn new() -> Self {
149 let mut processor = Self {
150 known_skills: HashSet::new(),
151 known_tools: HashSet::new(),
152 synonyms: HashMap::new(),
153 acronyms: HashMap::new(),
154 action_verbs: HashSet::new(),
155 categories: HashMap::new(),
156 };
157
158 processor.init_action_verbs();
160 processor.init_synonyms();
161 processor.init_acronyms();
162 processor.init_categories();
163
164 processor
165 }
166
167 pub fn with_skills(mut self, skills: impl IntoIterator<Item = impl Into<String>>) -> Self {
169 for skill in skills {
170 self.known_skills.insert(skill.into().to_lowercase());
171 }
172 self
173 }
174
175 pub fn with_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
177 for tool in tools {
178 self.known_tools.insert(tool.into().to_lowercase());
179 }
180 self
181 }
182
183 pub fn process(&self, query: &str) -> ProcessedQuery {
185 let normalized = self.normalize_query(query);
186 let tokens = self.tokenize(&normalized);
187
188 let (intent, intent_confidence) = self.classify_intent(&normalized, &tokens);
190
191 let entities = self.extract_entities(&normalized, &tokens);
193
194 let expansions = self.generate_expansions(&tokens);
196
197 let suggested_filters = self.suggest_filters(&entities);
199
200 ProcessedQuery {
201 original: query.to_string(),
202 normalized,
203 intent,
204 intent_confidence,
205 entities,
206 expansions,
207 suggested_filters,
208 }
209 }
210
211 pub fn get_expanded_terms(&self, query: &ProcessedQuery) -> Vec<String> {
213 let mut terms = vec![query.normalized.clone()];
214
215 for expansion in &query.expansions {
217 for term in &expansion.expanded {
218 if !terms.contains(term) {
219 terms.push(term.clone());
220 }
221 }
222 }
223
224 for entity in &query.entities {
226 if !terms.contains(&entity.text) {
227 terms.push(entity.text.clone());
228 }
229 }
230
231 terms
232 }
233
234 fn init_action_verbs(&mut self) {
237 let verbs = [
238 "create", "make", "add", "new", "generate",
239 "delete", "remove", "destroy", "drop",
240 "list", "get", "fetch", "retrieve", "show", "display",
241 "update", "modify", "change", "edit", "patch",
242 "read", "view", "inspect", "describe",
243 "run", "execute", "invoke", "call", "start",
244 "stop", "kill", "terminate", "cancel",
245 "deploy", "install", "setup", "configure",
246 "search", "find", "query", "filter",
247 "connect", "disconnect", "link",
248 "send", "receive", "push", "pull",
249 "upload", "download", "sync",
250 "validate", "verify", "check", "test",
251 ];
252 self.action_verbs = verbs.iter().map(|s| s.to_string()).collect();
253 }
254
255 fn init_synonyms(&mut self) {
256 let synonym_map = [
257 ("create", vec!["make", "generate", "add", "new", "build"]),
258 ("delete", vec!["remove", "destroy", "drop", "erase"]),
259 ("list", vec!["get", "show", "display", "fetch", "retrieve"]),
260 ("update", vec!["modify", "change", "edit", "patch", "alter"]),
261 ("run", vec!["execute", "invoke", "call", "start", "launch"]),
262 ("find", vec!["search", "query", "lookup", "locate"]),
263 ("stop", vec!["kill", "terminate", "cancel", "halt"]),
264 ("deploy", vec!["install", "setup", "release", "publish"]),
265 ("file", vec!["document", "artifact"]),
266 ("folder", vec!["directory", "dir"]),
267 ("container", vec!["pod", "instance"]),
268 ];
269
270 for (key, synonyms) in synonym_map {
271 self.synonyms.insert(key.to_string(), synonyms.iter().map(|s| s.to_string()).collect());
272 }
273 }
274
275 fn init_acronyms(&mut self) {
276 let acronym_map = [
277 ("k8s", "kubernetes"),
278 ("gh", "github"),
279 ("gl", "gitlab"),
280 ("db", "database"),
281 ("aws", "amazon web services"),
282 ("gcp", "google cloud platform"),
283 ("az", "azure"),
284 ("tf", "terraform"),
285 ("ci", "continuous integration"),
286 ("cd", "continuous deployment"),
287 ("api", "application programming interface"),
288 ("cli", "command line interface"),
289 ("env", "environment"),
290 ("vars", "variables"),
291 ("config", "configuration"),
292 ("auth", "authentication"),
293 ("repo", "repository"),
294 ];
295
296 for (acronym, expanded) in acronym_map {
297 self.acronyms.insert(acronym.to_string(), expanded.to_string());
298 }
299 }
300
301 fn init_categories(&mut self) {
302 let category_map = [
303 ("kubernetes", vec!["pod", "deployment", "service", "namespace", "ingress", "configmap", "secret", "node", "cluster"]),
304 ("git", vec!["commit", "branch", "merge", "pull", "push", "clone", "checkout", "repository", "repo"]),
305 ("database", vec!["query", "table", "schema", "index", "migration", "backup", "restore"]),
306 ("cloud", vec!["instance", "bucket", "function", "lambda", "storage", "network", "vpc"]),
307 ("docker", vec!["container", "image", "volume", "network", "compose"]),
308 ("file", vec!["read", "write", "copy", "move", "delete", "list", "directory"]),
309 ];
310
311 for (category, keywords) in category_map {
312 self.categories.insert(category.to_string(), keywords.iter().map(|s| s.to_string()).collect());
313 }
314 }
315
316 fn normalize_query(&self, query: &str) -> String {
317 let mut normalized = query.to_lowercase();
318
319 for (acronym, expanded) in &self.acronyms {
321 if normalized.contains(acronym) {
322 let pattern = format!(r"\b{}\b", acronym);
324 if let Ok(re) = regex_lite::Regex::new(&pattern) {
325 normalized = re.replace_all(&normalized, expanded.as_str()).to_string();
326 }
327 }
328 }
329
330 normalized.split_whitespace().collect::<Vec<_>>().join(" ")
332 }
333
334 fn tokenize(&self, text: &str) -> Vec<String> {
335 text.split_whitespace()
336 .map(|s| s.trim_matches(|c: char| c.is_ascii_punctuation()).to_string())
337 .filter(|s| !s.is_empty())
338 .collect()
339 }
340
341 fn classify_intent(&self, query: &str, _tokens: &[String]) -> (QueryIntent, f32) {
342 let query_lower = query.to_lowercase();
343
344 let execution_patterns = ["run ", "execute ", "invoke ", "call "];
346 for pattern in execution_patterns {
347 if query_lower.starts_with(pattern) {
348 return (QueryIntent::ToolExecution, 0.9);
349 }
350 }
351
352 if query_lower.contains(" vs ") ||
354 query_lower.contains(" versus ") ||
355 query_lower.contains("compare ") ||
356 query_lower.contains("difference between") {
357 return (QueryIntent::Comparison, 0.85);
358 }
359
360 let trouble_patterns = ["why ", "error", "fail", "not working", "issue", "problem", "debug"];
362 for pattern in trouble_patterns {
363 if query_lower.contains(pattern) {
364 return (QueryIntent::Troubleshooting, 0.8);
365 }
366 }
367
368 let doc_patterns = ["how does", "how to", "what is", "explain", "documentation", "help with"];
370 for pattern in doc_patterns {
371 if query_lower.contains(pattern) {
372 return (QueryIntent::ToolDocumentation, 0.75);
373 }
374 }
375
376 let discovery_patterns = ["what tools", "tools for", "which tool", "find tool", "available"];
378 for pattern in discovery_patterns {
379 if query_lower.contains(pattern) {
380 return (QueryIntent::ToolDiscovery, 0.7);
381 }
382 }
383
384 (QueryIntent::General, 0.5)
386 }
387
388 fn extract_entities(&self, _query: &str, tokens: &[String]) -> Vec<ExtractedEntity> {
389 let mut entities = Vec::new();
390
391 for (pos, token) in tokens.iter().enumerate() {
392 let token_lower = token.to_lowercase();
393
394 if self.known_skills.contains(&token_lower) {
396 entities.push(ExtractedEntity {
397 text: token.clone(),
398 entity_type: EntityType::SkillName,
399 confidence: 0.95,
400 position: pos,
401 });
402 continue;
403 }
404
405 if self.known_tools.contains(&token_lower) {
407 entities.push(ExtractedEntity {
408 text: token.clone(),
409 entity_type: EntityType::ToolName,
410 confidence: 0.95,
411 position: pos,
412 });
413 continue;
414 }
415
416 if self.action_verbs.contains(&token_lower) {
418 entities.push(ExtractedEntity {
419 text: token.clone(),
420 entity_type: EntityType::ActionVerb,
421 confidence: 0.85,
422 position: pos,
423 });
424 continue;
425 }
426
427 for (category, keywords) in &self.categories {
429 if keywords.iter().any(|k| token_lower.contains(k) || k.contains(&token_lower)) {
430 entities.push(ExtractedEntity {
431 text: category.clone(),
432 entity_type: EntityType::Category,
433 confidence: 0.75,
434 position: pos,
435 });
436 break;
437 }
438 }
439 }
440
441 let mut seen = HashSet::new();
443 entities.retain(|e| seen.insert((e.text.clone(), e.entity_type)));
444
445 entities
446 }
447
448 fn generate_expansions(&self, tokens: &[String]) -> Vec<QueryExpansion> {
449 let mut expansions = Vec::new();
450
451 for token in tokens {
452 let token_lower = token.to_lowercase();
453
454 if let Some(synonyms) = self.synonyms.get(&token_lower) {
456 expansions.push(QueryExpansion {
457 original: token.clone(),
458 expanded: synonyms.clone(),
459 expansion_type: ExpansionType::Synonym,
460 });
461 }
462
463 }
465
466 expansions
467 }
468
469 fn suggest_filters(&self, entities: &[ExtractedEntity]) -> Vec<SuggestedFilter> {
470 let mut filters = Vec::new();
471
472 for entity in entities {
473 match entity.entity_type {
474 EntityType::SkillName => {
475 filters.push(SuggestedFilter {
476 field: "skill_name".to_string(),
477 value: entity.text.clone(),
478 confidence: entity.confidence,
479 });
480 }
481 EntityType::Category => {
482 filters.push(SuggestedFilter {
483 field: "category".to_string(),
484 value: entity.text.clone(),
485 confidence: entity.confidence,
486 });
487 }
488 _ => {}
489 }
490 }
491
492 filters
493 }
494}
495
496mod regex_lite {
498 pub struct Regex(String);
499
500 impl Regex {
501 pub fn new(pattern: &str) -> Result<Self, ()> {
502 Ok(Regex(pattern.to_string()))
503 }
504
505 pub fn replace_all<'a>(&self, text: &'a str, replacement: &str) -> std::borrow::Cow<'a, str> {
506 let word = self.0.trim_start_matches(r"\b").trim_end_matches(r"\b");
508 let words: Vec<&str> = text.split_whitespace().collect();
509 let replaced: Vec<&str> = words.iter()
510 .map(|w| if w.to_lowercase() == word { replacement } else { *w })
511 .collect();
512 std::borrow::Cow::Owned(replaced.join(" "))
513 }
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_intent_classification_execution() {
523 let processor = QueryProcessor::new();
524
525 let query = processor.process("run list_pods");
526 assert_eq!(query.intent, QueryIntent::ToolExecution);
527 assert!(query.intent_confidence > 0.8);
528
529 let query = processor.process("execute get deployment");
530 assert_eq!(query.intent, QueryIntent::ToolExecution);
531 }
532
533 #[test]
534 fn test_intent_classification_comparison() {
535 let processor = QueryProcessor::new();
536
537 let query = processor.process("kubernetes vs docker");
538 assert_eq!(query.intent, QueryIntent::Comparison);
539
540 let query = processor.process("difference between list and get");
541 assert_eq!(query.intent, QueryIntent::Comparison);
542 }
543
544 #[test]
545 fn test_intent_classification_troubleshooting() {
546 let processor = QueryProcessor::new();
547
548 let query = processor.process("why is the pod failing");
549 assert_eq!(query.intent, QueryIntent::Troubleshooting);
550
551 let query = processor.process("error connecting to database");
552 assert_eq!(query.intent, QueryIntent::Troubleshooting);
553 }
554
555 #[test]
556 fn test_intent_classification_documentation() {
557 let processor = QueryProcessor::new();
558
559 let query = processor.process("how does list_pods work");
560 assert_eq!(query.intent, QueryIntent::ToolDocumentation);
561
562 let query = processor.process("explain kubernetes deployment");
563 assert_eq!(query.intent, QueryIntent::ToolDocumentation);
564 }
565
566 #[test]
567 fn test_entity_extraction_with_known_skills() {
568 let processor = QueryProcessor::new()
569 .with_skills(["kubernetes", "github", "docker"]);
570
571 let query = processor.process("list pods in kubernetes");
572 let skill_entities: Vec<_> = query.entities.iter()
573 .filter(|e| e.entity_type == EntityType::SkillName)
574 .collect();
575
576 assert_eq!(skill_entities.len(), 1);
577 assert_eq!(skill_entities[0].text, "kubernetes");
578 }
579
580 #[test]
581 fn test_entity_extraction_action_verbs() {
582 let processor = QueryProcessor::new();
583
584 let query = processor.process("create a new deployment");
585 let verb_entities: Vec<_> = query.entities.iter()
586 .filter(|e| e.entity_type == EntityType::ActionVerb)
587 .collect();
588
589 assert!(verb_entities.iter().any(|e| e.text == "create"));
590 }
591
592 #[test]
593 fn test_query_expansion_synonyms() {
594 let processor = QueryProcessor::new();
595
596 let query = processor.process("create pod");
597 let create_expansion = query.expansions.iter()
598 .find(|e| e.original.to_lowercase() == "create");
599
600 assert!(create_expansion.is_some());
601 let expansion = create_expansion.unwrap();
602 assert!(expansion.expanded.contains(&"make".to_string()));
603 assert!(expansion.expanded.contains(&"generate".to_string()));
604 }
605
606 #[test]
607 fn test_acronym_expansion() {
608 let processor = QueryProcessor::new();
609
610 let query = processor.process("list pods in k8s");
611 assert!(query.normalized.contains("kubernetes"));
612 }
613
614 #[test]
615 fn test_category_detection() {
616 let processor = QueryProcessor::new();
617
618 let query = processor.process("get deployment information");
619 let category_entities: Vec<_> = query.entities.iter()
620 .filter(|e| e.entity_type == EntityType::Category)
621 .collect();
622
623 assert!(category_entities.iter().any(|e| e.text == "kubernetes"));
625 }
626
627 #[test]
628 fn test_suggested_filters() {
629 let processor = QueryProcessor::new()
630 .with_skills(["kubernetes"]);
631
632 let query = processor.process("kubernetes pod list");
633 let skill_filters: Vec<_> = query.suggested_filters.iter()
634 .filter(|f| f.field == "skill_name")
635 .collect();
636
637 assert_eq!(skill_filters.len(), 1);
638 assert_eq!(skill_filters[0].value, "kubernetes");
639 }
640
641 #[test]
642 fn test_get_expanded_terms() {
643 let processor = QueryProcessor::new();
644
645 let query = processor.process("create deployment");
646 let terms = processor.get_expanded_terms(&query);
647
648 assert!(terms.iter().any(|t| t.contains("create") || t.contains("deployment")));
650 assert!(terms.len() > 1); }
652
653 #[test]
654 fn test_normalize_query() {
655 let processor = QueryProcessor::new();
656
657 let query = processor.process(" list pods ");
659 assert_eq!(query.normalized, "list pods");
660 }
661}