Skip to main content

symbi_runtime/routing/
engine.rs

1//! Core routing engine implementation
2
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8
9use super::classifier::TaskClassifier;
10use super::confidence::{ConfidenceMonitorTrait, NoOpConfidenceMonitor};
11use super::config::RoutingConfig;
12use super::decision::{
13    FinishReason, LLMProvider, ModelRequest, ModelResponse, RouteDecision, RoutingContext,
14    RoutingStatistics, TokenUsage,
15};
16use super::error::RoutingError;
17use super::policy::PolicyEvaluator;
18use crate::logging::{
19    ModelInteractionType, ModelLogger, RequestData, ResponseData, TokenUsage as LogTokenUsage,
20};
21use crate::models::{ModelCatalog, SlmRunnerError};
22
23/// Core routing engine trait for SLM-first architecture
24#[async_trait]
25pub trait RoutingEngine: Send + Sync {
26    /// Route a model request based on configured policies
27    async fn route_request(&self, context: &RoutingContext) -> Result<RouteDecision, RoutingError>;
28
29    /// Execute the routing decision and handle fallbacks
30    async fn execute_with_routing(
31        &self,
32        context: RoutingContext,
33        request: ModelRequest,
34    ) -> Result<ModelResponse, RoutingError>;
35
36    /// Validate routing policies
37    fn validate_policies(&self) -> Result<(), RoutingError>;
38
39    /// Get routing statistics
40    async fn get_routing_stats(&self) -> RoutingStatistics;
41
42    /// Update routing configuration
43    async fn update_config(&self, config: RoutingConfig) -> Result<(), RoutingError>;
44}
45
46/// Default implementation of the routing engine
47pub struct DefaultRoutingEngine {
48    /// Policy evaluator for making routing decisions
49    policy_evaluator: Arc<RwLock<PolicyEvaluator>>,
50    /// Model catalog for SLM information
51    model_catalog: Arc<ModelCatalog>,
52    /// Confidence monitor for evaluating SLM responses
53    #[allow(dead_code)]
54    confidence_monitor: Arc<RwLock<Box<dyn ConfidenceMonitorTrait>>>,
55    /// Optional model logger for audit trails
56    model_logger: Option<Arc<ModelLogger>>,
57    /// Routing statistics
58    statistics: Arc<RwLock<RoutingStatistics>>,
59    /// Configuration
60    config: Arc<RwLock<RoutingConfig>>,
61    /// LLM client pool for fallback
62    llm_clients: Arc<LLMClientPool>,
63}
64
65/// Pool of LLM clients for different providers
66struct LLMClientPool {
67    clients: HashMap<String, Box<dyn LLMClient>>,
68}
69
70/// Trait for LLM clients
71#[async_trait]
72trait LLMClient: Send + Sync {
73    async fn execute_request(
74        &self,
75        request: &ModelRequest,
76        provider: &LLMProvider,
77    ) -> Result<ModelResponse, RoutingError>;
78}
79
80/// Mock LLM client implementation
81#[derive(Debug)]
82struct MockLLMClient;
83
84#[async_trait]
85impl LLMClient for MockLLMClient {
86    async fn execute_request(
87        &self,
88        request: &ModelRequest,
89        provider: &LLMProvider,
90    ) -> Result<ModelResponse, RoutingError> {
91        // Mock implementation - in real system would call actual LLM APIs
92        tokio::time::sleep(Duration::from_millis(100)).await;
93
94        Ok(ModelResponse {
95            content: format!("LLM response to: {}", request.prompt),
96            finish_reason: FinishReason::Stop,
97            token_usage: Some(TokenUsage {
98                prompt_tokens: request.prompt.len() as u32 / 4,
99                completion_tokens: 50,
100                total_tokens: (request.prompt.len() as u32 / 4) + 50,
101            }),
102            metadata: {
103                let mut meta = HashMap::new();
104                meta.insert(
105                    "provider".to_string(),
106                    serde_json::Value::String(provider.to_string()),
107                );
108                meta.insert("mock".to_string(), serde_json::Value::Bool(true));
109                meta
110            },
111            confidence_score: Some(0.95),
112        })
113    }
114}
115
116impl LLMClientPool {
117    fn new() -> Self {
118        let mut clients: HashMap<String, Box<dyn LLMClient>> = HashMap::new();
119        clients.insert("openai".to_string(), Box::new(MockLLMClient));
120        clients.insert("anthropic".to_string(), Box::new(MockLLMClient));
121        clients.insert("custom".to_string(), Box::new(MockLLMClient));
122
123        Self { clients }
124    }
125
126    async fn execute_request(
127        &self,
128        request: &ModelRequest,
129        provider: &LLMProvider,
130    ) -> Result<ModelResponse, RoutingError> {
131        let client_key = match provider {
132            LLMProvider::OpenAI { .. } => "openai",
133            LLMProvider::Anthropic { .. } => "anthropic",
134            LLMProvider::Custom { .. } => "custom",
135        };
136
137        let client =
138            self.clients
139                .get(client_key)
140                .ok_or_else(|| RoutingError::LLMFallbackFailed {
141                    provider: provider.to_string(),
142                    reason: "No client available for provider".to_string(),
143                })?;
144
145        client.execute_request(request, provider).await
146    }
147}
148
149impl DefaultRoutingEngine {
150    /// Create a new routing engine with the given configuration
151    pub async fn new(
152        config: RoutingConfig,
153        model_catalog: ModelCatalog,
154        model_logger: Option<Arc<ModelLogger>>,
155    ) -> Result<Self, RoutingError> {
156        Self::new_with_confidence_monitor(
157            config,
158            model_catalog,
159            model_logger,
160            Box::new(NoOpConfidenceMonitor),
161        )
162        .await
163    }
164
165    /// Create a new routing engine with a custom confidence monitor implementation
166    /// This allows enterprise builds to inject their own confidence monitor
167    pub async fn new_with_confidence_monitor(
168        config: RoutingConfig,
169        model_catalog: ModelCatalog,
170        model_logger: Option<Arc<ModelLogger>>,
171        confidence_monitor: Box<dyn ConfidenceMonitorTrait>,
172    ) -> Result<Self, RoutingError> {
173        // Create task classifier
174        let classifier = TaskClassifier::new(config.classification.clone())?;
175
176        // Create policy evaluator
177        let policy_evaluator =
178            PolicyEvaluator::new(config.policy.clone(), classifier, model_catalog.clone())?;
179
180        // Create LLM client pool
181        let llm_clients = Arc::new(LLMClientPool::new());
182
183        let engine = Self {
184            policy_evaluator: Arc::new(RwLock::new(policy_evaluator)),
185            model_catalog: Arc::new(model_catalog),
186            confidence_monitor: Arc::new(RwLock::new(confidence_monitor)),
187            model_logger,
188            statistics: Arc::new(RwLock::new(RoutingStatistics::default())),
189            config: Arc::new(RwLock::new(config)),
190            llm_clients,
191        };
192
193        Ok(engine)
194    }
195
196    /// Execute an SLM route with monitoring and fallback
197    async fn execute_slm_route(
198        &self,
199        context: &RoutingContext,
200        request: &ModelRequest,
201        model_id: &str,
202        monitoring_level: &super::decision::MonitoringLevel,
203        fallback_on_failure: bool,
204    ) -> Result<ModelResponse, RoutingError> {
205        let _start_time = Instant::now();
206
207        // Get the model from catalog
208        let model = self.model_catalog.get_model(model_id).ok_or_else(|| {
209            RoutingError::NoSuitableModel {
210                task_type: context.task_type.clone(),
211            }
212        })?;
213
214        // Execute the SLM (mock implementation)
215        let slm_result = self.execute_slm_mock(request, model).await;
216
217        match slm_result {
218            Ok(response) => {
219                // Evaluate confidence if monitoring is enabled
220                let should_fallback = match monitoring_level {
221                    super::decision::MonitoringLevel::None => false,
222                    super::decision::MonitoringLevel::Basic => {
223                        // Basic monitoring - check for obvious failures
224                        response.finish_reason != FinishReason::Stop
225                    }
226                    super::decision::MonitoringLevel::Enhanced {
227                        confidence_threshold,
228                    } => {
229                        // For enhanced monitoring, use confidence score if available
230                        // Enterprise builds can inject more sophisticated confidence monitors
231                        let confidence_score = response.confidence_score.unwrap_or(0.5);
232                        confidence_score < *confidence_threshold
233                    }
234                };
235
236                if should_fallback && fallback_on_failure {
237                    tracing::warn!(
238                        "SLM response did not meet confidence threshold, falling back to LLM"
239                    );
240
241                    // Update statistics for fallback
242                    {
243                        let mut stats = self.statistics.write().await;
244                        stats.fallback_routes += 1;
245                    }
246
247                    self.execute_llm_fallback(request, "Low confidence SLM response")
248                        .await
249                } else {
250                    // Note: In enterprise mode, the ConfidenceMonitor would record evaluation results
251                    // but the trait interface doesn't expose this method to keep OSS code clean
252                    Ok(response)
253                }
254            }
255            Err(e) => {
256                if fallback_on_failure {
257                    tracing::error!("SLM execution failed, falling back to LLM: {}", e);
258
259                    // Update statistics for fallback
260                    {
261                        let mut stats = self.statistics.write().await;
262                        stats.fallback_routes += 1;
263                    }
264
265                    self.execute_llm_fallback(request, &format!("SLM execution failed: {}", e))
266                        .await
267                } else {
268                    Err(RoutingError::ModelExecutionFailed {
269                        model_id: model_id.to_string(),
270                        reason: e.to_string(),
271                    })
272                }
273            }
274        }
275    }
276
277    /// Mock SLM execution (in real implementation, would use SlmRunner)
278    async fn execute_slm_mock(
279        &self,
280        request: &ModelRequest,
281        model: &crate::config::Model,
282    ) -> Result<ModelResponse, SlmRunnerError> {
283        // Simulate SLM execution time
284        tokio::time::sleep(Duration::from_millis(200)).await;
285
286        // Simulate potential failure for certain inputs
287        if request.prompt.contains("error") {
288            return Err(SlmRunnerError::ExecutionFailed {
289                reason: "Simulated execution error".to_string(),
290            });
291        }
292
293        Ok(ModelResponse {
294            content: format!("SLM ({}) response: {}", model.name, request.prompt),
295            finish_reason: FinishReason::Stop,
296            token_usage: Some(TokenUsage {
297                prompt_tokens: request.prompt.len() as u32 / 4,
298                completion_tokens: 30,
299                total_tokens: (request.prompt.len() as u32 / 4) + 30,
300            }),
301            metadata: {
302                let mut meta = HashMap::new();
303                meta.insert(
304                    "model_id".to_string(),
305                    serde_json::Value::String(model.id.clone()),
306                );
307                meta.insert(
308                    "provider".to_string(),
309                    serde_json::Value::String(format!("{:?}", model.provider)),
310                );
311                meta
312            },
313            confidence_score: Some(0.8 + (request.prompt.len() % 20) as f64 / 100.0), // Mock confidence
314        })
315    }
316
317    /// Execute LLM fallback
318    async fn execute_llm_fallback(
319        &self,
320        request: &ModelRequest,
321        _reason: &str,
322    ) -> Result<ModelResponse, RoutingError> {
323        let config = self.config.read().await;
324        let fallback_config = &config.policy.fallback_config;
325
326        if !fallback_config.enabled {
327            return Err(RoutingError::LLMFallbackFailed {
328                provider: "disabled".to_string(),
329                reason: "LLM fallback is disabled".to_string(),
330            });
331        }
332
333        // Try primary provider first
334        let provider = LLMProvider::OpenAI { model: None };
335
336        match self.llm_clients.execute_request(request, &provider).await {
337            Ok(response) => Ok(response),
338            Err(e) => Err(RoutingError::LLMFallbackFailed {
339                provider: provider.to_string(),
340                reason: e.to_string(),
341            }),
342        }
343    }
344
345    /// Log routing decision and execution
346    async fn log_routing_execution(
347        &self,
348        context: &RoutingContext,
349        decision: &RouteDecision,
350        request: &ModelRequest,
351        response: &ModelResponse,
352        execution_time: Duration,
353        error: Option<&RoutingError>,
354    ) {
355        if let Some(ref logger) = self.model_logger {
356            let model_name = match decision {
357                RouteDecision::UseSLM { model_id, .. } => model_id.clone(),
358                RouteDecision::UseLLM { provider, .. } => provider.to_string(),
359                RouteDecision::Deny { .. } => "denied".to_string(),
360            };
361
362            let request_data = RequestData {
363                prompt: request.prompt.clone(),
364                tool_name: None,
365                tool_arguments: None,
366                parameters: {
367                    let mut params = HashMap::new();
368                    params.insert(
369                        "routing_decision".to_string(),
370                        serde_json::Value::String(format!("{:?}", decision)),
371                    );
372                    params.insert(
373                        "task_type".to_string(),
374                        serde_json::Value::String(context.task_type.to_string()),
375                    );
376                    params
377                },
378            };
379
380            let response_data = ResponseData {
381                content: response.content.clone(),
382                tool_result: None,
383                confidence: response.confidence_score,
384                metadata: response.metadata.clone(),
385            };
386
387            let metadata = {
388                let mut meta = HashMap::new();
389                meta.insert("routing_engine".to_string(), "default".to_string());
390                meta.insert("agent_id".to_string(), context.agent_id.to_string());
391                meta.insert("request_id".to_string(), context.request_id.clone());
392                meta
393            };
394
395            if let Err(e) = logger
396                .log_interaction(
397                    context.agent_id,
398                    ModelInteractionType::Completion,
399                    &model_name,
400                    request_data,
401                    response_data,
402                    execution_time,
403                    metadata,
404                    response.token_usage.as_ref().map(|t| LogTokenUsage {
405                        input_tokens: t.prompt_tokens,
406                        output_tokens: t.completion_tokens,
407                        total_tokens: t.total_tokens,
408                    }),
409                    error.map(|e| e.to_string()),
410                )
411                .await
412            {
413                tracing::warn!("Failed to log routing execution: {}", e);
414            }
415        }
416    }
417}
418
419#[async_trait]
420impl RoutingEngine for DefaultRoutingEngine {
421    async fn route_request(&self, context: &RoutingContext) -> Result<RouteDecision, RoutingError> {
422        let policy_result = self
423            .policy_evaluator
424            .read()
425            .await
426            .evaluate_policies(context)
427            .await?;
428
429        tracing::debug!(
430            "Routing decision for agent {}: {:?} (rule: {:?})",
431            context.agent_id,
432            policy_result.decision,
433            policy_result.matched_rule
434        );
435
436        Ok(policy_result.decision)
437    }
438
439    async fn execute_with_routing(
440        &self,
441        context: RoutingContext,
442        request: ModelRequest,
443    ) -> Result<ModelResponse, RoutingError> {
444        let start_time = Instant::now();
445        let route_decision = self.route_request(&context).await?;
446
447        let result = match &route_decision {
448            RouteDecision::UseSLM {
449                model_id,
450                monitoring,
451                fallback_on_failure,
452                sandbox_tier: _,
453            } => {
454                self.execute_slm_route(
455                    &context,
456                    &request,
457                    model_id,
458                    monitoring,
459                    *fallback_on_failure,
460                )
461                .await
462            }
463            RouteDecision::UseLLM {
464                provider,
465                reason,
466                sandbox_tier: _,
467            } => {
468                tracing::info!("Routing to LLM: {}", reason);
469                self.llm_clients.execute_request(&request, provider).await
470            }
471            RouteDecision::Deny {
472                reason,
473                policy_violated,
474            } => {
475                return Err(RoutingError::RoutingDenied {
476                    policy: policy_violated.clone(),
477                    reason: reason.clone(),
478                });
479            }
480        };
481
482        let execution_time = start_time.elapsed();
483
484        // Update statistics
485        {
486            let mut stats = self.statistics.write().await;
487            stats.update(&route_decision, execution_time, result.is_ok());
488
489            if let Ok(ref response) = result {
490                if let Some(confidence) = response.confidence_score {
491                    stats.add_confidence_score(confidence);
492                }
493            }
494        }
495
496        // Log the execution
497        match &result {
498            Ok(response) => {
499                self.log_routing_execution(
500                    &context,
501                    &route_decision,
502                    &request,
503                    response,
504                    execution_time,
505                    None,
506                )
507                .await;
508            }
509            Err(error) => {
510                // Create a dummy response for logging
511                let dummy_response = ModelResponse {
512                    content: "Error occurred".to_string(),
513                    finish_reason: FinishReason::Error,
514                    token_usage: None,
515                    metadata: HashMap::new(),
516                    confidence_score: Some(0.0),
517                };
518
519                self.log_routing_execution(
520                    &context,
521                    &route_decision,
522                    &request,
523                    &dummy_response,
524                    execution_time,
525                    Some(error),
526                )
527                .await;
528            }
529        }
530
531        result
532    }
533
534    fn validate_policies(&self) -> Result<(), RoutingError> {
535        // Validation is done during PolicyEvaluator creation
536        Ok(())
537    }
538
539    async fn get_routing_stats(&self) -> RoutingStatistics {
540        self.statistics.read().await.clone()
541    }
542
543    async fn update_config(&self, config: RoutingConfig) -> Result<(), RoutingError> {
544        // Update configuration
545        *self.config.write().await = config.clone();
546
547        // Update policy evaluator
548        self.policy_evaluator
549            .write()
550            .await
551            .update_config(config.policy)?;
552
553        Ok(())
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use crate::config::{
561        Model, ModelAllowListConfig, ModelProvider, ModelResourceRequirements, SandboxProfile, Slm,
562    };
563    use crate::types::AgentId;
564    use std::collections::HashMap;
565    use std::path::PathBuf;
566
567    async fn create_test_engine() -> DefaultRoutingEngine {
568        let global_models = vec![
569            Model {
570                id: "test-slm".to_string(),
571                name: "Test SLM".to_string(),
572                provider: ModelProvider::LocalFile {
573                    file_path: PathBuf::from("/tmp/test.gguf"),
574                },
575                capabilities: vec![
576                    crate::config::ModelCapability::TextGeneration,
577                    crate::config::ModelCapability::CodeGeneration,
578                ],
579                resource_requirements: ModelResourceRequirements {
580                    min_memory_mb: 1024,
581                    preferred_cpu_cores: 2.0,
582                    gpu_requirements: None,
583                },
584            },
585            Model {
586                id: "error-slm".to_string(),
587                name: "Error SLM".to_string(),
588                provider: ModelProvider::LocalFile {
589                    file_path: PathBuf::from("/tmp/error.gguf"),
590                },
591                capabilities: vec![crate::config::ModelCapability::TextGeneration],
592                resource_requirements: ModelResourceRequirements {
593                    min_memory_mb: 512,
594                    preferred_cpu_cores: 1.0,
595                    gpu_requirements: None,
596                },
597            },
598        ];
599
600        let mut sandbox_profiles = HashMap::new();
601        sandbox_profiles.insert("default".to_string(), SandboxProfile::secure_default());
602
603        let slm_config = Slm {
604            enabled: true,
605            model_allow_lists: ModelAllowListConfig {
606                global_models,
607                agent_model_maps: HashMap::new(),
608                allow_runtime_overrides: false,
609            },
610            sandbox_profiles,
611            default_sandbox_profile: "default".to_string(),
612        };
613
614        let model_catalog = ModelCatalog::new(slm_config).unwrap();
615        let mut config = RoutingConfig::default();
616
617        // Add SLM routing rule for code generation tasks
618        config.policy.rules.push(super::super::config::RoutingRule {
619            name: "slm_code_rule".to_string(),
620            priority: 100,
621            conditions: super::super::config::RoutingConditions {
622                task_types: Some(vec![super::super::error::TaskType::CodeGeneration]),
623                agent_ids: None,
624                resource_constraints: None,
625                security_level: None,
626                custom_conditions: None,
627            },
628            action: super::super::config::RouteAction::UseSLM {
629                model_preference: super::super::config::ModelPreference::BestAvailable,
630                monitoring_level: crate::routing::MonitoringLevel::Basic,
631                fallback_on_low_confidence: true,
632                confidence_threshold: Some(0.8),
633            },
634            override_allowed: true,
635            action_extension: None,
636        });
637
638        DefaultRoutingEngine::new(config, model_catalog, None)
639            .await
640            .unwrap()
641    }
642
643    async fn create_test_engine_with_logger() -> DefaultRoutingEngine {
644        let logger =
645            ModelLogger::new(super::super::super::logging::LoggingConfig::default(), None).unwrap();
646
647        let global_models = vec![Model {
648            id: "test-slm".to_string(),
649            name: "Test SLM".to_string(),
650            provider: ModelProvider::LocalFile {
651                file_path: PathBuf::from("/tmp/test.gguf"),
652            },
653            capabilities: vec![crate::config::ModelCapability::TextGeneration],
654            resource_requirements: ModelResourceRequirements {
655                min_memory_mb: 1024,
656                preferred_cpu_cores: 2.0,
657                gpu_requirements: None,
658            },
659        }];
660
661        let mut sandbox_profiles = HashMap::new();
662        sandbox_profiles.insert("default".to_string(), SandboxProfile::secure_default());
663
664        let slm_config = Slm {
665            enabled: true,
666            model_allow_lists: ModelAllowListConfig {
667                global_models,
668                agent_model_maps: HashMap::new(),
669                allow_runtime_overrides: false,
670            },
671            sandbox_profiles,
672            default_sandbox_profile: "default".to_string(),
673        };
674
675        let model_catalog = ModelCatalog::new(slm_config).unwrap();
676        let config = RoutingConfig::default();
677
678        DefaultRoutingEngine::new(config, model_catalog, Some(Arc::new(logger)))
679            .await
680            .unwrap()
681    }
682
683    fn create_test_request(prompt: &str) -> ModelRequest {
684        ModelRequest::from_task(prompt.to_string())
685    }
686
687    fn create_test_context(
688        prompt: &str,
689        task_type: super::super::error::TaskType,
690    ) -> RoutingContext {
691        RoutingContext::new(AgentId::new(), task_type, prompt.to_string())
692    }
693
694    #[tokio::test]
695    async fn test_routing_engine_creation() {
696        let engine = create_test_engine().await;
697
698        // Verify engine was created successfully
699        let stats = engine.get_routing_stats().await;
700        assert_eq!(stats.total_requests, 0);
701
702        // Verify policies can be validated
703        assert!(engine.validate_policies().is_ok());
704    }
705
706    #[tokio::test]
707    async fn test_routing_engine_with_logger() {
708        let engine = create_test_engine_with_logger().await;
709
710        // Should have logger configured
711        assert!(engine.model_logger.is_some());
712
713        let stats = engine.get_routing_stats().await;
714        assert_eq!(stats.total_requests, 0);
715    }
716
717    #[tokio::test]
718    async fn test_routing_engine_basic_flow() {
719        let engine = create_test_engine().await;
720
721        let context = create_test_context(
722            "Write a hello world function",
723            super::super::error::TaskType::CodeGeneration,
724        );
725
726        let decision = engine.route_request(&context).await.unwrap();
727
728        // Should get some kind of routing decision
729        match decision {
730            RouteDecision::UseSLM { .. } | RouteDecision::UseLLM { .. } => {
731                // Expected outcomes
732            }
733            RouteDecision::Deny { .. } => {
734                panic!("Should not deny basic request");
735            }
736        }
737    }
738
739    #[tokio::test]
740    async fn test_execute_with_routing_slm_success() {
741        let engine = create_test_engine().await;
742
743        let context = create_test_context(
744            "Write a hello world function",
745            super::super::error::TaskType::CodeGeneration,
746        );
747
748        let request = create_test_request("Write a hello world function");
749
750        let response = engine.execute_with_routing(context, request).await.unwrap();
751
752        assert!(!response.content.is_empty());
753        assert!(response.confidence_score.is_some());
754        assert_eq!(response.finish_reason, FinishReason::Stop);
755
756        // Check that statistics were updated
757        let stats = engine.get_routing_stats().await;
758        assert!(stats.total_requests > 0);
759    }
760
761    #[tokio::test]
762    async fn test_execute_with_routing_slm_error_fallback() {
763        let engine = create_test_engine().await;
764
765        let context = create_test_context(
766            "This should trigger an error in SLM",
767            super::super::error::TaskType::CodeGeneration,
768        );
769
770        let request = create_test_request("error trigger");
771
772        let response = engine.execute_with_routing(context, request).await.unwrap();
773
774        // Should get LLM fallback response
775        assert!(!response.content.is_empty());
776        assert!(response.content.contains("LLM response"));
777
778        let stats = engine.get_routing_stats().await;
779        assert!(stats.fallback_routes > 0);
780    }
781
782    #[tokio::test]
783    async fn test_slm_execution_success() {
784        let engine = create_test_engine().await;
785
786        let context =
787            create_test_context("Test prompt", super::super::error::TaskType::CodeGeneration);
788
789        let request = create_test_request("Test prompt");
790
791        let response = engine
792            .execute_slm_route(
793                &context,
794                &request,
795                "test-slm",
796                &crate::routing::MonitoringLevel::Basic,
797                true,
798            )
799            .await
800            .unwrap();
801
802        assert!(!response.content.is_empty());
803        assert!(response.content.contains("Test SLM"));
804        assert!(response.confidence_score.is_some());
805    }
806
807    #[tokio::test]
808    async fn test_slm_execution_with_enhanced_monitoring() {
809        let engine = create_test_engine().await;
810
811        let context = create_test_context(
812            "Test prompt with monitoring",
813            super::super::error::TaskType::CodeGeneration,
814        );
815
816        let request = create_test_request("Test prompt with monitoring");
817
818        let response = engine
819            .execute_slm_route(
820                &context,
821                &request,
822                "test-slm",
823                &crate::routing::MonitoringLevel::Enhanced {
824                    confidence_threshold: 0.9,
825                },
826                true,
827            )
828            .await
829            .unwrap();
830
831        // Should either get SLM response or LLM fallback
832        assert!(!response.content.is_empty());
833    }
834
835    #[tokio::test]
836    async fn test_slm_execution_no_monitoring() {
837        let engine = create_test_engine().await;
838
839        let context = create_test_context(
840            "Test prompt no monitoring",
841            super::super::error::TaskType::CodeGeneration,
842        );
843
844        let request = create_test_request("Test prompt no monitoring");
845
846        let response = engine
847            .execute_slm_route(
848                &context,
849                &request,
850                "test-slm",
851                &crate::routing::MonitoringLevel::None,
852                true,
853            )
854            .await
855            .unwrap();
856
857        // Should get SLM response without monitoring
858        assert!(!response.content.is_empty());
859        assert!(response.content.contains("Test SLM"));
860    }
861
862    #[tokio::test]
863    async fn test_slm_execution_error_no_fallback() {
864        let engine = create_test_engine().await;
865
866        let context = create_test_context(
867            "error trigger",
868            super::super::error::TaskType::CodeGeneration,
869        );
870
871        let request = create_test_request("error trigger");
872
873        let result = engine
874            .execute_slm_route(
875                &context,
876                &request,
877                "test-slm",
878                &crate::routing::MonitoringLevel::Basic,
879                false, // No fallback
880            )
881            .await;
882
883        assert!(result.is_err());
884        assert!(matches!(
885            result.unwrap_err(),
886            RoutingError::ModelExecutionFailed { .. }
887        ));
888    }
889
890    #[tokio::test]
891    async fn test_slm_execution_nonexistent_model() {
892        let engine = create_test_engine().await;
893
894        let context =
895            create_test_context("Test prompt", super::super::error::TaskType::CodeGeneration);
896
897        let request = create_test_request("Test prompt");
898
899        let result = engine
900            .execute_slm_route(
901                &context,
902                &request,
903                "nonexistent-model",
904                &crate::routing::MonitoringLevel::Basic,
905                true,
906            )
907            .await;
908
909        assert!(result.is_err());
910        assert!(matches!(
911            result.unwrap_err(),
912            RoutingError::NoSuitableModel { .. }
913        ));
914    }
915
916    #[tokio::test]
917    async fn test_llm_fallback_execution() {
918        let engine = create_test_engine().await;
919
920        let request = create_test_request("Test LLM fallback");
921
922        let response = engine
923            .execute_llm_fallback(&request, "Test reason")
924            .await
925            .unwrap();
926
927        assert!(!response.content.is_empty());
928        assert!(response.content.contains("LLM response"));
929        assert_eq!(response.finish_reason, FinishReason::Stop);
930        assert!(response.confidence_score.is_some());
931    }
932
933    #[tokio::test]
934    async fn test_llm_fallback_disabled() {
935        let engine = create_test_engine().await;
936
937        // Disable fallback in config
938        {
939            let mut config = engine.config.write().await;
940            config.policy.fallback_config.enabled = false;
941        }
942
943        let request = create_test_request("Test disabled fallback");
944
945        let result = engine.execute_llm_fallback(&request, "Test reason").await;
946
947        assert!(result.is_err());
948        assert!(matches!(
949            result.unwrap_err(),
950            RoutingError::LLMFallbackFailed { .. }
951        ));
952    }
953
954    #[tokio::test]
955    async fn test_llm_client_pool() {
956        let pool = LLMClientPool::new();
957
958        let request = create_test_request("Test LLM client");
959
960        // Test OpenAI provider
961        let openai_response = pool
962            .execute_request(
963                &request,
964                &LLMProvider::OpenAI {
965                    model: Some("gpt-3.5-turbo".to_string()),
966                },
967            )
968            .await
969            .unwrap();
970
971        assert!(!openai_response.content.is_empty());
972        assert!(openai_response.metadata.contains_key("provider"));
973
974        // Test Anthropic provider
975        let anthropic_response = pool
976            .execute_request(
977                &request,
978                &LLMProvider::Anthropic {
979                    model: Some("claude-3".to_string()),
980                },
981            )
982            .await
983            .unwrap();
984
985        assert!(!anthropic_response.content.is_empty());
986        assert!(anthropic_response.metadata.contains_key("provider"));
987
988        // Test Custom provider
989        let custom_response = pool
990            .execute_request(
991                &request,
992                &LLMProvider::Custom {
993                    endpoint: "http://localhost:8080".to_string(),
994                    model: None,
995                },
996            )
997            .await
998            .unwrap();
999
1000        assert!(!custom_response.content.is_empty());
1001    }
1002
1003    #[tokio::test]
1004    async fn test_mock_slm_execution() {
1005        let engine = create_test_engine().await;
1006
1007        let request = create_test_request("Test SLM execution");
1008        let model = engine.model_catalog.get_model("test-slm").unwrap();
1009
1010        let response = engine.execute_slm_mock(&request, model).await.unwrap();
1011
1012        assert!(!response.content.is_empty());
1013        assert!(response.content.contains("Test SLM"));
1014        assert!(response.content.contains("Test SLM execution"));
1015        assert_eq!(response.finish_reason, FinishReason::Stop);
1016        assert!(response.confidence_score.is_some());
1017        assert!(response.token_usage.is_some());
1018    }
1019
1020    #[tokio::test]
1021    async fn test_mock_slm_execution_error() {
1022        let engine = create_test_engine().await;
1023
1024        let request = create_test_request("This should error out");
1025        let model = engine.model_catalog.get_model("test-slm").unwrap();
1026
1027        let result = engine.execute_slm_mock(&request, model).await;
1028
1029        assert!(result.is_err());
1030        assert!(matches!(
1031            result.unwrap_err(),
1032            SlmRunnerError::ExecutionFailed { .. }
1033        ));
1034    }
1035
1036    #[tokio::test]
1037    async fn test_routing_statistics_tracking() {
1038        let engine = create_test_engine().await;
1039
1040        // Execute a few requests to track statistics
1041        let context1 = create_test_context("Test 1", super::super::error::TaskType::CodeGeneration);
1042        let request1 = create_test_request("Test 1");
1043
1044        let _response1 = engine
1045            .execute_with_routing(context1, request1)
1046            .await
1047            .unwrap();
1048
1049        let context2 = create_test_context(
1050            "error trigger",
1051            super::super::error::TaskType::CodeGeneration,
1052        );
1053        let request2 = create_test_request("error trigger");
1054
1055        let _response2 = engine
1056            .execute_with_routing(context2, request2)
1057            .await
1058            .unwrap();
1059
1060        let stats = engine.get_routing_stats().await;
1061
1062        assert!(stats.total_requests > 0);
1063        assert!(stats.fallback_routes > 0); // Second request should trigger fallback
1064        assert!(stats.average_response_time > Duration::from_millis(0));
1065    }
1066
1067    #[tokio::test]
1068    async fn test_config_update() {
1069        let engine = create_test_engine().await;
1070
1071        let mut new_config = RoutingConfig::default();
1072        new_config.policy.fallback_config.enabled = false;
1073
1074        let result = engine.update_config(new_config.clone()).await;
1075        assert!(result.is_ok());
1076
1077        // Verify config was updated
1078        let updated_config = engine.config.read().await;
1079        assert!(!updated_config.policy.fallback_config.enabled);
1080    }
1081
1082    #[tokio::test]
1083    async fn test_routing_with_deny_decision() {
1084        let engine = create_test_engine().await;
1085
1086        // Create a routing context that would trigger a deny decision
1087        // (This would need specific policy configuration to work in practice)
1088        let context = create_test_context(
1089            "forbidden operation",
1090            super::super::error::TaskType::Custom("forbidden".to_string()),
1091        );
1092
1093        let request = create_test_request("forbidden operation");
1094
1095        // This might not trigger a deny in the default config, but test the error handling
1096        let result = engine.execute_with_routing(context, request).await;
1097
1098        // Should either succeed with a response or fail with specific error
1099        match result {
1100            Ok(response) => {
1101                assert!(!response.content.is_empty());
1102            }
1103            Err(RoutingError::RoutingDenied { .. }) => {
1104                // Expected for deny decision
1105            }
1106            Err(e) => {
1107                panic!("Unexpected error: {:?}", e);
1108            }
1109        }
1110    }
1111
1112    #[tokio::test]
1113    async fn test_logging_integration() {
1114        let engine = create_test_engine_with_logger().await;
1115
1116        let context = create_test_context(
1117            "Test logging integration",
1118            super::super::error::TaskType::CodeGeneration,
1119        );
1120
1121        let request = create_test_request("Test logging integration");
1122
1123        let response = engine.execute_with_routing(context, request).await.unwrap();
1124
1125        assert!(!response.content.is_empty());
1126        // Logging should happen in the background without affecting the response
1127    }
1128
1129    #[tokio::test]
1130    async fn test_confidence_monitoring_integration() {
1131        let engine = create_test_engine().await;
1132
1133        let context = create_test_context(
1134            "Test confidence monitoring",
1135            super::super::error::TaskType::CodeGeneration,
1136        );
1137
1138        let request = create_test_request("Test confidence monitoring");
1139
1140        let response = engine.execute_with_routing(context, request).await.unwrap();
1141
1142        assert!(!response.content.is_empty());
1143        assert!(response.confidence_score.is_some());
1144
1145        // Note: Confidence monitoring statistics are only available in enterprise mode
1146        // The trait interface doesn't expose statistics to keep OSS code clean
1147    }
1148
1149    #[tokio::test]
1150    async fn test_policy_evaluation_integration() {
1151        let engine = create_test_engine().await;
1152
1153        // Test different task types to ensure policy evaluation works
1154        let task_types = vec![
1155            super::super::error::TaskType::CodeGeneration,
1156            super::super::error::TaskType::CodeGeneration,
1157            super::super::error::TaskType::Analysis,
1158            super::super::error::TaskType::Reasoning,
1159        ];
1160
1161        for task_type in task_types {
1162            let context = create_test_context("Test policy evaluation", task_type.clone());
1163
1164            let decision = engine.route_request(&context).await.unwrap();
1165
1166            // Should get a valid routing decision for each task type
1167            match decision {
1168                RouteDecision::UseSLM { .. }
1169                | RouteDecision::UseLLM { .. }
1170                | RouteDecision::Deny { .. } => {
1171                    // All are valid outcomes
1172                }
1173            }
1174        }
1175    }
1176
1177    #[tokio::test]
1178    async fn test_concurrent_routing_requests() {
1179        let engine = Arc::new(create_test_engine().await);
1180
1181        let mut handles = Vec::new();
1182
1183        // Spawn multiple concurrent routing requests
1184        for i in 0..10 {
1185            let engine_clone = Arc::clone(&engine);
1186            let handle = tokio::spawn(async move {
1187                let context = create_test_context(
1188                    &format!("Concurrent request {}", i),
1189                    super::super::error::TaskType::CodeGeneration,
1190                );
1191
1192                let request = create_test_request(&format!("Concurrent request {}", i));
1193
1194                engine_clone.execute_with_routing(context, request).await
1195            });
1196            handles.push(handle);
1197        }
1198
1199        // Wait for all requests to complete
1200        let results = futures::future::join_all(handles).await;
1201
1202        // All requests should succeed
1203        for result in results {
1204            let response = result.unwrap().unwrap();
1205            assert!(!response.content.is_empty());
1206        }
1207
1208        // Check that statistics reflect all requests
1209        let stats = engine.get_routing_stats().await;
1210        assert_eq!(stats.total_requests, 10);
1211    }
1212
1213    #[tokio::test]
1214    async fn test_error_handling_and_recovery() {
1215        let engine = create_test_engine().await;
1216
1217        // Test various error scenarios
1218        let error_scenarios = vec![
1219            ("error trigger", "Should trigger SLM execution error"),
1220            ("", "Empty prompt"),
1221            ("   ", "Whitespace-only prompt"),
1222        ];
1223
1224        for (prompt, description) in error_scenarios {
1225            let context =
1226                create_test_context(prompt, super::super::error::TaskType::CodeGeneration);
1227            let request = create_test_request(prompt);
1228
1229            let result = engine.execute_with_routing(context, request).await;
1230
1231            match result {
1232                Ok(response) => {
1233                    // Should get a response (likely from LLM fallback)
1234                    assert!(!response.content.is_empty(), "Failed for: {}", description);
1235                }
1236                Err(e) => {
1237                    // Some errors are expected, but should be handled gracefully
1238                    tracing::info!("Expected error for '{}': {:?}", description, e);
1239                }
1240            }
1241        }
1242    }
1243
1244    #[tokio::test]
1245    async fn test_model_metadata_and_token_usage() {
1246        let engine = create_test_engine().await;
1247
1248        let context = create_test_context(
1249            "Test metadata and token usage",
1250            super::super::error::TaskType::CodeGeneration,
1251        );
1252
1253        let request = create_test_request("Test metadata and token usage");
1254
1255        let response = engine.execute_with_routing(context, request).await.unwrap();
1256
1257        // Verify response structure
1258        assert!(!response.content.is_empty());
1259        assert!(response.token_usage.is_some());
1260        assert!(!response.metadata.is_empty());
1261
1262        let token_usage = response.token_usage.unwrap();
1263        assert!(token_usage.prompt_tokens > 0);
1264        assert!(token_usage.completion_tokens > 0);
1265        assert_eq!(
1266            token_usage.total_tokens,
1267            token_usage.prompt_tokens + token_usage.completion_tokens
1268        );
1269    }
1270
1271    #[tokio::test]
1272    async fn test_validate_policies() {
1273        let engine = create_test_engine().await;
1274
1275        let result = engine.validate_policies();
1276        assert!(result.is_ok());
1277    }
1278
1279    #[tokio::test]
1280    async fn test_engine_state_consistency() {
1281        let engine = create_test_engine().await;
1282
1283        // Verify initial state
1284        let initial_stats = engine.get_routing_stats().await;
1285        assert_eq!(initial_stats.total_requests, 0);
1286
1287        // Execute some requests
1288        for i in 0..5 {
1289            let context = create_test_context(
1290                &format!("Test request {}", i),
1291                super::super::error::TaskType::CodeGeneration,
1292            );
1293            let request = create_test_request(&format!("Test request {}", i));
1294
1295            let _response = engine.execute_with_routing(context, request).await.unwrap();
1296        }
1297
1298        // Verify state was updated consistently
1299        let final_stats = engine.get_routing_stats().await;
1300        assert_eq!(final_stats.total_requests, 5);
1301        assert!(final_stats.average_response_time > Duration::from_millis(0));
1302
1303        // Note: Confidence monitoring statistics are only available in enterprise mode
1304        // The trait interface doesn't expose statistics to keep OSS code clean
1305    }
1306}