Skip to main content

symbi_runtime/routing/
config.rs

1//! Configuration types for the routing module
2
3use super::decision::{LLMProvider, MonitoringLevel, SecurityLevel};
4use super::error::TaskType;
5use crate::config::ResourceConstraints;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::Duration;
9
10/// Complete routing configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RoutingConfig {
13    /// Enable the policy-driven router
14    pub enabled: bool,
15    /// Routing policy configuration
16    pub policy: RoutingPolicyConfig,
17    /// Task classification settings
18    pub classification: TaskClassificationConfig,
19    /// LLM provider configurations
20    pub llm_providers: HashMap<String, LLMProviderConfig>,
21}
22
23/// Core routing policy configuration
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct RoutingPolicyConfig {
26    /// Global routing settings
27    pub global_settings: GlobalRoutingSettings,
28    /// Ordered list of routing rules
29    pub rules: Vec<RoutingRule>,
30    /// Default action when no rules match
31    pub default_action: RouteAction,
32    /// LLM fallback configuration
33    pub fallback_config: FallbackConfig,
34}
35
36/// Global routing settings
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct GlobalRoutingSettings {
39    /// Enable/disable SLM routing globally
40    pub slm_routing_enabled: bool,
41    /// Always audit routing decisions
42    pub always_audit: bool,
43    /// Global confidence threshold for SLM responses
44    pub global_confidence_threshold: f64,
45    /// Maximum retry attempts for failed SLM calls
46    pub max_slm_retries: u32,
47}
48
49/// Individual routing rule
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct RoutingRule {
52    /// Rule identifier
53    pub name: String,
54    /// Rule priority (higher = evaluated first)
55    pub priority: u32,
56    /// Conditions that must be met
57    pub conditions: RoutingConditions,
58    /// Action to take if conditions match
59    pub action: RouteAction,
60    /// Action extensions for additional configuration
61    #[serde(default)]
62    pub action_extension: Option<ActionExtension>,
63    /// Whether this rule can be overridden
64    pub override_allowed: bool,
65}
66
67/// Conditions for routing rules
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct RoutingConditions {
70    /// Task types this rule applies to
71    pub task_types: Option<Vec<TaskType>>,
72    /// Agent IDs this rule applies to
73    pub agent_ids: Option<Vec<String>>,
74    /// Resource requirements
75    pub resource_constraints: Option<ResourceConstraints>,
76    /// Security level requirements
77    pub security_level: Option<SecurityLevel>,
78    /// Custom condition expressions
79    pub custom_conditions: Option<Vec<String>>,
80}
81
82/// Action to take when routing conditions are met
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub enum RouteAction {
85    /// Use SLM with specified preferences
86    UseSLM {
87        model_preference: ModelPreference,
88        monitoring_level: MonitoringLevel,
89        fallback_on_low_confidence: bool,
90        confidence_threshold: Option<f64>,
91    },
92    /// Use LLM provider
93    UseLLM {
94        provider: LLMProvider,
95        model: Option<String>,
96    },
97    /// Deny request
98    Deny { reason: String },
99}
100
101/// Model preference for SLM selection
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum ModelPreference {
104    /// Prefer specialist models for the task type
105    Specialist,
106    /// Prefer general-purpose models
107    Generalist,
108    /// Use specific model by ID
109    Specific { model_id: String },
110    /// Use best available model for requirements
111    BestAvailable,
112}
113
114/// Action extensions for additional routing configuration
115#[derive(Debug, Clone, Serialize, Deserialize, Default)]
116pub struct ActionExtension {
117    /// Preferred sandbox tier for execution
118    pub sandbox: Option<String>,
119}
120
121/// Fallback configuration for LLM providers
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct FallbackConfig {
124    /// Enable fallback mechanism
125    pub enabled: bool,
126    /// Maximum fallback attempts
127    pub max_attempts: u32,
128    /// Timeout for fallback operations
129    #[serde(with = "humantime_serde")]
130    pub timeout: Duration,
131    /// Provider priority order
132    pub providers: HashMap<String, LLMProviderConfig>,
133}
134
135/// LLM provider configuration
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct LLMProviderConfig {
138    /// API key environment variable name
139    pub api_key_env: String,
140    /// Base URL for the provider
141    pub base_url: String,
142    /// Default model for this provider
143    pub default_model: String,
144    /// Request timeout
145    #[serde(with = "humantime_serde")]
146    pub timeout: Duration,
147    /// Maximum retries
148    pub max_retries: u32,
149    /// Rate limiting settings
150    pub rate_limit: Option<RateLimitConfig>,
151}
152
153/// Rate limiting configuration
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct RateLimitConfig {
156    /// Requests per minute
157    pub requests_per_minute: u32,
158    /// Tokens per minute
159    pub tokens_per_minute: Option<u32>,
160    /// Burst allowance
161    pub burst_allowance: Option<u32>,
162}
163
164/// Task classification configuration
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct TaskClassificationConfig {
167    /// Enable automatic task classification
168    pub enabled: bool,
169    /// Classification patterns for different task types
170    pub patterns: HashMap<TaskType, ClassificationPattern>,
171    /// Confidence threshold for classification
172    pub confidence_threshold: f64,
173    /// Fallback task type when classification fails
174    pub default_task_type: TaskType,
175}
176
177/// Pattern for task classification
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ClassificationPattern {
180    /// Keywords that indicate this task type
181    pub keywords: Vec<String>,
182    /// Regex patterns for classification
183    pub patterns: Vec<String>,
184    /// Weight for this pattern
185    pub weight: f64,
186}
187
188impl Default for RoutingConfig {
189    fn default() -> Self {
190        let mut llm_providers = HashMap::new();
191
192        llm_providers.insert(
193            "openai".to_string(),
194            LLMProviderConfig {
195                api_key_env: "OPENAI_API_KEY".to_string(),
196                base_url: "https://api.openai.com/v1".to_string(),
197                default_model: "gpt-3.5-turbo".to_string(),
198                timeout: Duration::from_secs(60),
199                max_retries: 3,
200                rate_limit: Some(RateLimitConfig {
201                    requests_per_minute: 60,
202                    tokens_per_minute: Some(10000),
203                    burst_allowance: Some(10),
204                }),
205            },
206        );
207
208        llm_providers.insert(
209            "anthropic".to_string(),
210            LLMProviderConfig {
211                api_key_env: "ANTHROPIC_API_KEY".to_string(),
212                base_url: "https://api.anthropic.com".to_string(),
213                default_model: "claude-3-sonnet-20240229".to_string(),
214                timeout: Duration::from_secs(60),
215                max_retries: 3,
216                rate_limit: Some(RateLimitConfig {
217                    requests_per_minute: 60,
218                    tokens_per_minute: Some(10000),
219                    burst_allowance: Some(10),
220                }),
221            },
222        );
223
224        Self {
225            enabled: true,
226            policy: RoutingPolicyConfig::default(),
227            classification: TaskClassificationConfig::default(),
228            llm_providers,
229        }
230    }
231}
232
233impl Default for RoutingPolicyConfig {
234    fn default() -> Self {
235        Self {
236            global_settings: GlobalRoutingSettings::default(),
237            rules: Vec::new(),
238            default_action: RouteAction::UseLLM {
239                provider: LLMProvider::OpenAI { model: None },
240                model: Some("gpt-3.5-turbo".to_string()),
241            },
242            fallback_config: FallbackConfig::default(),
243        }
244    }
245}
246
247impl Default for GlobalRoutingSettings {
248    fn default() -> Self {
249        Self {
250            slm_routing_enabled: true,
251            always_audit: true,
252            global_confidence_threshold: 0.85,
253            max_slm_retries: 2,
254        }
255    }
256}
257
258impl Default for FallbackConfig {
259    fn default() -> Self {
260        let mut providers = HashMap::new();
261        providers.insert(
262            "primary".to_string(),
263            LLMProviderConfig {
264                api_key_env: "OPENAI_API_KEY".to_string(),
265                base_url: "https://api.openai.com/v1".to_string(),
266                default_model: "gpt-3.5-turbo".to_string(),
267                timeout: Duration::from_secs(60),
268                max_retries: 3,
269                rate_limit: None,
270            },
271        );
272
273        Self {
274            enabled: true,
275            max_attempts: 3,
276            timeout: Duration::from_secs(30),
277            providers,
278        }
279    }
280}
281
282impl Default for TaskClassificationConfig {
283    fn default() -> Self {
284        let mut patterns = HashMap::new();
285
286        patterns.insert(
287            TaskType::Intent,
288            ClassificationPattern {
289                keywords: vec![
290                    "intent".to_string(),
291                    "intention".to_string(),
292                    "purpose".to_string(),
293                ],
294                patterns: vec![r"what.*intent".to_string(), r"user.*wants".to_string()],
295                weight: 1.0,
296            },
297        );
298
299        patterns.insert(
300            TaskType::CodeGeneration,
301            ClassificationPattern {
302                keywords: vec![
303                    "code".to_string(),
304                    "function".to_string(),
305                    "implement".to_string(),
306                    "generate".to_string(),
307                ],
308                patterns: vec![
309                    r"write.*code".to_string(),
310                    r"implement.*function".to_string(),
311                ],
312                weight: 1.0,
313            },
314        );
315
316        patterns.insert(
317            TaskType::Analysis,
318            ClassificationPattern {
319                keywords: vec![
320                    "analyze".to_string(),
321                    "analysis".to_string(),
322                    "examine".to_string(),
323                    "review".to_string(),
324                ],
325                patterns: vec![
326                    r"analyze.*data".to_string(),
327                    r"perform.*analysis".to_string(),
328                ],
329                weight: 1.0,
330            },
331        );
332
333        Self {
334            enabled: true,
335            patterns,
336            confidence_threshold: 0.7,
337            default_task_type: TaskType::Custom("unknown".to_string()),
338        }
339    }
340}
341
342impl RoutingRule {
343    /// Check if this rule's conditions match the given context
344    pub fn matches(&self, context: &super::decision::RoutingContext) -> bool {
345        // Check task types
346        if let Some(ref task_types) = self.conditions.task_types {
347            if !task_types.contains(&context.task_type) {
348                return false;
349            }
350        }
351
352        // Check agent IDs
353        if let Some(ref agent_ids) = self.conditions.agent_ids {
354            if !agent_ids.contains(&context.agent_id.to_string()) {
355                return false;
356            }
357        }
358
359        // Check security level
360        if let Some(ref required_level) = self.conditions.security_level {
361            if context.agent_security_level < *required_level {
362                return false;
363            }
364        }
365
366        // Check resource constraints
367        if let Some(ref rule_constraints) = self.conditions.resource_constraints {
368            if let Some(ref context_limits) = context.resource_limits {
369                if context_limits.max_memory_mb > rule_constraints.max_memory_mb {
370                    return false;
371                }
372            }
373        }
374
375        // Implement custom condition evaluation
376        if let Some(ref custom_conditions) = self.conditions.custom_conditions {
377            for condition_expr in custom_conditions {
378                if !self.evaluate_custom_condition(condition_expr, context) {
379                    return false;
380                }
381            }
382        }
383
384        true
385    }
386
387    /// Evaluate a custom condition expression
388    fn evaluate_custom_condition(
389        &self,
390        condition_expr: &str,
391        context: &super::decision::RoutingContext,
392    ) -> bool {
393        // Simple expression evaluator for basic conditions
394        // In a real implementation, this could use a proper expression parser
395
396        // Handle common condition patterns
397        if condition_expr.contains("agent_id") {
398            if let Some(expected_id) = condition_expr.strip_prefix("agent_id == ") {
399                let expected_id = expected_id.trim_matches('"');
400                return context.agent_id.to_string() == expected_id;
401            }
402        }
403
404        if condition_expr.contains("task_type") {
405            if let Some(expected_type) = condition_expr.strip_prefix("task_type == ") {
406                let expected_type = expected_type.trim_matches('"');
407                return format!("{:?}", context.task_type)
408                    .to_lowercase()
409                    .contains(&expected_type.to_lowercase());
410            }
411        }
412
413        if condition_expr.contains("security_level") && condition_expr.contains(">=") {
414            if let Some(level_str) = condition_expr.strip_prefix("security_level >= ") {
415                if let Ok(required_level) = level_str.trim().parse::<u8>() {
416                    let current_level = match context.agent_security_level {
417                        SecurityLevel::Low => 1,
418                        SecurityLevel::Medium => 2,
419                        SecurityLevel::High => 3,
420                        SecurityLevel::Critical => 4,
421                    };
422                    return current_level >= required_level;
423                }
424            }
425        }
426
427        if condition_expr.contains("memory_limit") {
428            if let Some(ref resource_limits) = context.resource_limits {
429                if condition_expr.contains("<=") {
430                    if let Some(limit_str) = condition_expr.strip_prefix("memory_limit <= ") {
431                        if let Ok(max_memory) = limit_str.trim().parse::<u64>() {
432                            return resource_limits.max_memory_mb <= max_memory;
433                        }
434                    }
435                }
436            }
437        }
438
439        // Handle boolean expressions
440        if condition_expr == "true" {
441            return true;
442        }
443        if condition_expr == "false" {
444            return false;
445        }
446
447        // Default: log unrecognized condition and return true to be permissive
448        tracing::warn!("Unrecognized custom condition: {}", condition_expr);
449        true
450    }
451}
452
453impl RoutingPolicyConfig {
454    /// Validate the routing policy configuration
455    pub fn validate(&self) -> Result<(), super::error::RoutingError> {
456        // Validate rules are sorted by priority
457        let mut prev_priority = u32::MAX;
458        for rule in &self.rules {
459            if rule.priority > prev_priority {
460                return Err(super::error::RoutingError::ConfigurationError {
461                    key: "policy.rules".to_string(),
462                    reason: "Rules must be ordered by priority (highest first)".to_string(),
463                });
464            }
465            prev_priority = rule.priority;
466        }
467
468        // Validate confidence thresholds
469        if self.global_settings.global_confidence_threshold < 0.0
470            || self.global_settings.global_confidence_threshold > 1.0
471        {
472            return Err(super::error::RoutingError::ConfigurationError {
473                key: "policy.global_settings.global_confidence_threshold".to_string(),
474                reason: "Confidence threshold must be between 0.0 and 1.0".to_string(),
475            });
476        }
477
478        Ok(())
479    }
480}