skill_runtime/search/
query_processor.rs

1//! Query understanding with intent classification and entity extraction
2//!
3//! Provides intelligent query preprocessing to improve search relevance
4//! through rule-based intent detection, entity recognition, and query expansion.
5
6use std::collections::{HashMap, HashSet};
7
8/// Query intent classification
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum QueryIntent {
11    /// User wants to discover what tools can do something
12    /// E.g., "what tools can list files", "tools for kubernetes"
13    ToolDiscovery,
14    /// User wants to execute a specific tool
15    /// E.g., "run list_pods", "execute get deployment"
16    ToolExecution,
17    /// User wants documentation on how a tool works
18    /// E.g., "how does list_pods work", "explain kubernetes tool"
19    ToolDocumentation,
20    /// User wants to compare tools
21    /// E.g., "difference between X and Y", "X vs Y"
22    Comparison,
23    /// User is troubleshooting
24    /// E.g., "why is X failing", "error with X"
25    Troubleshooting,
26    /// General query - no specific intent detected
27    General,
28}
29
30impl QueryIntent {
31    /// Get confidence threshold for this intent
32    pub fn confidence_threshold(&self) -> f32 {
33        match self {
34            QueryIntent::ToolExecution => 0.8,
35            QueryIntent::Comparison => 0.7,
36            QueryIntent::Troubleshooting => 0.7,
37            QueryIntent::ToolDocumentation => 0.6,
38            QueryIntent::ToolDiscovery => 0.5,
39            QueryIntent::General => 0.0,
40        }
41    }
42}
43
44/// Entity type for extracted entities
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum EntityType {
47    /// Skill name (e.g., "kubernetes", "github")
48    SkillName,
49    /// Tool name (e.g., "list_pods", "create_issue")
50    ToolName,
51    /// Action verb (e.g., "create", "delete", "list")
52    ActionVerb,
53    /// Category (e.g., "database", "cloud", "git")
54    Category,
55    /// Target/object (e.g., "pods", "files", "users")
56    Target,
57}
58
59/// Extracted entity from query
60#[derive(Debug, Clone)]
61pub struct ExtractedEntity {
62    /// The entity text
63    pub text: String,
64    /// Entity type
65    pub entity_type: EntityType,
66    /// Confidence score (0.0-1.0)
67    pub confidence: f32,
68    /// Position in the original query
69    pub position: usize,
70}
71
72/// Query expansion with synonyms and related terms
73#[derive(Debug, Clone)]
74pub struct QueryExpansion {
75    /// Original term
76    pub original: String,
77    /// Expanded terms (synonyms, related)
78    pub expanded: Vec<String>,
79    /// Expansion type
80    pub expansion_type: ExpansionType,
81}
82
83/// Type of query expansion
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum ExpansionType {
86    /// Synonym expansion (create -> make, generate)
87    Synonym,
88    /// Acronym expansion (k8s -> kubernetes)
89    Acronym,
90    /// Pattern expansion (get pods -> list pods)
91    Pattern,
92}
93
94/// Processed query with intent, entities, and expansions
95#[derive(Debug, Clone)]
96pub struct ProcessedQuery {
97    /// Original query string
98    pub original: String,
99    /// Normalized/cleaned query
100    pub normalized: String,
101    /// Detected intent
102    pub intent: QueryIntent,
103    /// Intent confidence score
104    pub intent_confidence: f32,
105    /// Extracted entities
106    pub entities: Vec<ExtractedEntity>,
107    /// Query expansions
108    pub expansions: Vec<QueryExpansion>,
109    /// Suggested search filters
110    pub suggested_filters: Vec<SuggestedFilter>,
111}
112
113/// Suggested filter for search
114#[derive(Debug, Clone)]
115pub struct SuggestedFilter {
116    /// Filter field
117    pub field: String,
118    /// Filter value
119    pub value: String,
120    /// Confidence
121    pub confidence: f32,
122}
123
124/// Query processor for intelligent search preprocessing
125pub struct QueryProcessor {
126    /// Known skill names for entity extraction
127    known_skills: HashSet<String>,
128    /// Known tool names for entity extraction
129    known_tools: HashSet<String>,
130    /// Synonyms for query expansion
131    synonyms: HashMap<String, Vec<String>>,
132    /// Acronyms for expansion
133    acronyms: HashMap<String, String>,
134    /// Action verb patterns
135    action_verbs: HashSet<String>,
136    /// Category keywords
137    categories: HashMap<String, Vec<String>>,
138}
139
140impl Default for QueryProcessor {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl QueryProcessor {
147    /// Create a new query processor with default knowledge
148    pub fn new() -> Self {
149        let mut processor = Self {
150            known_skills: HashSet::new(),
151            known_tools: HashSet::new(),
152            synonyms: HashMap::new(),
153            acronyms: HashMap::new(),
154            action_verbs: HashSet::new(),
155            categories: HashMap::new(),
156        };
157
158        // Initialize with common knowledge
159        processor.init_action_verbs();
160        processor.init_synonyms();
161        processor.init_acronyms();
162        processor.init_categories();
163
164        processor
165    }
166
167    /// Add known skills for entity extraction
168    pub fn with_skills(mut self, skills: impl IntoIterator<Item = impl Into<String>>) -> Self {
169        for skill in skills {
170            self.known_skills.insert(skill.into().to_lowercase());
171        }
172        self
173    }
174
175    /// Add known tools for entity extraction
176    pub fn with_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
177        for tool in tools {
178            self.known_tools.insert(tool.into().to_lowercase());
179        }
180        self
181    }
182
183    /// Process a query for improved search
184    pub fn process(&self, query: &str) -> ProcessedQuery {
185        let normalized = self.normalize_query(query);
186        let tokens = self.tokenize(&normalized);
187
188        // Classify intent
189        let (intent, intent_confidence) = self.classify_intent(&normalized, &tokens);
190
191        // Extract entities
192        let entities = self.extract_entities(&normalized, &tokens);
193
194        // Generate expansions
195        let expansions = self.generate_expansions(&tokens);
196
197        // Suggest filters based on entities
198        let suggested_filters = self.suggest_filters(&entities);
199
200        ProcessedQuery {
201            original: query.to_string(),
202            normalized,
203            intent,
204            intent_confidence,
205            entities,
206            expansions,
207            suggested_filters,
208        }
209    }
210
211    /// Get expanded query terms for search
212    pub fn get_expanded_terms(&self, query: &ProcessedQuery) -> Vec<String> {
213        let mut terms = vec![query.normalized.clone()];
214
215        // Add expansion terms
216        for expansion in &query.expansions {
217            for term in &expansion.expanded {
218                if !terms.contains(term) {
219                    terms.push(term.clone());
220                }
221            }
222        }
223
224        // Add entity text with higher weight potential
225        for entity in &query.entities {
226            if !terms.contains(&entity.text) {
227                terms.push(entity.text.clone());
228            }
229        }
230
231        terms
232    }
233
234    // --- Internal methods ---
235
236    fn init_action_verbs(&mut self) {
237        let verbs = [
238            "create", "make", "add", "new", "generate",
239            "delete", "remove", "destroy", "drop",
240            "list", "get", "fetch", "retrieve", "show", "display",
241            "update", "modify", "change", "edit", "patch",
242            "read", "view", "inspect", "describe",
243            "run", "execute", "invoke", "call", "start",
244            "stop", "kill", "terminate", "cancel",
245            "deploy", "install", "setup", "configure",
246            "search", "find", "query", "filter",
247            "connect", "disconnect", "link",
248            "send", "receive", "push", "pull",
249            "upload", "download", "sync",
250            "validate", "verify", "check", "test",
251        ];
252        self.action_verbs = verbs.iter().map(|s| s.to_string()).collect();
253    }
254
255    fn init_synonyms(&mut self) {
256        let synonym_map = [
257            ("create", vec!["make", "generate", "add", "new", "build"]),
258            ("delete", vec!["remove", "destroy", "drop", "erase"]),
259            ("list", vec!["get", "show", "display", "fetch", "retrieve"]),
260            ("update", vec!["modify", "change", "edit", "patch", "alter"]),
261            ("run", vec!["execute", "invoke", "call", "start", "launch"]),
262            ("find", vec!["search", "query", "lookup", "locate"]),
263            ("stop", vec!["kill", "terminate", "cancel", "halt"]),
264            ("deploy", vec!["install", "setup", "release", "publish"]),
265            ("file", vec!["document", "artifact"]),
266            ("folder", vec!["directory", "dir"]),
267            ("container", vec!["pod", "instance"]),
268        ];
269
270        for (key, synonyms) in synonym_map {
271            self.synonyms.insert(key.to_string(), synonyms.iter().map(|s| s.to_string()).collect());
272        }
273    }
274
275    fn init_acronyms(&mut self) {
276        let acronym_map = [
277            ("k8s", "kubernetes"),
278            ("gh", "github"),
279            ("gl", "gitlab"),
280            ("db", "database"),
281            ("aws", "amazon web services"),
282            ("gcp", "google cloud platform"),
283            ("az", "azure"),
284            ("tf", "terraform"),
285            ("ci", "continuous integration"),
286            ("cd", "continuous deployment"),
287            ("api", "application programming interface"),
288            ("cli", "command line interface"),
289            ("env", "environment"),
290            ("vars", "variables"),
291            ("config", "configuration"),
292            ("auth", "authentication"),
293            ("repo", "repository"),
294        ];
295
296        for (acronym, expanded) in acronym_map {
297            self.acronyms.insert(acronym.to_string(), expanded.to_string());
298        }
299    }
300
301    fn init_categories(&mut self) {
302        let category_map = [
303            ("kubernetes", vec!["pod", "deployment", "service", "namespace", "ingress", "configmap", "secret", "node", "cluster"]),
304            ("git", vec!["commit", "branch", "merge", "pull", "push", "clone", "checkout", "repository", "repo"]),
305            ("database", vec!["query", "table", "schema", "index", "migration", "backup", "restore"]),
306            ("cloud", vec!["instance", "bucket", "function", "lambda", "storage", "network", "vpc"]),
307            ("docker", vec!["container", "image", "volume", "network", "compose"]),
308            ("file", vec!["read", "write", "copy", "move", "delete", "list", "directory"]),
309        ];
310
311        for (category, keywords) in category_map {
312            self.categories.insert(category.to_string(), keywords.iter().map(|s| s.to_string()).collect());
313        }
314    }
315
316    fn normalize_query(&self, query: &str) -> String {
317        let mut normalized = query.to_lowercase();
318
319        // Expand acronyms
320        for (acronym, expanded) in &self.acronyms {
321            if normalized.contains(acronym) {
322                // Only expand if it's a whole word
323                let pattern = format!(r"\b{}\b", acronym);
324                if let Ok(re) = regex_lite::Regex::new(&pattern) {
325                    normalized = re.replace_all(&normalized, expanded.as_str()).to_string();
326                }
327            }
328        }
329
330        // Remove extra whitespace
331        normalized.split_whitespace().collect::<Vec<_>>().join(" ")
332    }
333
334    fn tokenize(&self, text: &str) -> Vec<String> {
335        text.split_whitespace()
336            .map(|s| s.trim_matches(|c: char| c.is_ascii_punctuation()).to_string())
337            .filter(|s| !s.is_empty())
338            .collect()
339    }
340
341    fn classify_intent(&self, query: &str, _tokens: &[String]) -> (QueryIntent, f32) {
342        let query_lower = query.to_lowercase();
343
344        // Check for execution intent (highest priority)
345        let execution_patterns = ["run ", "execute ", "invoke ", "call "];
346        for pattern in execution_patterns {
347            if query_lower.starts_with(pattern) {
348                return (QueryIntent::ToolExecution, 0.9);
349            }
350        }
351
352        // Check for comparison intent
353        if query_lower.contains(" vs ") ||
354           query_lower.contains(" versus ") ||
355           query_lower.contains("compare ") ||
356           query_lower.contains("difference between") {
357            return (QueryIntent::Comparison, 0.85);
358        }
359
360        // Check for troubleshooting intent
361        let trouble_patterns = ["why ", "error", "fail", "not working", "issue", "problem", "debug"];
362        for pattern in trouble_patterns {
363            if query_lower.contains(pattern) {
364                return (QueryIntent::Troubleshooting, 0.8);
365            }
366        }
367
368        // Check for documentation intent
369        let doc_patterns = ["how does", "how to", "what is", "explain", "documentation", "help with"];
370        for pattern in doc_patterns {
371            if query_lower.contains(pattern) {
372                return (QueryIntent::ToolDocumentation, 0.75);
373            }
374        }
375
376        // Check for discovery intent
377        let discovery_patterns = ["what tools", "tools for", "which tool", "find tool", "available"];
378        for pattern in discovery_patterns {
379            if query_lower.contains(pattern) {
380                return (QueryIntent::ToolDiscovery, 0.7);
381            }
382        }
383
384        // Default to general with lower confidence
385        (QueryIntent::General, 0.5)
386    }
387
388    fn extract_entities(&self, _query: &str, tokens: &[String]) -> Vec<ExtractedEntity> {
389        let mut entities = Vec::new();
390
391        for (pos, token) in tokens.iter().enumerate() {
392            let token_lower = token.to_lowercase();
393
394            // Check for known skills
395            if self.known_skills.contains(&token_lower) {
396                entities.push(ExtractedEntity {
397                    text: token.clone(),
398                    entity_type: EntityType::SkillName,
399                    confidence: 0.95,
400                    position: pos,
401                });
402                continue;
403            }
404
405            // Check for known tools
406            if self.known_tools.contains(&token_lower) {
407                entities.push(ExtractedEntity {
408                    text: token.clone(),
409                    entity_type: EntityType::ToolName,
410                    confidence: 0.95,
411                    position: pos,
412                });
413                continue;
414            }
415
416            // Check for action verbs
417            if self.action_verbs.contains(&token_lower) {
418                entities.push(ExtractedEntity {
419                    text: token.clone(),
420                    entity_type: EntityType::ActionVerb,
421                    confidence: 0.85,
422                    position: pos,
423                });
424                continue;
425            }
426
427            // Check for category matches
428            for (category, keywords) in &self.categories {
429                if keywords.iter().any(|k| token_lower.contains(k) || k.contains(&token_lower)) {
430                    entities.push(ExtractedEntity {
431                        text: category.clone(),
432                        entity_type: EntityType::Category,
433                        confidence: 0.75,
434                        position: pos,
435                    });
436                    break;
437                }
438            }
439        }
440
441        // Deduplicate entities by (text, type) pair
442        let mut seen = HashSet::new();
443        entities.retain(|e| seen.insert((e.text.clone(), e.entity_type)));
444
445        entities
446    }
447
448    fn generate_expansions(&self, tokens: &[String]) -> Vec<QueryExpansion> {
449        let mut expansions = Vec::new();
450
451        for token in tokens {
452            let token_lower = token.to_lowercase();
453
454            // Check for synonym expansion
455            if let Some(synonyms) = self.synonyms.get(&token_lower) {
456                expansions.push(QueryExpansion {
457                    original: token.clone(),
458                    expanded: synonyms.clone(),
459                    expansion_type: ExpansionType::Synonym,
460                });
461            }
462
463            // Note: Acronym expansion is handled in normalize_query
464        }
465
466        expansions
467    }
468
469    fn suggest_filters(&self, entities: &[ExtractedEntity]) -> Vec<SuggestedFilter> {
470        let mut filters = Vec::new();
471
472        for entity in entities {
473            match entity.entity_type {
474                EntityType::SkillName => {
475                    filters.push(SuggestedFilter {
476                        field: "skill_name".to_string(),
477                        value: entity.text.clone(),
478                        confidence: entity.confidence,
479                    });
480                }
481                EntityType::Category => {
482                    filters.push(SuggestedFilter {
483                        field: "category".to_string(),
484                        value: entity.text.clone(),
485                        confidence: entity.confidence,
486                    });
487                }
488                _ => {}
489            }
490        }
491
492        filters
493    }
494}
495
496// Note: Using regex-lite instead of full regex for lighter dependency
497mod regex_lite {
498    pub struct Regex(String);
499
500    impl Regex {
501        pub fn new(pattern: &str) -> Result<Self, ()> {
502            Ok(Regex(pattern.to_string()))
503        }
504
505        pub fn replace_all<'a>(&self, text: &'a str, replacement: &str) -> std::borrow::Cow<'a, str> {
506            // Simple word boundary replacement
507            let word = self.0.trim_start_matches(r"\b").trim_end_matches(r"\b");
508            let words: Vec<&str> = text.split_whitespace().collect();
509            let replaced: Vec<&str> = words.iter()
510                .map(|w| if w.to_lowercase() == word { replacement } else { *w })
511                .collect();
512            std::borrow::Cow::Owned(replaced.join(" "))
513        }
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn test_intent_classification_execution() {
523        let processor = QueryProcessor::new();
524
525        let query = processor.process("run list_pods");
526        assert_eq!(query.intent, QueryIntent::ToolExecution);
527        assert!(query.intent_confidence > 0.8);
528
529        let query = processor.process("execute get deployment");
530        assert_eq!(query.intent, QueryIntent::ToolExecution);
531    }
532
533    #[test]
534    fn test_intent_classification_comparison() {
535        let processor = QueryProcessor::new();
536
537        let query = processor.process("kubernetes vs docker");
538        assert_eq!(query.intent, QueryIntent::Comparison);
539
540        let query = processor.process("difference between list and get");
541        assert_eq!(query.intent, QueryIntent::Comparison);
542    }
543
544    #[test]
545    fn test_intent_classification_troubleshooting() {
546        let processor = QueryProcessor::new();
547
548        let query = processor.process("why is the pod failing");
549        assert_eq!(query.intent, QueryIntent::Troubleshooting);
550
551        let query = processor.process("error connecting to database");
552        assert_eq!(query.intent, QueryIntent::Troubleshooting);
553    }
554
555    #[test]
556    fn test_intent_classification_documentation() {
557        let processor = QueryProcessor::new();
558
559        let query = processor.process("how does list_pods work");
560        assert_eq!(query.intent, QueryIntent::ToolDocumentation);
561
562        let query = processor.process("explain kubernetes deployment");
563        assert_eq!(query.intent, QueryIntent::ToolDocumentation);
564    }
565
566    #[test]
567    fn test_entity_extraction_with_known_skills() {
568        let processor = QueryProcessor::new()
569            .with_skills(["kubernetes", "github", "docker"]);
570
571        let query = processor.process("list pods in kubernetes");
572        let skill_entities: Vec<_> = query.entities.iter()
573            .filter(|e| e.entity_type == EntityType::SkillName)
574            .collect();
575
576        assert_eq!(skill_entities.len(), 1);
577        assert_eq!(skill_entities[0].text, "kubernetes");
578    }
579
580    #[test]
581    fn test_entity_extraction_action_verbs() {
582        let processor = QueryProcessor::new();
583
584        let query = processor.process("create a new deployment");
585        let verb_entities: Vec<_> = query.entities.iter()
586            .filter(|e| e.entity_type == EntityType::ActionVerb)
587            .collect();
588
589        assert!(verb_entities.iter().any(|e| e.text == "create"));
590    }
591
592    #[test]
593    fn test_query_expansion_synonyms() {
594        let processor = QueryProcessor::new();
595
596        let query = processor.process("create pod");
597        let create_expansion = query.expansions.iter()
598            .find(|e| e.original.to_lowercase() == "create");
599
600        assert!(create_expansion.is_some());
601        let expansion = create_expansion.unwrap();
602        assert!(expansion.expanded.contains(&"make".to_string()));
603        assert!(expansion.expanded.contains(&"generate".to_string()));
604    }
605
606    #[test]
607    fn test_acronym_expansion() {
608        let processor = QueryProcessor::new();
609
610        let query = processor.process("list pods in k8s");
611        assert!(query.normalized.contains("kubernetes"));
612    }
613
614    #[test]
615    fn test_category_detection() {
616        let processor = QueryProcessor::new();
617
618        let query = processor.process("get deployment information");
619        let category_entities: Vec<_> = query.entities.iter()
620            .filter(|e| e.entity_type == EntityType::Category)
621            .collect();
622
623        // "deployment" should trigger kubernetes category
624        assert!(category_entities.iter().any(|e| e.text == "kubernetes"));
625    }
626
627    #[test]
628    fn test_suggested_filters() {
629        let processor = QueryProcessor::new()
630            .with_skills(["kubernetes"]);
631
632        let query = processor.process("kubernetes pod list");
633        let skill_filters: Vec<_> = query.suggested_filters.iter()
634            .filter(|f| f.field == "skill_name")
635            .collect();
636
637        assert_eq!(skill_filters.len(), 1);
638        assert_eq!(skill_filters[0].value, "kubernetes");
639    }
640
641    #[test]
642    fn test_get_expanded_terms() {
643        let processor = QueryProcessor::new();
644
645        let query = processor.process("create deployment");
646        let terms = processor.get_expanded_terms(&query);
647
648        // Should include original and expansions
649        assert!(terms.iter().any(|t| t.contains("create") || t.contains("deployment")));
650        assert!(terms.len() > 1); // Should have expansions
651    }
652
653    #[test]
654    fn test_normalize_query() {
655        let processor = QueryProcessor::new();
656
657        // Test whitespace normalization
658        let query = processor.process("  list    pods  ");
659        assert_eq!(query.normalized, "list pods");
660    }
661}