Skip to main content

symbi_runtime/routing/
classifier.rs

1//! Task classification system for routing decisions
2
3use super::config::{ClassificationPattern, TaskClassificationConfig};
4use super::decision::RoutingContext;
5use super::error::{RoutingError, TaskType};
6use regex::Regex;
7use std::collections::HashMap;
8
9/// Task classifier for determining task types from prompts
10#[derive(Debug, Clone)]
11pub struct TaskClassifier {
12    /// Classification patterns for each task type
13    patterns: HashMap<TaskType, ClassificationPattern>,
14    /// Compiled regex patterns for efficiency
15    compiled_patterns: HashMap<TaskType, Vec<Regex>>,
16    /// Configuration settings
17    config: TaskClassificationConfig,
18}
19
20/// Classification result with confidence score
21#[derive(Debug, Clone)]
22pub struct ClassificationResult {
23    pub task_type: TaskType,
24    pub confidence: f64,
25    pub matched_patterns: Vec<String>,
26    pub keyword_matches: Vec<String>,
27}
28
29impl TaskClassifier {
30    /// Create a new task classifier with the given configuration
31    pub fn new(config: TaskClassificationConfig) -> Result<Self, RoutingError> {
32        let mut compiled_patterns = HashMap::new();
33
34        // Compile regex patterns for efficiency
35        for (task_type, pattern) in &config.patterns {
36            let mut regexes = Vec::new();
37            for pattern_str in &pattern.patterns {
38                let regex =
39                    Regex::new(pattern_str).map_err(|e| RoutingError::ConfigurationError {
40                        key: format!("classification.patterns.{}.patterns", task_type),
41                        reason: format!("Invalid regex pattern '{}': {}", pattern_str, e),
42                    })?;
43                regexes.push(regex);
44            }
45            compiled_patterns.insert(task_type.clone(), regexes);
46        }
47
48        Ok(Self {
49            patterns: config.patterns.clone(),
50            compiled_patterns,
51            config,
52        })
53    }
54
55    /// Classify a task based on the prompt and context
56    pub fn classify_task(
57        &self,
58        prompt: &str,
59        context: &RoutingContext,
60    ) -> Result<ClassificationResult, RoutingError> {
61        if !self.config.enabled {
62            return Ok(ClassificationResult {
63                task_type: self.config.default_task_type.clone(),
64                confidence: 1.0,
65                matched_patterns: vec!["classification_disabled".to_string()],
66                keyword_matches: Vec::new(),
67            });
68        }
69
70        let prompt_lower = prompt.to_lowercase();
71        let mut scores = HashMap::new();
72        let mut all_matches = HashMap::new();
73
74        // Score each task type based on pattern matching
75        for (task_type, pattern) in &self.patterns {
76            let mut score = 0.0;
77            let mut matches = Vec::new();
78            let mut keyword_matches = Vec::new();
79
80            // Check keyword matches
81            for keyword in &pattern.keywords {
82                if prompt_lower.contains(&keyword.to_lowercase()) {
83                    score += pattern.weight * 0.5; // Keywords get half weight
84                    keyword_matches.push(keyword.clone());
85                }
86            }
87
88            // Check regex pattern matches
89            if let Some(regexes) = self.compiled_patterns.get(task_type) {
90                for (i, regex) in regexes.iter().enumerate() {
91                    if regex.is_match(&prompt_lower) {
92                        score += pattern.weight; // Full weight for regex matches
93                        matches.push(pattern.patterns[i].clone());
94                    }
95                }
96            }
97
98            if score > 0.0 {
99                scores.insert(task_type.clone(), score);
100                all_matches.insert(task_type.clone(), (matches, keyword_matches));
101            }
102        }
103
104        // Apply context-based adjustments
105        self.apply_context_adjustments(&mut scores, context);
106
107        // Find the highest scoring task type
108        let best_match = scores
109            .iter()
110            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal));
111
112        let (task_type, raw_score) = match best_match {
113            Some((task_type, score)) => (task_type.clone(), *score),
114            None => {
115                // No matches found, use default
116                return Ok(ClassificationResult {
117                    task_type: self.config.default_task_type.clone(),
118                    confidence: 0.0,
119                    matched_patterns: vec!["no_patterns_matched".to_string()],
120                    keyword_matches: Vec::new(),
121                });
122            }
123        };
124
125        // Normalize confidence score
126        let max_possible_score = self.calculate_max_possible_score(&task_type);
127        let confidence = if max_possible_score > 0.0 {
128            (raw_score / max_possible_score).min(1.0)
129        } else {
130            0.0
131        };
132
133        let (matched_patterns, keyword_matches) = all_matches
134            .get(&task_type)
135            .cloned()
136            .unwrap_or((Vec::new(), Vec::new()));
137
138        // Check if confidence meets threshold
139        if confidence < self.config.confidence_threshold {
140            return Ok(ClassificationResult {
141                task_type: self.config.default_task_type.clone(),
142                confidence,
143                matched_patterns: vec!["confidence_below_threshold".to_string()],
144                keyword_matches: Vec::new(),
145            });
146        }
147
148        Ok(ClassificationResult {
149            task_type,
150            confidence,
151            matched_patterns,
152            keyword_matches,
153        })
154    }
155
156    /// Apply context-based adjustments to scores
157    fn apply_context_adjustments(
158        &self,
159        scores: &mut HashMap<TaskType, f64>,
160        context: &RoutingContext,
161    ) {
162        // Adjust based on expected output type
163        match context.expected_output_type {
164            super::decision::OutputType::Code => {
165                // Boost code-related task types
166                if let Some(score) = scores.get_mut(&TaskType::CodeGeneration) {
167                    *score *= 1.5;
168                }
169                if let Some(score) = scores.get_mut(&TaskType::BoilerplateCode) {
170                    *score *= 1.3;
171                }
172            }
173            super::decision::OutputType::Json | super::decision::OutputType::Structured => {
174                // Boost extraction and analysis tasks
175                if let Some(score) = scores.get_mut(&TaskType::Extract) {
176                    *score *= 1.4;
177                }
178                if let Some(score) = scores.get_mut(&TaskType::Analysis) {
179                    *score *= 1.2;
180                }
181            }
182            _ => {}
183        }
184
185        // Adjust based on agent capabilities
186        for capability in &context.agent_capabilities {
187            match capability.as_str() {
188                "code_generation" => {
189                    if let Some(score) = scores.get_mut(&TaskType::CodeGeneration) {
190                        *score *= 1.2;
191                    }
192                }
193                "analysis" => {
194                    if let Some(score) = scores.get_mut(&TaskType::Analysis) {
195                        *score *= 1.2;
196                    }
197                    if let Some(score) = scores.get_mut(&TaskType::Reasoning) {
198                        *score *= 1.1;
199                    }
200                }
201                "translation" => {
202                    if let Some(score) = scores.get_mut(&TaskType::Translation) {
203                        *score *= 1.3;
204                    }
205                }
206                _ => {}
207            }
208        }
209
210        // Adjust based on security level (higher security might prefer certain task types)
211        match context.agent_security_level {
212            super::decision::SecurityLevel::Critical | super::decision::SecurityLevel::High => {
213                // Prefer simpler, more predictable tasks for high security
214                if let Some(score) = scores.get_mut(&TaskType::Intent) {
215                    *score *= 1.1;
216                }
217                if let Some(score) = scores.get_mut(&TaskType::Extract) {
218                    *score *= 1.1;
219                }
220                // Slightly penalize complex reasoning tasks
221                if let Some(score) = scores.get_mut(&TaskType::Reasoning) {
222                    *score *= 0.9;
223                }
224            }
225            _ => {}
226        }
227    }
228
229    /// Calculate the maximum possible score for a task type
230    fn calculate_max_possible_score(&self, task_type: &TaskType) -> f64 {
231        if let Some(pattern) = self.patterns.get(task_type) {
232            let keyword_score = pattern.keywords.len() as f64 * pattern.weight * 0.5;
233            let pattern_score = pattern.patterns.len() as f64 * pattern.weight;
234            keyword_score + pattern_score
235        } else {
236            1.0
237        }
238    }
239
240    /// Add or update a classification pattern
241    pub fn add_pattern(
242        &mut self,
243        task_type: TaskType,
244        pattern: ClassificationPattern,
245    ) -> Result<(), RoutingError> {
246        // Compile regex patterns
247        let mut regexes = Vec::new();
248        for pattern_str in &pattern.patterns {
249            let regex = Regex::new(pattern_str).map_err(|e| RoutingError::ConfigurationError {
250                key: format!("pattern.{}", task_type),
251                reason: format!("Invalid regex pattern '{}': {}", pattern_str, e),
252            })?;
253            regexes.push(regex);
254        }
255
256        self.compiled_patterns.insert(task_type.clone(), regexes);
257        self.patterns.insert(task_type, pattern);
258        Ok(())
259    }
260
261    /// Remove a classification pattern
262    pub fn remove_pattern(&mut self, task_type: &TaskType) {
263        self.patterns.remove(task_type);
264        self.compiled_patterns.remove(task_type);
265    }
266
267    /// Get classification statistics
268    pub fn get_statistics(&self) -> ClassificationStatistics {
269        ClassificationStatistics {
270            total_patterns: self.patterns.len(),
271            task_type_coverage: self.patterns.keys().cloned().collect(),
272            total_keywords: self.patterns.values().map(|p| p.keywords.len()).sum(),
273            total_regex_patterns: self.patterns.values().map(|p| p.patterns.len()).sum(),
274            confidence_threshold: self.config.confidence_threshold,
275            default_task_type: self.config.default_task_type.clone(),
276        }
277    }
278}
279
280/// Statistics about the task classifier
281#[derive(Debug, Clone)]
282pub struct ClassificationStatistics {
283    pub total_patterns: usize,
284    pub task_type_coverage: Vec<TaskType>,
285    pub total_keywords: usize,
286    pub total_regex_patterns: usize,
287    pub confidence_threshold: f64,
288    pub default_task_type: TaskType,
289}
290
291#[cfg(test)]
292mod tests {
293    use super::super::decision::{OutputType, RoutingContext};
294    use super::*;
295    use crate::types::AgentId;
296
297    fn create_test_config() -> TaskClassificationConfig {
298        let mut patterns = HashMap::new();
299
300        patterns.insert(
301            TaskType::CodeGeneration,
302            ClassificationPattern {
303                keywords: vec![
304                    "code".to_string(),
305                    "function".to_string(),
306                    "implement".to_string(),
307                ],
308                patterns: vec![
309                    r"write.*code".to_string(),
310                    r"implement.*function".to_string(),
311                ],
312                weight: 1.0,
313            },
314        );
315
316        patterns.insert(
317            TaskType::Analysis,
318            ClassificationPattern {
319                keywords: vec![
320                    "analyze".to_string(),
321                    "analysis".to_string(),
322                    "examine".to_string(),
323                ],
324                patterns: vec![
325                    r"analyze.*data".to_string(),
326                    r"perform.*analysis".to_string(),
327                ],
328                weight: 1.0,
329            },
330        );
331
332        TaskClassificationConfig {
333            enabled: true,
334            patterns,
335            confidence_threshold: 0.3,
336            default_task_type: TaskType::Custom("unknown".to_string()),
337        }
338    }
339
340    fn create_test_context() -> RoutingContext {
341        RoutingContext::new(
342            AgentId::new(),
343            TaskType::Custom("unknown".to_string()),
344            "test prompt".to_string(),
345        )
346    }
347
348    #[test]
349    fn test_code_generation_classification() {
350        let config = create_test_config();
351        let classifier = TaskClassifier::new(config).unwrap();
352        let context = create_test_context();
353
354        let result = classifier
355            .classify_task(
356                "Please write code to implement a sorting function",
357                &context,
358            )
359            .unwrap();
360
361        assert_eq!(result.task_type, TaskType::CodeGeneration);
362        assert!(result.confidence > 0.5);
363        assert!(!result.keyword_matches.is_empty());
364    }
365
366    #[test]
367    fn test_analysis_classification() {
368        let config = create_test_config();
369        let classifier = TaskClassifier::new(config).unwrap();
370        let context = create_test_context();
371
372        let result = classifier
373            .classify_task("Please analyze the data trends", &context)
374            .unwrap();
375
376        assert_eq!(result.task_type, TaskType::Analysis);
377        assert!(result.confidence > 0.3);
378    }
379
380    #[test]
381    fn test_no_match_fallback() {
382        let config = create_test_config();
383        let classifier = TaskClassifier::new(config).unwrap();
384        let context = create_test_context();
385
386        let result = classifier.classify_task("Hello world", &context).unwrap();
387
388        assert_eq!(result.task_type, TaskType::Custom("unknown".to_string()));
389        assert_eq!(result.confidence, 0.0);
390    }
391
392    #[test]
393    fn test_context_adjustments() {
394        let config = create_test_config();
395        let classifier = TaskClassifier::new(config).unwrap();
396        let mut context = create_test_context();
397        context.expected_output_type = OutputType::Code;
398        context.agent_capabilities = vec!["code_generation".to_string()];
399
400        let result = classifier
401            .classify_task("Please write some code", &context)
402            .unwrap();
403
404        assert_eq!(result.task_type, TaskType::CodeGeneration);
405        // Should have higher confidence due to context adjustments
406        assert!(result.confidence > 0.5);
407    }
408}