1use super::config::{ClassificationPattern, TaskClassificationConfig};
4use super::decision::RoutingContext;
5use super::error::{RoutingError, TaskType};
6use regex::Regex;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct TaskClassifier {
12 patterns: HashMap<TaskType, ClassificationPattern>,
14 compiled_patterns: HashMap<TaskType, Vec<Regex>>,
16 config: TaskClassificationConfig,
18}
19
20#[derive(Debug, Clone)]
22pub struct ClassificationResult {
23 pub task_type: TaskType,
24 pub confidence: f64,
25 pub matched_patterns: Vec<String>,
26 pub keyword_matches: Vec<String>,
27}
28
29impl TaskClassifier {
30 pub fn new(config: TaskClassificationConfig) -> Result<Self, RoutingError> {
32 let mut compiled_patterns = HashMap::new();
33
34 for (task_type, pattern) in &config.patterns {
36 let mut regexes = Vec::new();
37 for pattern_str in &pattern.patterns {
38 let regex =
39 Regex::new(pattern_str).map_err(|e| RoutingError::ConfigurationError {
40 key: format!("classification.patterns.{}.patterns", task_type),
41 reason: format!("Invalid regex pattern '{}': {}", pattern_str, e),
42 })?;
43 regexes.push(regex);
44 }
45 compiled_patterns.insert(task_type.clone(), regexes);
46 }
47
48 Ok(Self {
49 patterns: config.patterns.clone(),
50 compiled_patterns,
51 config,
52 })
53 }
54
55 pub fn classify_task(
57 &self,
58 prompt: &str,
59 context: &RoutingContext,
60 ) -> Result<ClassificationResult, RoutingError> {
61 if !self.config.enabled {
62 return Ok(ClassificationResult {
63 task_type: self.config.default_task_type.clone(),
64 confidence: 1.0,
65 matched_patterns: vec!["classification_disabled".to_string()],
66 keyword_matches: Vec::new(),
67 });
68 }
69
70 let prompt_lower = prompt.to_lowercase();
71 let mut scores = HashMap::new();
72 let mut all_matches = HashMap::new();
73
74 for (task_type, pattern) in &self.patterns {
76 let mut score = 0.0;
77 let mut matches = Vec::new();
78 let mut keyword_matches = Vec::new();
79
80 for keyword in &pattern.keywords {
82 if prompt_lower.contains(&keyword.to_lowercase()) {
83 score += pattern.weight * 0.5; keyword_matches.push(keyword.clone());
85 }
86 }
87
88 if let Some(regexes) = self.compiled_patterns.get(task_type) {
90 for (i, regex) in regexes.iter().enumerate() {
91 if regex.is_match(&prompt_lower) {
92 score += pattern.weight; matches.push(pattern.patterns[i].clone());
94 }
95 }
96 }
97
98 if score > 0.0 {
99 scores.insert(task_type.clone(), score);
100 all_matches.insert(task_type.clone(), (matches, keyword_matches));
101 }
102 }
103
104 self.apply_context_adjustments(&mut scores, context);
106
107 let best_match = scores
109 .iter()
110 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal));
111
112 let (task_type, raw_score) = match best_match {
113 Some((task_type, score)) => (task_type.clone(), *score),
114 None => {
115 return Ok(ClassificationResult {
117 task_type: self.config.default_task_type.clone(),
118 confidence: 0.0,
119 matched_patterns: vec!["no_patterns_matched".to_string()],
120 keyword_matches: Vec::new(),
121 });
122 }
123 };
124
125 let max_possible_score = self.calculate_max_possible_score(&task_type);
127 let confidence = if max_possible_score > 0.0 {
128 (raw_score / max_possible_score).min(1.0)
129 } else {
130 0.0
131 };
132
133 let (matched_patterns, keyword_matches) = all_matches
134 .get(&task_type)
135 .cloned()
136 .unwrap_or((Vec::new(), Vec::new()));
137
138 if confidence < self.config.confidence_threshold {
140 return Ok(ClassificationResult {
141 task_type: self.config.default_task_type.clone(),
142 confidence,
143 matched_patterns: vec!["confidence_below_threshold".to_string()],
144 keyword_matches: Vec::new(),
145 });
146 }
147
148 Ok(ClassificationResult {
149 task_type,
150 confidence,
151 matched_patterns,
152 keyword_matches,
153 })
154 }
155
156 fn apply_context_adjustments(
158 &self,
159 scores: &mut HashMap<TaskType, f64>,
160 context: &RoutingContext,
161 ) {
162 match context.expected_output_type {
164 super::decision::OutputType::Code => {
165 if let Some(score) = scores.get_mut(&TaskType::CodeGeneration) {
167 *score *= 1.5;
168 }
169 if let Some(score) = scores.get_mut(&TaskType::BoilerplateCode) {
170 *score *= 1.3;
171 }
172 }
173 super::decision::OutputType::Json | super::decision::OutputType::Structured => {
174 if let Some(score) = scores.get_mut(&TaskType::Extract) {
176 *score *= 1.4;
177 }
178 if let Some(score) = scores.get_mut(&TaskType::Analysis) {
179 *score *= 1.2;
180 }
181 }
182 _ => {}
183 }
184
185 for capability in &context.agent_capabilities {
187 match capability.as_str() {
188 "code_generation" => {
189 if let Some(score) = scores.get_mut(&TaskType::CodeGeneration) {
190 *score *= 1.2;
191 }
192 }
193 "analysis" => {
194 if let Some(score) = scores.get_mut(&TaskType::Analysis) {
195 *score *= 1.2;
196 }
197 if let Some(score) = scores.get_mut(&TaskType::Reasoning) {
198 *score *= 1.1;
199 }
200 }
201 "translation" => {
202 if let Some(score) = scores.get_mut(&TaskType::Translation) {
203 *score *= 1.3;
204 }
205 }
206 _ => {}
207 }
208 }
209
210 match context.agent_security_level {
212 super::decision::SecurityLevel::Critical | super::decision::SecurityLevel::High => {
213 if let Some(score) = scores.get_mut(&TaskType::Intent) {
215 *score *= 1.1;
216 }
217 if let Some(score) = scores.get_mut(&TaskType::Extract) {
218 *score *= 1.1;
219 }
220 if let Some(score) = scores.get_mut(&TaskType::Reasoning) {
222 *score *= 0.9;
223 }
224 }
225 _ => {}
226 }
227 }
228
229 fn calculate_max_possible_score(&self, task_type: &TaskType) -> f64 {
231 if let Some(pattern) = self.patterns.get(task_type) {
232 let keyword_score = pattern.keywords.len() as f64 * pattern.weight * 0.5;
233 let pattern_score = pattern.patterns.len() as f64 * pattern.weight;
234 keyword_score + pattern_score
235 } else {
236 1.0
237 }
238 }
239
240 pub fn add_pattern(
242 &mut self,
243 task_type: TaskType,
244 pattern: ClassificationPattern,
245 ) -> Result<(), RoutingError> {
246 let mut regexes = Vec::new();
248 for pattern_str in &pattern.patterns {
249 let regex = Regex::new(pattern_str).map_err(|e| RoutingError::ConfigurationError {
250 key: format!("pattern.{}", task_type),
251 reason: format!("Invalid regex pattern '{}': {}", pattern_str, e),
252 })?;
253 regexes.push(regex);
254 }
255
256 self.compiled_patterns.insert(task_type.clone(), regexes);
257 self.patterns.insert(task_type, pattern);
258 Ok(())
259 }
260
261 pub fn remove_pattern(&mut self, task_type: &TaskType) {
263 self.patterns.remove(task_type);
264 self.compiled_patterns.remove(task_type);
265 }
266
267 pub fn get_statistics(&self) -> ClassificationStatistics {
269 ClassificationStatistics {
270 total_patterns: self.patterns.len(),
271 task_type_coverage: self.patterns.keys().cloned().collect(),
272 total_keywords: self.patterns.values().map(|p| p.keywords.len()).sum(),
273 total_regex_patterns: self.patterns.values().map(|p| p.patterns.len()).sum(),
274 confidence_threshold: self.config.confidence_threshold,
275 default_task_type: self.config.default_task_type.clone(),
276 }
277 }
278}
279
280#[derive(Debug, Clone)]
282pub struct ClassificationStatistics {
283 pub total_patterns: usize,
284 pub task_type_coverage: Vec<TaskType>,
285 pub total_keywords: usize,
286 pub total_regex_patterns: usize,
287 pub confidence_threshold: f64,
288 pub default_task_type: TaskType,
289}
290
291#[cfg(test)]
292mod tests {
293 use super::super::decision::{OutputType, RoutingContext};
294 use super::*;
295 use crate::types::AgentId;
296
297 fn create_test_config() -> TaskClassificationConfig {
298 let mut patterns = HashMap::new();
299
300 patterns.insert(
301 TaskType::CodeGeneration,
302 ClassificationPattern {
303 keywords: vec![
304 "code".to_string(),
305 "function".to_string(),
306 "implement".to_string(),
307 ],
308 patterns: vec![
309 r"write.*code".to_string(),
310 r"implement.*function".to_string(),
311 ],
312 weight: 1.0,
313 },
314 );
315
316 patterns.insert(
317 TaskType::Analysis,
318 ClassificationPattern {
319 keywords: vec![
320 "analyze".to_string(),
321 "analysis".to_string(),
322 "examine".to_string(),
323 ],
324 patterns: vec![
325 r"analyze.*data".to_string(),
326 r"perform.*analysis".to_string(),
327 ],
328 weight: 1.0,
329 },
330 );
331
332 TaskClassificationConfig {
333 enabled: true,
334 patterns,
335 confidence_threshold: 0.3,
336 default_task_type: TaskType::Custom("unknown".to_string()),
337 }
338 }
339
340 fn create_test_context() -> RoutingContext {
341 RoutingContext::new(
342 AgentId::new(),
343 TaskType::Custom("unknown".to_string()),
344 "test prompt".to_string(),
345 )
346 }
347
348 #[test]
349 fn test_code_generation_classification() {
350 let config = create_test_config();
351 let classifier = TaskClassifier::new(config).unwrap();
352 let context = create_test_context();
353
354 let result = classifier
355 .classify_task(
356 "Please write code to implement a sorting function",
357 &context,
358 )
359 .unwrap();
360
361 assert_eq!(result.task_type, TaskType::CodeGeneration);
362 assert!(result.confidence > 0.5);
363 assert!(!result.keyword_matches.is_empty());
364 }
365
366 #[test]
367 fn test_analysis_classification() {
368 let config = create_test_config();
369 let classifier = TaskClassifier::new(config).unwrap();
370 let context = create_test_context();
371
372 let result = classifier
373 .classify_task("Please analyze the data trends", &context)
374 .unwrap();
375
376 assert_eq!(result.task_type, TaskType::Analysis);
377 assert!(result.confidence > 0.3);
378 }
379
380 #[test]
381 fn test_no_match_fallback() {
382 let config = create_test_config();
383 let classifier = TaskClassifier::new(config).unwrap();
384 let context = create_test_context();
385
386 let result = classifier.classify_task("Hello world", &context).unwrap();
387
388 assert_eq!(result.task_type, TaskType::Custom("unknown".to_string()));
389 assert_eq!(result.confidence, 0.0);
390 }
391
392 #[test]
393 fn test_context_adjustments() {
394 let config = create_test_config();
395 let classifier = TaskClassifier::new(config).unwrap();
396 let mut context = create_test_context();
397 context.expected_output_type = OutputType::Code;
398 context.agent_capabilities = vec!["code_generation".to_string()];
399
400 let result = classifier
401 .classify_task("Please write some code", &context)
402 .unwrap();
403
404 assert_eq!(result.task_type, TaskType::CodeGeneration);
405 assert!(result.confidence > 0.5);
407 }
408}