1use super::decision::{LLMProvider, MonitoringLevel, SecurityLevel};
4use super::error::TaskType;
5use crate::config::ResourceConstraints;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RoutingConfig {
13 pub enabled: bool,
15 pub policy: RoutingPolicyConfig,
17 pub classification: TaskClassificationConfig,
19 pub llm_providers: HashMap<String, LLMProviderConfig>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct RoutingPolicyConfig {
26 pub global_settings: GlobalRoutingSettings,
28 pub rules: Vec<RoutingRule>,
30 pub default_action: RouteAction,
32 pub fallback_config: FallbackConfig,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct GlobalRoutingSettings {
39 pub slm_routing_enabled: bool,
41 pub always_audit: bool,
43 pub global_confidence_threshold: f64,
45 pub max_slm_retries: u32,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct RoutingRule {
52 pub name: String,
54 pub priority: u32,
56 pub conditions: RoutingConditions,
58 pub action: RouteAction,
60 #[serde(default)]
62 pub action_extension: Option<ActionExtension>,
63 pub override_allowed: bool,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct RoutingConditions {
70 pub task_types: Option<Vec<TaskType>>,
72 pub agent_ids: Option<Vec<String>>,
74 pub resource_constraints: Option<ResourceConstraints>,
76 pub security_level: Option<SecurityLevel>,
78 pub custom_conditions: Option<Vec<String>>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub enum RouteAction {
85 UseSLM {
87 model_preference: ModelPreference,
88 monitoring_level: MonitoringLevel,
89 fallback_on_low_confidence: bool,
90 confidence_threshold: Option<f64>,
91 },
92 UseLLM {
94 provider: LLMProvider,
95 model: Option<String>,
96 },
97 Deny { reason: String },
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum ModelPreference {
104 Specialist,
106 Generalist,
108 Specific { model_id: String },
110 BestAvailable,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize, Default)]
116pub struct ActionExtension {
117 pub sandbox: Option<String>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct FallbackConfig {
124 pub enabled: bool,
126 pub max_attempts: u32,
128 #[serde(with = "humantime_serde")]
130 pub timeout: Duration,
131 pub providers: HashMap<String, LLMProviderConfig>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct LLMProviderConfig {
138 pub api_key_env: String,
140 pub base_url: String,
142 pub default_model: String,
144 #[serde(with = "humantime_serde")]
146 pub timeout: Duration,
147 pub max_retries: u32,
149 pub rate_limit: Option<RateLimitConfig>,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct RateLimitConfig {
156 pub requests_per_minute: u32,
158 pub tokens_per_minute: Option<u32>,
160 pub burst_allowance: Option<u32>,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct TaskClassificationConfig {
167 pub enabled: bool,
169 pub patterns: HashMap<TaskType, ClassificationPattern>,
171 pub confidence_threshold: f64,
173 pub default_task_type: TaskType,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ClassificationPattern {
180 pub keywords: Vec<String>,
182 pub patterns: Vec<String>,
184 pub weight: f64,
186}
187
188impl Default for RoutingConfig {
189 fn default() -> Self {
190 let mut llm_providers = HashMap::new();
191
192 llm_providers.insert(
193 "openai".to_string(),
194 LLMProviderConfig {
195 api_key_env: "OPENAI_API_KEY".to_string(),
196 base_url: "https://api.openai.com/v1".to_string(),
197 default_model: "gpt-3.5-turbo".to_string(),
198 timeout: Duration::from_secs(60),
199 max_retries: 3,
200 rate_limit: Some(RateLimitConfig {
201 requests_per_minute: 60,
202 tokens_per_minute: Some(10000),
203 burst_allowance: Some(10),
204 }),
205 },
206 );
207
208 llm_providers.insert(
209 "anthropic".to_string(),
210 LLMProviderConfig {
211 api_key_env: "ANTHROPIC_API_KEY".to_string(),
212 base_url: "https://api.anthropic.com".to_string(),
213 default_model: "claude-3-sonnet-20240229".to_string(),
214 timeout: Duration::from_secs(60),
215 max_retries: 3,
216 rate_limit: Some(RateLimitConfig {
217 requests_per_minute: 60,
218 tokens_per_minute: Some(10000),
219 burst_allowance: Some(10),
220 }),
221 },
222 );
223
224 Self {
225 enabled: true,
226 policy: RoutingPolicyConfig::default(),
227 classification: TaskClassificationConfig::default(),
228 llm_providers,
229 }
230 }
231}
232
233impl Default for RoutingPolicyConfig {
234 fn default() -> Self {
235 Self {
236 global_settings: GlobalRoutingSettings::default(),
237 rules: Vec::new(),
238 default_action: RouteAction::UseLLM {
239 provider: LLMProvider::OpenAI { model: None },
240 model: Some("gpt-3.5-turbo".to_string()),
241 },
242 fallback_config: FallbackConfig::default(),
243 }
244 }
245}
246
247impl Default for GlobalRoutingSettings {
248 fn default() -> Self {
249 Self {
250 slm_routing_enabled: true,
251 always_audit: true,
252 global_confidence_threshold: 0.85,
253 max_slm_retries: 2,
254 }
255 }
256}
257
258impl Default for FallbackConfig {
259 fn default() -> Self {
260 let mut providers = HashMap::new();
261 providers.insert(
262 "primary".to_string(),
263 LLMProviderConfig {
264 api_key_env: "OPENAI_API_KEY".to_string(),
265 base_url: "https://api.openai.com/v1".to_string(),
266 default_model: "gpt-3.5-turbo".to_string(),
267 timeout: Duration::from_secs(60),
268 max_retries: 3,
269 rate_limit: None,
270 },
271 );
272
273 Self {
274 enabled: true,
275 max_attempts: 3,
276 timeout: Duration::from_secs(30),
277 providers,
278 }
279 }
280}
281
282impl Default for TaskClassificationConfig {
283 fn default() -> Self {
284 let mut patterns = HashMap::new();
285
286 patterns.insert(
287 TaskType::Intent,
288 ClassificationPattern {
289 keywords: vec![
290 "intent".to_string(),
291 "intention".to_string(),
292 "purpose".to_string(),
293 ],
294 patterns: vec![r"what.*intent".to_string(), r"user.*wants".to_string()],
295 weight: 1.0,
296 },
297 );
298
299 patterns.insert(
300 TaskType::CodeGeneration,
301 ClassificationPattern {
302 keywords: vec![
303 "code".to_string(),
304 "function".to_string(),
305 "implement".to_string(),
306 "generate".to_string(),
307 ],
308 patterns: vec![
309 r"write.*code".to_string(),
310 r"implement.*function".to_string(),
311 ],
312 weight: 1.0,
313 },
314 );
315
316 patterns.insert(
317 TaskType::Analysis,
318 ClassificationPattern {
319 keywords: vec![
320 "analyze".to_string(),
321 "analysis".to_string(),
322 "examine".to_string(),
323 "review".to_string(),
324 ],
325 patterns: vec![
326 r"analyze.*data".to_string(),
327 r"perform.*analysis".to_string(),
328 ],
329 weight: 1.0,
330 },
331 );
332
333 Self {
334 enabled: true,
335 patterns,
336 confidence_threshold: 0.7,
337 default_task_type: TaskType::Custom("unknown".to_string()),
338 }
339 }
340}
341
342impl RoutingRule {
343 pub fn matches(&self, context: &super::decision::RoutingContext) -> bool {
345 if let Some(ref task_types) = self.conditions.task_types {
347 if !task_types.contains(&context.task_type) {
348 return false;
349 }
350 }
351
352 if let Some(ref agent_ids) = self.conditions.agent_ids {
354 if !agent_ids.contains(&context.agent_id.to_string()) {
355 return false;
356 }
357 }
358
359 if let Some(ref required_level) = self.conditions.security_level {
361 if context.agent_security_level < *required_level {
362 return false;
363 }
364 }
365
366 if let Some(ref rule_constraints) = self.conditions.resource_constraints {
368 if let Some(ref context_limits) = context.resource_limits {
369 if context_limits.max_memory_mb > rule_constraints.max_memory_mb {
370 return false;
371 }
372 }
373 }
374
375 if let Some(ref custom_conditions) = self.conditions.custom_conditions {
377 for condition_expr in custom_conditions {
378 if !self.evaluate_custom_condition(condition_expr, context) {
379 return false;
380 }
381 }
382 }
383
384 true
385 }
386
387 fn evaluate_custom_condition(
389 &self,
390 condition_expr: &str,
391 context: &super::decision::RoutingContext,
392 ) -> bool {
393 if condition_expr.contains("agent_id") {
398 if let Some(expected_id) = condition_expr.strip_prefix("agent_id == ") {
399 let expected_id = expected_id.trim_matches('"');
400 return context.agent_id.to_string() == expected_id;
401 }
402 }
403
404 if condition_expr.contains("task_type") {
405 if let Some(expected_type) = condition_expr.strip_prefix("task_type == ") {
406 let expected_type = expected_type.trim_matches('"');
407 return format!("{:?}", context.task_type)
408 .to_lowercase()
409 .contains(&expected_type.to_lowercase());
410 }
411 }
412
413 if condition_expr.contains("security_level") && condition_expr.contains(">=") {
414 if let Some(level_str) = condition_expr.strip_prefix("security_level >= ") {
415 if let Ok(required_level) = level_str.trim().parse::<u8>() {
416 let current_level = match context.agent_security_level {
417 SecurityLevel::Low => 1,
418 SecurityLevel::Medium => 2,
419 SecurityLevel::High => 3,
420 SecurityLevel::Critical => 4,
421 };
422 return current_level >= required_level;
423 }
424 }
425 }
426
427 if condition_expr.contains("memory_limit") {
428 if let Some(ref resource_limits) = context.resource_limits {
429 if condition_expr.contains("<=") {
430 if let Some(limit_str) = condition_expr.strip_prefix("memory_limit <= ") {
431 if let Ok(max_memory) = limit_str.trim().parse::<u64>() {
432 return resource_limits.max_memory_mb <= max_memory;
433 }
434 }
435 }
436 }
437 }
438
439 if condition_expr == "true" {
441 return true;
442 }
443 if condition_expr == "false" {
444 return false;
445 }
446
447 tracing::warn!("Unrecognized custom condition: {}", condition_expr);
449 true
450 }
451}
452
453impl RoutingPolicyConfig {
454 pub fn validate(&self) -> Result<(), super::error::RoutingError> {
456 let mut prev_priority = u32::MAX;
458 for rule in &self.rules {
459 if rule.priority > prev_priority {
460 return Err(super::error::RoutingError::ConfigurationError {
461 key: "policy.rules".to_string(),
462 reason: "Rules must be ordered by priority (highest first)".to_string(),
463 });
464 }
465 prev_priority = rule.priority;
466 }
467
468 if self.global_settings.global_confidence_threshold < 0.0
470 || self.global_settings.global_confidence_threshold > 1.0
471 {
472 return Err(super::error::RoutingError::ConfigurationError {
473 key: "policy.global_settings.global_confidence_threshold".to_string(),
474 reason: "Confidence threshold must be between 0.0 and 1.0".to_string(),
475 });
476 }
477
478 Ok(())
479 }
480}