Skip to main content

symbi_runtime/routing/
engine.rs

1//! Core routing engine implementation
2
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8
9use super::classifier::TaskClassifier;
10use super::confidence::{ConfidenceMonitorTrait, NoOpConfidenceMonitor};
11use super::config::RoutingConfig;
12use super::decision::{
13    FinishReason, LLMProvider, ModelRequest, ModelResponse, RouteDecision, RoutingContext,
14    RoutingStatistics,
15};
16use super::error::RoutingError;
17use super::policy::PolicyEvaluator;
18use crate::logging::{
19    ModelInteractionType, ModelLogger, RequestData, ResponseData, TokenUsage as LogTokenUsage,
20};
21use crate::models::{ModelCatalog, SlmRunnerError};
22
23/// Core routing engine trait for SLM-first architecture
24#[async_trait]
25pub trait RoutingEngine: Send + Sync {
26    /// Route a model request based on configured policies
27    async fn route_request(&self, context: &RoutingContext) -> Result<RouteDecision, RoutingError>;
28
29    /// Execute the routing decision and handle fallbacks
30    async fn execute_with_routing(
31        &self,
32        context: RoutingContext,
33        request: ModelRequest,
34    ) -> Result<ModelResponse, RoutingError>;
35
36    /// Validate routing policies
37    fn validate_policies(&self) -> Result<(), RoutingError>;
38
39    /// Get routing statistics
40    async fn get_routing_stats(&self) -> RoutingStatistics;
41
42    /// Update routing configuration
43    async fn update_config(&self, config: RoutingConfig) -> Result<(), RoutingError>;
44}
45
46/// Default implementation of the routing engine
47pub struct DefaultRoutingEngine {
48    /// Policy evaluator for making routing decisions
49    policy_evaluator: Arc<RwLock<PolicyEvaluator>>,
50    /// Model catalog for SLM information
51    model_catalog: Arc<ModelCatalog>,
52    /// Confidence monitor for evaluating SLM responses
53    #[allow(dead_code)]
54    confidence_monitor: Arc<RwLock<Box<dyn ConfidenceMonitorTrait>>>,
55    /// Optional model logger for audit trails
56    model_logger: Option<Arc<ModelLogger>>,
57    /// Routing statistics (lock-free atomic counters)
58    statistics: Arc<RoutingStatistics>,
59    /// Configuration
60    config: Arc<arc_swap::ArcSwap<RoutingConfig>>,
61    /// LLM client pool for fallback
62    llm_clients: Arc<LLMClientPool>,
63    /// SLM executor for local model inference
64    slm_executor: Arc<dyn SlmExecutor>,
65}
66
67/// Trait for SLM execution — implement for real model inference
68#[async_trait]
69pub trait SlmExecutor: Send + Sync {
70    async fn execute(
71        &self,
72        request: &ModelRequest,
73        model: &crate::config::Model,
74    ) -> Result<ModelResponse, SlmRunnerError>;
75}
76
77/// Trait for LLM clients
78#[async_trait]
79pub trait LLMClient: Send + Sync {
80    async fn execute_request(
81        &self,
82        request: &ModelRequest,
83        provider: &LLMProvider,
84    ) -> Result<ModelResponse, RoutingError>;
85}
86
87/// Pool of LLM clients for different providers
88pub struct LLMClientPool {
89    clients: HashMap<String, Box<dyn LLMClient>>,
90}
91
92impl LLMClientPool {
93    pub fn new() -> Self {
94        Self {
95            clients: HashMap::new(),
96        }
97    }
98
99    pub fn register(&mut self, key: &str, client: Box<dyn LLMClient>) {
100        self.clients.insert(key.to_string(), client);
101    }
102
103    pub async fn execute_request(
104        &self,
105        request: &ModelRequest,
106        provider: &LLMProvider,
107    ) -> Result<ModelResponse, RoutingError> {
108        let client_key = match provider {
109            LLMProvider::OpenAI { .. } => "openai",
110            LLMProvider::Anthropic { .. } => "anthropic",
111            LLMProvider::Custom { .. } => "custom",
112        };
113
114        let client =
115            self.clients
116                .get(client_key)
117                .ok_or_else(|| RoutingError::LLMFallbackFailed {
118                    provider: provider.to_string(),
119                    reason: format!("No client registered for provider '{}'", client_key),
120                })?;
121
122        client.execute_request(request, provider).await
123    }
124}
125
126impl Default for LLMClientPool {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132impl DefaultRoutingEngine {
133    /// Create a new routing engine with the given configuration
134    pub async fn new(
135        config: RoutingConfig,
136        model_catalog: ModelCatalog,
137        model_logger: Option<Arc<ModelLogger>>,
138        llm_clients: LLMClientPool,
139        slm_executor: Arc<dyn SlmExecutor>,
140    ) -> Result<Self, RoutingError> {
141        Self::new_with_confidence_monitor(
142            config,
143            model_catalog,
144            model_logger,
145            Box::new(NoOpConfidenceMonitor),
146            llm_clients,
147            slm_executor,
148        )
149        .await
150    }
151
152    /// Create a new routing engine with a custom confidence monitor implementation
153    pub async fn new_with_confidence_monitor(
154        config: RoutingConfig,
155        model_catalog: ModelCatalog,
156        model_logger: Option<Arc<ModelLogger>>,
157        confidence_monitor: Box<dyn ConfidenceMonitorTrait>,
158        llm_clients: LLMClientPool,
159        slm_executor: Arc<dyn SlmExecutor>,
160    ) -> Result<Self, RoutingError> {
161        // Create task classifier
162        let classifier = TaskClassifier::new(config.classification.clone())?;
163
164        // Create policy evaluator
165        let policy_evaluator =
166            PolicyEvaluator::new(config.policy.clone(), classifier, model_catalog.clone())?;
167
168        let llm_clients = Arc::new(llm_clients);
169
170        let engine = Self {
171            policy_evaluator: Arc::new(RwLock::new(policy_evaluator)),
172            model_catalog: Arc::new(model_catalog),
173            confidence_monitor: Arc::new(RwLock::new(confidence_monitor)),
174            model_logger,
175            statistics: Arc::new(RoutingStatistics::default()),
176            config: Arc::new(arc_swap::ArcSwap::from_pointee(config)),
177            llm_clients,
178            slm_executor,
179        };
180
181        Ok(engine)
182    }
183
184    /// Execute an SLM route with monitoring and fallback
185    async fn execute_slm_route(
186        &self,
187        context: &RoutingContext,
188        request: &ModelRequest,
189        model_id: &str,
190        monitoring_level: &super::decision::MonitoringLevel,
191        fallback_on_failure: bool,
192    ) -> Result<ModelResponse, RoutingError> {
193        // Get the model from catalog
194        let model = self.model_catalog.get_model(model_id).ok_or_else(|| {
195            RoutingError::NoSuitableModel {
196                task_type: context.task_type.clone(),
197            }
198        })?;
199
200        // Execute the SLM via the injected executor
201        let slm_result = self.slm_executor.execute(request, model).await;
202
203        match slm_result {
204            Ok(response) => {
205                // Evaluate confidence if monitoring is enabled
206                let should_fallback = match monitoring_level {
207                    super::decision::MonitoringLevel::None => false,
208                    super::decision::MonitoringLevel::Basic => {
209                        // Basic monitoring - check for obvious failures
210                        response.finish_reason != FinishReason::Stop
211                    }
212                    super::decision::MonitoringLevel::Enhanced {
213                        confidence_threshold,
214                    } => {
215                        // Use confidence score if available
216                        let confidence_score = response.confidence_score.unwrap_or(0.5);
217                        confidence_score < *confidence_threshold
218                    }
219                };
220
221                if should_fallback && fallback_on_failure {
222                    tracing::warn!(
223                        "SLM response did not meet confidence threshold, falling back to LLM"
224                    );
225
226                    self.fallback_to_llm(request, "Low confidence SLM response")
227                        .await
228                } else {
229                    Ok(response)
230                }
231            }
232            Err(e) => {
233                if fallback_on_failure {
234                    tracing::error!("SLM execution failed, falling back to LLM: {}", e);
235
236                    self.fallback_to_llm(request, &format!("SLM execution failed: {}", e))
237                        .await
238                } else {
239                    Err(RoutingError::ModelExecutionFailed {
240                        model_id: model_id.to_string(),
241                        reason: e.to_string(),
242                    })
243                }
244            }
245        }
246    }
247
248    /// Execute LLM fallback and record the fallback in statistics
249    async fn fallback_to_llm(
250        &self,
251        request: &ModelRequest,
252        reason: &str,
253    ) -> Result<ModelResponse, RoutingError> {
254        self.statistics.record_fallback();
255        self.execute_llm_fallback(request, reason).await
256    }
257
258    /// Execute LLM fallback
259    async fn execute_llm_fallback(
260        &self,
261        request: &ModelRequest,
262        _reason: &str,
263    ) -> Result<ModelResponse, RoutingError> {
264        let config = self.config.load();
265        let fallback_config = &config.policy.fallback_config;
266
267        if !fallback_config.enabled {
268            return Err(RoutingError::LLMFallbackFailed {
269                provider: "disabled".to_string(),
270                reason: "LLM fallback is disabled".to_string(),
271            });
272        }
273
274        // Try primary provider first
275        let provider = LLMProvider::OpenAI { model: None };
276
277        match self.llm_clients.execute_request(request, &provider).await {
278            Ok(response) => Ok(response),
279            Err(e) => Err(RoutingError::LLMFallbackFailed {
280                provider: provider.to_string(),
281                reason: e.to_string(),
282            }),
283        }
284    }
285
286    /// Log routing decision and execution
287    async fn log_routing_execution(
288        &self,
289        context: &RoutingContext,
290        decision: &RouteDecision,
291        request: &ModelRequest,
292        response: &ModelResponse,
293        execution_time: Duration,
294        error: Option<&RoutingError>,
295    ) {
296        if let Some(ref logger) = self.model_logger {
297            let model_name = match decision {
298                RouteDecision::UseSLM { model_id, .. } => model_id.clone(),
299                RouteDecision::UseLLM { provider, .. } => provider.to_string(),
300                RouteDecision::Deny { .. } => "denied".to_string(),
301            };
302
303            let request_data = RequestData {
304                prompt: request.prompt.clone(),
305                tool_name: None,
306                tool_arguments: None,
307                parameters: {
308                    let mut params = HashMap::new();
309                    params.insert(
310                        "routing_decision".to_string(),
311                        serde_json::Value::String(format!("{:?}", decision)),
312                    );
313                    params.insert(
314                        "task_type".to_string(),
315                        serde_json::Value::String(context.task_type.to_string()),
316                    );
317                    params
318                },
319            };
320
321            let response_data = ResponseData {
322                content: response.content.clone(),
323                tool_result: None,
324                confidence: response.confidence_score,
325                metadata: response.metadata.clone(),
326            };
327
328            let metadata = {
329                let mut meta = HashMap::new();
330                meta.insert("routing_engine".to_string(), "default".to_string());
331                meta.insert("agent_id".to_string(), context.agent_id.to_string());
332                meta.insert("request_id".to_string(), context.request_id.clone());
333                meta
334            };
335
336            if let Err(e) = logger
337                .log_interaction(
338                    context.agent_id,
339                    ModelInteractionType::Completion,
340                    &model_name,
341                    request_data,
342                    response_data,
343                    execution_time,
344                    metadata,
345                    response.token_usage.as_ref().map(|t| LogTokenUsage {
346                        input_tokens: t.prompt_tokens,
347                        output_tokens: t.completion_tokens,
348                        total_tokens: t.total_tokens,
349                    }),
350                    error.map(|e| e.to_string()),
351                )
352                .await
353            {
354                tracing::warn!("Failed to log routing execution: {}", e);
355            }
356        }
357    }
358}
359
360#[async_trait]
361impl RoutingEngine for DefaultRoutingEngine {
362    async fn route_request(&self, context: &RoutingContext) -> Result<RouteDecision, RoutingError> {
363        let policy_result = self
364            .policy_evaluator
365            .read()
366            .await
367            .evaluate_policies(context)
368            .await?;
369
370        tracing::debug!(
371            "Routing decision for agent {}: {:?} (rule: {:?})",
372            context.agent_id,
373            policy_result.decision,
374            policy_result.matched_rule
375        );
376
377        Ok(policy_result.decision)
378    }
379
380    async fn execute_with_routing(
381        &self,
382        context: RoutingContext,
383        request: ModelRequest,
384    ) -> Result<ModelResponse, RoutingError> {
385        let start_time = Instant::now();
386        let route_decision = self.route_request(&context).await?;
387
388        let result = match &route_decision {
389            RouteDecision::UseSLM {
390                model_id,
391                monitoring,
392                fallback_on_failure,
393                sandbox_tier: _,
394            } => {
395                self.execute_slm_route(
396                    &context,
397                    &request,
398                    model_id,
399                    monitoring,
400                    *fallback_on_failure,
401                )
402                .await
403            }
404            RouteDecision::UseLLM {
405                provider,
406                reason,
407                sandbox_tier: _,
408            } => {
409                tracing::info!("Routing to LLM: {}", reason);
410                self.llm_clients.execute_request(&request, provider).await
411            }
412            RouteDecision::Deny {
413                reason,
414                policy_violated,
415            } => {
416                return Err(RoutingError::RoutingDenied {
417                    policy: policy_violated.clone(),
418                    reason: reason.clone(),
419                });
420            }
421        };
422
423        let execution_time = start_time.elapsed();
424
425        // Update statistics (lock-free)
426        self.statistics
427            .record_request(&route_decision, execution_time, result.is_ok());
428
429        if let Ok(ref response) = result {
430            if let Some(confidence) = response.confidence_score {
431                self.statistics.add_confidence_score(confidence);
432            }
433        }
434
435        // Log the execution
436        match &result {
437            Ok(response) => {
438                self.log_routing_execution(
439                    &context,
440                    &route_decision,
441                    &request,
442                    response,
443                    execution_time,
444                    None,
445                )
446                .await;
447            }
448            Err(error) => {
449                // Create a dummy response for logging
450                let dummy_response = ModelResponse {
451                    content: "Error occurred".to_string(),
452                    finish_reason: FinishReason::Error,
453                    token_usage: None,
454                    metadata: HashMap::new(),
455                    confidence_score: Some(0.0),
456                };
457
458                self.log_routing_execution(
459                    &context,
460                    &route_decision,
461                    &request,
462                    &dummy_response,
463                    execution_time,
464                    Some(error),
465                )
466                .await;
467            }
468        }
469
470        result
471    }
472
473    fn validate_policies(&self) -> Result<(), RoutingError> {
474        // Validation is done during PolicyEvaluator creation
475        Ok(())
476    }
477
478    async fn get_routing_stats(&self) -> RoutingStatistics {
479        (*self.statistics).clone()
480    }
481
482    async fn update_config(&self, config: RoutingConfig) -> Result<(), RoutingError> {
483        // Update policy evaluator
484        self.policy_evaluator
485            .write()
486            .await
487            .update_config(config.policy.clone())?;
488
489        // Update configuration (swap atomically)
490        self.config.store(Arc::new(config));
491
492        Ok(())
493    }
494}
495
496#[cfg(test)]
497use super::decision::TokenUsage;
498
499#[cfg(test)]
500struct MockSlmExecutor;
501
502#[cfg(test)]
503#[async_trait]
504impl SlmExecutor for MockSlmExecutor {
505    async fn execute(
506        &self,
507        request: &ModelRequest,
508        model: &crate::config::Model,
509    ) -> Result<ModelResponse, SlmRunnerError> {
510        // Reduce delay from 200ms to 10ms for faster tests
511        tokio::time::sleep(Duration::from_millis(10)).await;
512
513        if request.prompt.contains("error") {
514            return Err(SlmRunnerError::ExecutionFailed {
515                reason: "Simulated execution error".to_string(),
516            });
517        }
518
519        Ok(ModelResponse {
520            content: format!("SLM ({}) response: {}", model.name, request.prompt),
521            finish_reason: FinishReason::Stop,
522            token_usage: Some(TokenUsage {
523                prompt_tokens: request.prompt.len() as u32 / 4,
524                completion_tokens: 30,
525                total_tokens: (request.prompt.len() as u32 / 4) + 30,
526            }),
527            metadata: {
528                let mut meta = HashMap::new();
529                meta.insert(
530                    "model_id".to_string(),
531                    serde_json::Value::String(model.id.clone()),
532                );
533                meta.insert(
534                    "provider".to_string(),
535                    serde_json::Value::String(format!("{:?}", model.provider)),
536                );
537                meta
538            },
539            confidence_score: Some(0.8 + (request.prompt.len() % 20) as f64 / 100.0),
540        })
541    }
542}
543
544#[cfg(test)]
545#[derive(Debug)]
546struct MockLLMClient;
547
548#[cfg(test)]
549#[async_trait]
550impl LLMClient for MockLLMClient {
551    async fn execute_request(
552        &self,
553        request: &ModelRequest,
554        provider: &LLMProvider,
555    ) -> Result<ModelResponse, RoutingError> {
556        tokio::time::sleep(Duration::from_millis(100)).await;
557
558        Ok(ModelResponse {
559            content: format!("LLM response to: {}", request.prompt),
560            finish_reason: FinishReason::Stop,
561            token_usage: Some(TokenUsage {
562                prompt_tokens: request.prompt.len() as u32 / 4,
563                completion_tokens: 50,
564                total_tokens: (request.prompt.len() as u32 / 4) + 50,
565            }),
566            metadata: {
567                let mut meta = HashMap::new();
568                meta.insert(
569                    "provider".to_string(),
570                    serde_json::Value::String(provider.to_string()),
571                );
572                meta.insert("mock".to_string(), serde_json::Value::Bool(true));
573                meta
574            },
575            confidence_score: Some(0.95),
576        })
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583    use crate::config::{
584        Model, ModelAllowListConfig, ModelProvider, ModelResourceRequirements, SandboxProfile, Slm,
585    };
586    use crate::types::AgentId;
587    use std::collections::HashMap;
588    use std::path::PathBuf;
589
590    fn create_mock_pool() -> LLMClientPool {
591        let mut pool = LLMClientPool::new();
592        pool.register("openai", Box::new(MockLLMClient));
593        pool.register("anthropic", Box::new(MockLLMClient));
594        pool.register("custom", Box::new(MockLLMClient));
595        pool
596    }
597
598    fn create_mock_slm() -> Arc<dyn SlmExecutor> {
599        Arc::new(MockSlmExecutor)
600    }
601
602    async fn create_test_engine() -> DefaultRoutingEngine {
603        let global_models = vec![
604            Model {
605                id: "test-slm".to_string(),
606                name: "Test SLM".to_string(),
607                provider: ModelProvider::LocalFile {
608                    file_path: PathBuf::from("/tmp/test.gguf"),
609                },
610                capabilities: vec![
611                    crate::config::ModelCapability::TextGeneration,
612                    crate::config::ModelCapability::CodeGeneration,
613                ],
614                resource_requirements: ModelResourceRequirements {
615                    min_memory_mb: 1024,
616                    preferred_cpu_cores: 2.0,
617                    gpu_requirements: None,
618                },
619            },
620            Model {
621                id: "error-slm".to_string(),
622                name: "Error SLM".to_string(),
623                provider: ModelProvider::LocalFile {
624                    file_path: PathBuf::from("/tmp/error.gguf"),
625                },
626                capabilities: vec![crate::config::ModelCapability::TextGeneration],
627                resource_requirements: ModelResourceRequirements {
628                    min_memory_mb: 512,
629                    preferred_cpu_cores: 1.0,
630                    gpu_requirements: None,
631                },
632            },
633        ];
634
635        let mut sandbox_profiles = HashMap::new();
636        sandbox_profiles.insert("default".to_string(), SandboxProfile::secure_default());
637
638        let slm_config = Slm {
639            enabled: true,
640            model_allow_lists: ModelAllowListConfig {
641                global_models,
642                agent_model_maps: HashMap::new(),
643                allow_runtime_overrides: false,
644            },
645            sandbox_profiles,
646            default_sandbox_profile: "default".to_string(),
647        };
648
649        let model_catalog = ModelCatalog::new(slm_config).unwrap();
650        let mut config = RoutingConfig::default();
651
652        // Add SLM routing rule for code generation tasks
653        config.policy.rules.push(super::super::config::RoutingRule {
654            name: "slm_code_rule".to_string(),
655            priority: 100,
656            conditions: super::super::config::RoutingConditions {
657                task_types: Some(vec![super::super::error::TaskType::CodeGeneration]),
658                agent_ids: None,
659                resource_constraints: None,
660                security_level: None,
661                custom_conditions: None,
662            },
663            action: super::super::config::RouteAction::UseSLM {
664                model_preference: super::super::config::ModelPreference::BestAvailable,
665                monitoring_level: crate::routing::MonitoringLevel::Basic,
666                fallback_on_low_confidence: true,
667                confidence_threshold: Some(0.8),
668            },
669            override_allowed: true,
670            action_extension: None,
671        });
672
673        DefaultRoutingEngine::new(
674            config,
675            model_catalog,
676            None,
677            create_mock_pool(),
678            create_mock_slm(),
679        )
680        .await
681        .unwrap()
682    }
683
684    async fn create_test_engine_with_logger() -> DefaultRoutingEngine {
685        let logger =
686            ModelLogger::new(super::super::super::logging::LoggingConfig::default(), None).unwrap();
687
688        let global_models = vec![Model {
689            id: "test-slm".to_string(),
690            name: "Test SLM".to_string(),
691            provider: ModelProvider::LocalFile {
692                file_path: PathBuf::from("/tmp/test.gguf"),
693            },
694            capabilities: vec![crate::config::ModelCapability::TextGeneration],
695            resource_requirements: ModelResourceRequirements {
696                min_memory_mb: 1024,
697                preferred_cpu_cores: 2.0,
698                gpu_requirements: None,
699            },
700        }];
701
702        let mut sandbox_profiles = HashMap::new();
703        sandbox_profiles.insert("default".to_string(), SandboxProfile::secure_default());
704
705        let slm_config = Slm {
706            enabled: true,
707            model_allow_lists: ModelAllowListConfig {
708                global_models,
709                agent_model_maps: HashMap::new(),
710                allow_runtime_overrides: false,
711            },
712            sandbox_profiles,
713            default_sandbox_profile: "default".to_string(),
714        };
715
716        let model_catalog = ModelCatalog::new(slm_config).unwrap();
717        let config = RoutingConfig::default();
718
719        DefaultRoutingEngine::new(
720            config,
721            model_catalog,
722            Some(Arc::new(logger)),
723            create_mock_pool(),
724            create_mock_slm(),
725        )
726        .await
727        .unwrap()
728    }
729
730    fn create_test_request(prompt: &str) -> ModelRequest {
731        ModelRequest::from_task(prompt.to_string())
732    }
733
734    fn create_test_context(
735        prompt: &str,
736        task_type: super::super::error::TaskType,
737    ) -> RoutingContext {
738        RoutingContext::new(AgentId::new(), task_type, prompt.to_string())
739    }
740
741    #[tokio::test]
742    async fn test_routing_engine_creation() {
743        let engine = create_test_engine().await;
744
745        // Verify engine was created successfully
746        let stats = engine.get_routing_stats().await;
747        assert_eq!(stats.total_requests(), 0);
748
749        // Verify policies can be validated
750        assert!(engine.validate_policies().is_ok());
751    }
752
753    #[tokio::test]
754    async fn test_routing_engine_with_logger() {
755        let engine = create_test_engine_with_logger().await;
756
757        // Should have logger configured
758        assert!(engine.model_logger.is_some());
759
760        let stats = engine.get_routing_stats().await;
761        assert_eq!(stats.total_requests(), 0);
762    }
763
764    #[tokio::test]
765    async fn test_routing_engine_basic_flow() {
766        let engine = create_test_engine().await;
767
768        let context = create_test_context(
769            "Write a hello world function",
770            super::super::error::TaskType::CodeGeneration,
771        );
772
773        let decision = engine.route_request(&context).await.unwrap();
774
775        // Should get some kind of routing decision
776        match decision {
777            RouteDecision::UseSLM { .. } | RouteDecision::UseLLM { .. } => {
778                // Expected outcomes
779            }
780            RouteDecision::Deny { .. } => {
781                panic!("Should not deny basic request");
782            }
783        }
784    }
785
786    #[tokio::test]
787    async fn test_execute_with_routing_slm_success() {
788        let engine = create_test_engine().await;
789
790        let context = create_test_context(
791            "Write a hello world function",
792            super::super::error::TaskType::CodeGeneration,
793        );
794
795        let request = create_test_request("Write a hello world function");
796
797        let response = engine.execute_with_routing(context, request).await.unwrap();
798
799        assert!(!response.content.is_empty());
800        assert!(response.confidence_score.is_some());
801        assert_eq!(response.finish_reason, FinishReason::Stop);
802
803        // Check that statistics were updated
804        let stats = engine.get_routing_stats().await;
805        assert!(stats.total_requests() > 0);
806    }
807
808    #[tokio::test]
809    async fn test_execute_with_routing_slm_error_fallback() {
810        let engine = create_test_engine().await;
811
812        let context = create_test_context(
813            "This should trigger an error in SLM",
814            super::super::error::TaskType::CodeGeneration,
815        );
816
817        let request = create_test_request("error trigger");
818
819        let response = engine.execute_with_routing(context, request).await.unwrap();
820
821        // Should get LLM fallback response
822        assert!(!response.content.is_empty());
823        assert!(response.content.contains("LLM response"));
824
825        let stats = engine.get_routing_stats().await;
826        assert!(stats.fallback_routes() > 0);
827    }
828
829    #[tokio::test]
830    async fn test_slm_execution_success() {
831        let engine = create_test_engine().await;
832
833        let context =
834            create_test_context("Test prompt", super::super::error::TaskType::CodeGeneration);
835
836        let request = create_test_request("Test prompt");
837
838        let response = engine
839            .execute_slm_route(
840                &context,
841                &request,
842                "test-slm",
843                &crate::routing::MonitoringLevel::Basic,
844                true,
845            )
846            .await
847            .unwrap();
848
849        assert!(!response.content.is_empty());
850        assert!(response.content.contains("Test SLM"));
851        assert!(response.confidence_score.is_some());
852    }
853
854    #[tokio::test]
855    async fn test_slm_execution_with_enhanced_monitoring() {
856        let engine = create_test_engine().await;
857
858        let context = create_test_context(
859            "Test prompt with monitoring",
860            super::super::error::TaskType::CodeGeneration,
861        );
862
863        let request = create_test_request("Test prompt with monitoring");
864
865        let response = engine
866            .execute_slm_route(
867                &context,
868                &request,
869                "test-slm",
870                &crate::routing::MonitoringLevel::Enhanced {
871                    confidence_threshold: 0.9,
872                },
873                true,
874            )
875            .await
876            .unwrap();
877
878        // Should either get SLM response or LLM fallback
879        assert!(!response.content.is_empty());
880    }
881
882    #[tokio::test]
883    async fn test_slm_execution_no_monitoring() {
884        let engine = create_test_engine().await;
885
886        let context = create_test_context(
887            "Test prompt no monitoring",
888            super::super::error::TaskType::CodeGeneration,
889        );
890
891        let request = create_test_request("Test prompt no monitoring");
892
893        let response = engine
894            .execute_slm_route(
895                &context,
896                &request,
897                "test-slm",
898                &crate::routing::MonitoringLevel::None,
899                true,
900            )
901            .await
902            .unwrap();
903
904        // Should get SLM response without monitoring
905        assert!(!response.content.is_empty());
906        assert!(response.content.contains("Test SLM"));
907    }
908
909    #[tokio::test]
910    async fn test_slm_execution_error_no_fallback() {
911        let engine = create_test_engine().await;
912
913        let context = create_test_context(
914            "error trigger",
915            super::super::error::TaskType::CodeGeneration,
916        );
917
918        let request = create_test_request("error trigger");
919
920        let result = engine
921            .execute_slm_route(
922                &context,
923                &request,
924                "test-slm",
925                &crate::routing::MonitoringLevel::Basic,
926                false, // No fallback
927            )
928            .await;
929
930        assert!(result.is_err());
931        assert!(matches!(
932            result.unwrap_err(),
933            RoutingError::ModelExecutionFailed { .. }
934        ));
935    }
936
937    #[tokio::test]
938    async fn test_slm_execution_nonexistent_model() {
939        let engine = create_test_engine().await;
940
941        let context =
942            create_test_context("Test prompt", super::super::error::TaskType::CodeGeneration);
943
944        let request = create_test_request("Test prompt");
945
946        let result = engine
947            .execute_slm_route(
948                &context,
949                &request,
950                "nonexistent-model",
951                &crate::routing::MonitoringLevel::Basic,
952                true,
953            )
954            .await;
955
956        assert!(result.is_err());
957        assert!(matches!(
958            result.unwrap_err(),
959            RoutingError::NoSuitableModel { .. }
960        ));
961    }
962
963    #[tokio::test]
964    async fn test_llm_fallback_execution() {
965        let engine = create_test_engine().await;
966
967        let request = create_test_request("Test LLM fallback");
968
969        let response = engine
970            .execute_llm_fallback(&request, "Test reason")
971            .await
972            .unwrap();
973
974        assert!(!response.content.is_empty());
975        assert!(response.content.contains("LLM response"));
976        assert_eq!(response.finish_reason, FinishReason::Stop);
977        assert!(response.confidence_score.is_some());
978    }
979
980    #[tokio::test]
981    async fn test_llm_fallback_disabled() {
982        let engine = create_test_engine().await;
983
984        // Disable fallback in config
985        {
986            let mut config = (*engine.config.load().as_ref()).clone();
987            config.policy.fallback_config.enabled = false;
988            engine.config.store(Arc::new(config));
989        }
990
991        let request = create_test_request("Test disabled fallback");
992
993        let result = engine.execute_llm_fallback(&request, "Test reason").await;
994
995        assert!(result.is_err());
996        assert!(matches!(
997            result.unwrap_err(),
998            RoutingError::LLMFallbackFailed { .. }
999        ));
1000    }
1001
1002    #[tokio::test]
1003    async fn test_llm_client_pool() {
1004        let pool = create_mock_pool();
1005
1006        let request = create_test_request("Test LLM client");
1007
1008        // Test OpenAI provider
1009        let openai_response = pool
1010            .execute_request(
1011                &request,
1012                &LLMProvider::OpenAI {
1013                    model: Some("gpt-3.5-turbo".to_string()),
1014                },
1015            )
1016            .await
1017            .unwrap();
1018
1019        assert!(!openai_response.content.is_empty());
1020        assert!(openai_response.metadata.contains_key("provider"));
1021
1022        // Test Anthropic provider
1023        let anthropic_response = pool
1024            .execute_request(
1025                &request,
1026                &LLMProvider::Anthropic {
1027                    model: Some("claude-3".to_string()),
1028                },
1029            )
1030            .await
1031            .unwrap();
1032
1033        assert!(!anthropic_response.content.is_empty());
1034        assert!(anthropic_response.metadata.contains_key("provider"));
1035
1036        // Test Custom provider
1037        let custom_response = pool
1038            .execute_request(
1039                &request,
1040                &LLMProvider::Custom {
1041                    endpoint: "http://localhost:8080".to_string(),
1042                    model: None,
1043                },
1044            )
1045            .await
1046            .unwrap();
1047
1048        assert!(!custom_response.content.is_empty());
1049    }
1050
1051    #[tokio::test]
1052    async fn test_routing_statistics_tracking() {
1053        let engine = create_test_engine().await;
1054
1055        // Execute a few requests to track statistics
1056        let context1 = create_test_context("Test 1", super::super::error::TaskType::CodeGeneration);
1057        let request1 = create_test_request("Test 1");
1058
1059        let _response1 = engine
1060            .execute_with_routing(context1, request1)
1061            .await
1062            .unwrap();
1063
1064        let context2 = create_test_context(
1065            "error trigger",
1066            super::super::error::TaskType::CodeGeneration,
1067        );
1068        let request2 = create_test_request("error trigger");
1069
1070        let _response2 = engine
1071            .execute_with_routing(context2, request2)
1072            .await
1073            .unwrap();
1074
1075        let stats = engine.get_routing_stats().await;
1076
1077        assert!(stats.total_requests() > 0);
1078        assert!(stats.fallback_routes() > 0); // Second request should trigger fallback
1079        assert!(stats.average_response_time() > Duration::ZERO);
1080    }
1081
1082    #[tokio::test]
1083    async fn test_config_update() {
1084        let engine = create_test_engine().await;
1085
1086        let mut new_config = RoutingConfig::default();
1087        new_config.policy.fallback_config.enabled = false;
1088
1089        let result = engine.update_config(new_config.clone()).await;
1090        assert!(result.is_ok());
1091
1092        // Verify config was updated
1093        let updated_config = engine.config.load();
1094        assert!(!updated_config.policy.fallback_config.enabled);
1095    }
1096
1097    #[tokio::test]
1098    async fn test_routing_with_deny_decision() {
1099        let engine = create_test_engine().await;
1100
1101        // Create a routing context that would trigger a deny decision
1102        // (This would need specific policy configuration to work in practice)
1103        let context = create_test_context(
1104            "forbidden operation",
1105            super::super::error::TaskType::Custom("forbidden".to_string()),
1106        );
1107
1108        let request = create_test_request("forbidden operation");
1109
1110        // This might not trigger a deny in the default config, but test the error handling
1111        let result = engine.execute_with_routing(context, request).await;
1112
1113        // Should either succeed with a response or fail with specific error
1114        match result {
1115            Ok(response) => {
1116                assert!(!response.content.is_empty());
1117            }
1118            Err(RoutingError::RoutingDenied { .. }) => {
1119                // Expected for deny decision
1120            }
1121            Err(e) => {
1122                panic!("Unexpected error: {:?}", e);
1123            }
1124        }
1125    }
1126
1127    #[tokio::test]
1128    async fn test_logging_integration() {
1129        let engine = create_test_engine_with_logger().await;
1130
1131        let context = create_test_context(
1132            "Test logging integration",
1133            super::super::error::TaskType::CodeGeneration,
1134        );
1135
1136        let request = create_test_request("Test logging integration");
1137
1138        let response = engine.execute_with_routing(context, request).await.unwrap();
1139
1140        assert!(!response.content.is_empty());
1141        // Logging should happen in the background without affecting the response
1142    }
1143
1144    #[tokio::test]
1145    async fn test_confidence_monitoring_integration() {
1146        let engine = create_test_engine().await;
1147
1148        let context = create_test_context(
1149            "Test confidence monitoring",
1150            super::super::error::TaskType::CodeGeneration,
1151        );
1152
1153        let request = create_test_request("Test confidence monitoring");
1154
1155        let response = engine.execute_with_routing(context, request).await.unwrap();
1156
1157        assert!(!response.content.is_empty());
1158        assert!(response.confidence_score.is_some());
1159    }
1160
1161    #[tokio::test]
1162    async fn test_policy_evaluation_integration() {
1163        let engine = create_test_engine().await;
1164
1165        // Test different task types to ensure policy evaluation works
1166        let task_types = vec![
1167            super::super::error::TaskType::CodeGeneration,
1168            super::super::error::TaskType::CodeGeneration,
1169            super::super::error::TaskType::Analysis,
1170            super::super::error::TaskType::Reasoning,
1171        ];
1172
1173        for task_type in task_types {
1174            let context = create_test_context("Test policy evaluation", task_type.clone());
1175
1176            let decision = engine.route_request(&context).await.unwrap();
1177
1178            // Should get a valid routing decision for each task type
1179            match decision {
1180                RouteDecision::UseSLM { .. }
1181                | RouteDecision::UseLLM { .. }
1182                | RouteDecision::Deny { .. } => {
1183                    // All are valid outcomes
1184                }
1185            }
1186        }
1187    }
1188
1189    #[tokio::test]
1190    async fn test_concurrent_routing_requests() {
1191        let engine = Arc::new(create_test_engine().await);
1192
1193        let mut handles = Vec::new();
1194
1195        // Spawn multiple concurrent routing requests
1196        for i in 0..10 {
1197            let engine_clone = Arc::clone(&engine);
1198            let handle = tokio::spawn(async move {
1199                let context = create_test_context(
1200                    &format!("Concurrent request {}", i),
1201                    super::super::error::TaskType::CodeGeneration,
1202                );
1203
1204                let request = create_test_request(&format!("Concurrent request {}", i));
1205
1206                engine_clone.execute_with_routing(context, request).await
1207            });
1208            handles.push(handle);
1209        }
1210
1211        // Wait for all requests to complete
1212        let results = futures::future::join_all(handles).await;
1213
1214        // All requests should succeed
1215        for result in results {
1216            let response = result.unwrap().unwrap();
1217            assert!(!response.content.is_empty());
1218        }
1219
1220        // Check that statistics reflect all requests
1221        let stats = engine.get_routing_stats().await;
1222        assert_eq!(stats.total_requests(), 10);
1223    }
1224
1225    #[tokio::test]
1226    async fn test_error_handling_and_recovery() {
1227        let engine = create_test_engine().await;
1228
1229        // Test various error scenarios
1230        let error_scenarios = vec![
1231            ("error trigger", "Should trigger SLM execution error"),
1232            ("", "Empty prompt"),
1233            ("   ", "Whitespace-only prompt"),
1234        ];
1235
1236        for (prompt, description) in error_scenarios {
1237            let context =
1238                create_test_context(prompt, super::super::error::TaskType::CodeGeneration);
1239            let request = create_test_request(prompt);
1240
1241            let result = engine.execute_with_routing(context, request).await;
1242
1243            match result {
1244                Ok(response) => {
1245                    // Should get a response (likely from LLM fallback)
1246                    assert!(!response.content.is_empty(), "Failed for: {}", description);
1247                }
1248                Err(e) => {
1249                    // Some errors are expected, but should be handled gracefully
1250                    tracing::info!("Expected error for '{}': {:?}", description, e);
1251                }
1252            }
1253        }
1254    }
1255
1256    #[tokio::test]
1257    async fn test_model_metadata_and_token_usage() {
1258        let engine = create_test_engine().await;
1259
1260        let context = create_test_context(
1261            "Test metadata and token usage",
1262            super::super::error::TaskType::CodeGeneration,
1263        );
1264
1265        let request = create_test_request("Test metadata and token usage");
1266
1267        let response = engine.execute_with_routing(context, request).await.unwrap();
1268
1269        // Verify response structure
1270        assert!(!response.content.is_empty());
1271        assert!(response.token_usage.is_some());
1272        assert!(!response.metadata.is_empty());
1273
1274        let token_usage = response.token_usage.unwrap();
1275        assert!(token_usage.prompt_tokens > 0);
1276        assert!(token_usage.completion_tokens > 0);
1277        assert_eq!(
1278            token_usage.total_tokens,
1279            token_usage.prompt_tokens + token_usage.completion_tokens
1280        );
1281    }
1282
1283    #[tokio::test]
1284    async fn test_validate_policies() {
1285        let engine = create_test_engine().await;
1286
1287        let result = engine.validate_policies();
1288        assert!(result.is_ok());
1289    }
1290
1291    #[tokio::test]
1292    async fn test_engine_state_consistency() {
1293        let engine = create_test_engine().await;
1294
1295        // Verify initial state
1296        let initial_stats = engine.get_routing_stats().await;
1297        assert_eq!(initial_stats.total_requests(), 0);
1298
1299        // Execute some requests
1300        for i in 0..5 {
1301            let context = create_test_context(
1302                &format!("Test request {}", i),
1303                super::super::error::TaskType::CodeGeneration,
1304            );
1305            let request = create_test_request(&format!("Test request {}", i));
1306
1307            let _response = engine.execute_with_routing(context, request).await.unwrap();
1308        }
1309
1310        // Verify state was updated consistently
1311        let final_stats = engine.get_routing_stats().await;
1312        assert_eq!(final_stats.total_requests(), 5);
1313        assert!(final_stats.average_response_time() > Duration::ZERO);
1314    }
1315
1316    #[tokio::test]
1317    async fn test_routing_statistics_concurrent_updates() {
1318        use super::super::decision::MonitoringLevel;
1319        use std::sync::Arc;
1320
1321        let stats = Arc::new(RoutingStatistics::default());
1322        let mut handles = vec![];
1323
1324        for _ in 0..100 {
1325            let stats_clone = Arc::clone(&stats);
1326            handles.push(tokio::spawn(async move {
1327                stats_clone.record_request(
1328                    &RouteDecision::UseSLM {
1329                        model_id: "test".to_string(),
1330                        monitoring: MonitoringLevel::None,
1331                        fallback_on_failure: false,
1332                        sandbox_tier: None,
1333                    },
1334                    Duration::from_millis(10),
1335                    true,
1336                );
1337            }));
1338        }
1339
1340        for handle in handles {
1341            handle.await.unwrap();
1342        }
1343
1344        assert_eq!(stats.total_requests(), 100);
1345        assert_eq!(stats.slm_routes(), 100);
1346    }
1347
1348    #[tokio::test]
1349    async fn test_fallback_counted_once_per_request() {
1350        let engine = create_test_engine().await;
1351
1352        // Low-confidence path
1353        let context =
1354            create_test_context("Test prompt", super::super::error::TaskType::CodeGeneration);
1355        let request = create_test_request("Test prompt");
1356        let _response = engine
1357            .execute_slm_route(
1358                &context,
1359                &request,
1360                "test-slm",
1361                &crate::routing::MonitoringLevel::Enhanced {
1362                    confidence_threshold: 1.0,
1363                },
1364                true,
1365            )
1366            .await
1367            .unwrap();
1368
1369        // Error path
1370        let context2 = create_test_context(
1371            "error trigger",
1372            super::super::error::TaskType::CodeGeneration,
1373        );
1374        let request2 = create_test_request("error trigger");
1375        let _response2 = engine
1376            .execute_slm_route(
1377                &context2,
1378                &request2,
1379                "test-slm",
1380                &crate::routing::MonitoringLevel::Basic,
1381                true,
1382            )
1383            .await
1384            .unwrap();
1385
1386        let stats = engine.get_routing_stats().await;
1387        assert_eq!(stats.fallback_routes(), 2);
1388    }
1389
1390    #[tokio::test]
1391    async fn test_llm_client_pool_register_and_execute() {
1392        let mut pool = LLMClientPool::new();
1393        pool.register("openai", Box::new(MockLLMClient));
1394
1395        let request = create_test_request("Test");
1396        let provider = LLMProvider::OpenAI {
1397            model: Some("gpt-4".to_string()),
1398        };
1399        let response = pool.execute_request(&request, &provider).await.unwrap();
1400        assert!(!response.content.is_empty());
1401    }
1402
1403    #[tokio::test]
1404    async fn test_llm_client_pool_no_client_registered() {
1405        let pool = LLMClientPool::new();
1406        let request = create_test_request("Test");
1407        let provider = LLMProvider::OpenAI { model: None };
1408        let result = pool.execute_request(&request, &provider).await;
1409        assert!(result.is_err());
1410    }
1411}