1use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8
9use super::classifier::TaskClassifier;
10use super::confidence::{ConfidenceMonitorTrait, NoOpConfidenceMonitor};
11use super::config::RoutingConfig;
12use super::decision::{
13 FinishReason, LLMProvider, ModelRequest, ModelResponse, RouteDecision, RoutingContext,
14 RoutingStatistics,
15};
16use super::error::RoutingError;
17use super::policy::PolicyEvaluator;
18use crate::logging::{
19 ModelInteractionType, ModelLogger, RequestData, ResponseData, TokenUsage as LogTokenUsage,
20};
21use crate::models::{ModelCatalog, SlmRunnerError};
22
23#[async_trait]
25pub trait RoutingEngine: Send + Sync {
26 async fn route_request(&self, context: &RoutingContext) -> Result<RouteDecision, RoutingError>;
28
29 async fn execute_with_routing(
31 &self,
32 context: RoutingContext,
33 request: ModelRequest,
34 ) -> Result<ModelResponse, RoutingError>;
35
36 fn validate_policies(&self) -> Result<(), RoutingError>;
38
39 async fn get_routing_stats(&self) -> RoutingStatistics;
41
42 async fn update_config(&self, config: RoutingConfig) -> Result<(), RoutingError>;
44}
45
46pub struct DefaultRoutingEngine {
48 policy_evaluator: Arc<RwLock<PolicyEvaluator>>,
50 model_catalog: Arc<ModelCatalog>,
52 #[allow(dead_code)]
54 confidence_monitor: Arc<RwLock<Box<dyn ConfidenceMonitorTrait>>>,
55 model_logger: Option<Arc<ModelLogger>>,
57 statistics: Arc<RoutingStatistics>,
59 config: Arc<arc_swap::ArcSwap<RoutingConfig>>,
61 llm_clients: Arc<LLMClientPool>,
63 slm_executor: Arc<dyn SlmExecutor>,
65}
66
67#[async_trait]
69pub trait SlmExecutor: Send + Sync {
70 async fn execute(
71 &self,
72 request: &ModelRequest,
73 model: &crate::config::Model,
74 ) -> Result<ModelResponse, SlmRunnerError>;
75}
76
77#[async_trait]
79pub trait LLMClient: Send + Sync {
80 async fn execute_request(
81 &self,
82 request: &ModelRequest,
83 provider: &LLMProvider,
84 ) -> Result<ModelResponse, RoutingError>;
85}
86
87pub struct LLMClientPool {
89 clients: HashMap<String, Box<dyn LLMClient>>,
90}
91
92impl LLMClientPool {
93 pub fn new() -> Self {
94 Self {
95 clients: HashMap::new(),
96 }
97 }
98
99 pub fn register(&mut self, key: &str, client: Box<dyn LLMClient>) {
100 self.clients.insert(key.to_string(), client);
101 }
102
103 pub async fn execute_request(
104 &self,
105 request: &ModelRequest,
106 provider: &LLMProvider,
107 ) -> Result<ModelResponse, RoutingError> {
108 let client_key = match provider {
109 LLMProvider::OpenAI { .. } => "openai",
110 LLMProvider::Anthropic { .. } => "anthropic",
111 LLMProvider::Custom { .. } => "custom",
112 };
113
114 let client =
115 self.clients
116 .get(client_key)
117 .ok_or_else(|| RoutingError::LLMFallbackFailed {
118 provider: provider.to_string(),
119 reason: format!("No client registered for provider '{}'", client_key),
120 })?;
121
122 client.execute_request(request, provider).await
123 }
124}
125
126impl Default for LLMClientPool {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132impl DefaultRoutingEngine {
133 pub async fn new(
135 config: RoutingConfig,
136 model_catalog: ModelCatalog,
137 model_logger: Option<Arc<ModelLogger>>,
138 llm_clients: LLMClientPool,
139 slm_executor: Arc<dyn SlmExecutor>,
140 ) -> Result<Self, RoutingError> {
141 Self::new_with_confidence_monitor(
142 config,
143 model_catalog,
144 model_logger,
145 Box::new(NoOpConfidenceMonitor),
146 llm_clients,
147 slm_executor,
148 )
149 .await
150 }
151
152 pub async fn new_with_confidence_monitor(
154 config: RoutingConfig,
155 model_catalog: ModelCatalog,
156 model_logger: Option<Arc<ModelLogger>>,
157 confidence_monitor: Box<dyn ConfidenceMonitorTrait>,
158 llm_clients: LLMClientPool,
159 slm_executor: Arc<dyn SlmExecutor>,
160 ) -> Result<Self, RoutingError> {
161 let classifier = TaskClassifier::new(config.classification.clone())?;
163
164 let policy_evaluator =
166 PolicyEvaluator::new(config.policy.clone(), classifier, model_catalog.clone())?;
167
168 let llm_clients = Arc::new(llm_clients);
169
170 let engine = Self {
171 policy_evaluator: Arc::new(RwLock::new(policy_evaluator)),
172 model_catalog: Arc::new(model_catalog),
173 confidence_monitor: Arc::new(RwLock::new(confidence_monitor)),
174 model_logger,
175 statistics: Arc::new(RoutingStatistics::default()),
176 config: Arc::new(arc_swap::ArcSwap::from_pointee(config)),
177 llm_clients,
178 slm_executor,
179 };
180
181 Ok(engine)
182 }
183
184 async fn execute_slm_route(
186 &self,
187 context: &RoutingContext,
188 request: &ModelRequest,
189 model_id: &str,
190 monitoring_level: &super::decision::MonitoringLevel,
191 fallback_on_failure: bool,
192 ) -> Result<ModelResponse, RoutingError> {
193 let model = self.model_catalog.get_model(model_id).ok_or_else(|| {
195 RoutingError::NoSuitableModel {
196 task_type: context.task_type.clone(),
197 }
198 })?;
199
200 let slm_result = self.slm_executor.execute(request, model).await;
202
203 match slm_result {
204 Ok(response) => {
205 let should_fallback = match monitoring_level {
207 super::decision::MonitoringLevel::None => false,
208 super::decision::MonitoringLevel::Basic => {
209 response.finish_reason != FinishReason::Stop
211 }
212 super::decision::MonitoringLevel::Enhanced {
213 confidence_threshold,
214 } => {
215 let confidence_score = response.confidence_score.unwrap_or(0.5);
217 confidence_score < *confidence_threshold
218 }
219 };
220
221 if should_fallback && fallback_on_failure {
222 tracing::warn!(
223 "SLM response did not meet confidence threshold, falling back to LLM"
224 );
225
226 self.fallback_to_llm(request, "Low confidence SLM response")
227 .await
228 } else {
229 Ok(response)
230 }
231 }
232 Err(e) => {
233 if fallback_on_failure {
234 tracing::error!("SLM execution failed, falling back to LLM: {}", e);
235
236 self.fallback_to_llm(request, &format!("SLM execution failed: {}", e))
237 .await
238 } else {
239 Err(RoutingError::ModelExecutionFailed {
240 model_id: model_id.to_string(),
241 reason: e.to_string(),
242 })
243 }
244 }
245 }
246 }
247
248 async fn fallback_to_llm(
250 &self,
251 request: &ModelRequest,
252 reason: &str,
253 ) -> Result<ModelResponse, RoutingError> {
254 self.statistics.record_fallback();
255 self.execute_llm_fallback(request, reason).await
256 }
257
258 async fn execute_llm_fallback(
260 &self,
261 request: &ModelRequest,
262 _reason: &str,
263 ) -> Result<ModelResponse, RoutingError> {
264 let config = self.config.load();
265 let fallback_config = &config.policy.fallback_config;
266
267 if !fallback_config.enabled {
268 return Err(RoutingError::LLMFallbackFailed {
269 provider: "disabled".to_string(),
270 reason: "LLM fallback is disabled".to_string(),
271 });
272 }
273
274 let provider = LLMProvider::OpenAI { model: None };
276
277 match self.llm_clients.execute_request(request, &provider).await {
278 Ok(response) => Ok(response),
279 Err(e) => Err(RoutingError::LLMFallbackFailed {
280 provider: provider.to_string(),
281 reason: e.to_string(),
282 }),
283 }
284 }
285
286 async fn log_routing_execution(
288 &self,
289 context: &RoutingContext,
290 decision: &RouteDecision,
291 request: &ModelRequest,
292 response: &ModelResponse,
293 execution_time: Duration,
294 error: Option<&RoutingError>,
295 ) {
296 if let Some(ref logger) = self.model_logger {
297 let model_name = match decision {
298 RouteDecision::UseSLM { model_id, .. } => model_id.clone(),
299 RouteDecision::UseLLM { provider, .. } => provider.to_string(),
300 RouteDecision::Deny { .. } => "denied".to_string(),
301 };
302
303 let request_data = RequestData {
304 prompt: request.prompt.clone(),
305 tool_name: None,
306 tool_arguments: None,
307 parameters: {
308 let mut params = HashMap::new();
309 params.insert(
310 "routing_decision".to_string(),
311 serde_json::Value::String(format!("{:?}", decision)),
312 );
313 params.insert(
314 "task_type".to_string(),
315 serde_json::Value::String(context.task_type.to_string()),
316 );
317 params
318 },
319 };
320
321 let response_data = ResponseData {
322 content: response.content.clone(),
323 tool_result: None,
324 confidence: response.confidence_score,
325 metadata: response.metadata.clone(),
326 };
327
328 let metadata = {
329 let mut meta = HashMap::new();
330 meta.insert("routing_engine".to_string(), "default".to_string());
331 meta.insert("agent_id".to_string(), context.agent_id.to_string());
332 meta.insert("request_id".to_string(), context.request_id.clone());
333 meta
334 };
335
336 if let Err(e) = logger
337 .log_interaction(
338 context.agent_id,
339 ModelInteractionType::Completion,
340 &model_name,
341 request_data,
342 response_data,
343 execution_time,
344 metadata,
345 response.token_usage.as_ref().map(|t| LogTokenUsage {
346 input_tokens: t.prompt_tokens,
347 output_tokens: t.completion_tokens,
348 total_tokens: t.total_tokens,
349 }),
350 error.map(|e| e.to_string()),
351 )
352 .await
353 {
354 tracing::warn!("Failed to log routing execution: {}", e);
355 }
356 }
357 }
358}
359
360#[async_trait]
361impl RoutingEngine for DefaultRoutingEngine {
362 async fn route_request(&self, context: &RoutingContext) -> Result<RouteDecision, RoutingError> {
363 let policy_result = self
364 .policy_evaluator
365 .read()
366 .await
367 .evaluate_policies(context)
368 .await?;
369
370 tracing::debug!(
371 "Routing decision for agent {}: {:?} (rule: {:?})",
372 context.agent_id,
373 policy_result.decision,
374 policy_result.matched_rule
375 );
376
377 Ok(policy_result.decision)
378 }
379
380 async fn execute_with_routing(
381 &self,
382 context: RoutingContext,
383 request: ModelRequest,
384 ) -> Result<ModelResponse, RoutingError> {
385 let start_time = Instant::now();
386 let route_decision = self.route_request(&context).await?;
387
388 let result = match &route_decision {
389 RouteDecision::UseSLM {
390 model_id,
391 monitoring,
392 fallback_on_failure,
393 sandbox_tier: _,
394 } => {
395 self.execute_slm_route(
396 &context,
397 &request,
398 model_id,
399 monitoring,
400 *fallback_on_failure,
401 )
402 .await
403 }
404 RouteDecision::UseLLM {
405 provider,
406 reason,
407 sandbox_tier: _,
408 } => {
409 tracing::info!("Routing to LLM: {}", reason);
410 self.llm_clients.execute_request(&request, provider).await
411 }
412 RouteDecision::Deny {
413 reason,
414 policy_violated,
415 } => {
416 return Err(RoutingError::RoutingDenied {
417 policy: policy_violated.clone(),
418 reason: reason.clone(),
419 });
420 }
421 };
422
423 let execution_time = start_time.elapsed();
424
425 self.statistics
427 .record_request(&route_decision, execution_time, result.is_ok());
428
429 if let Ok(ref response) = result {
430 if let Some(confidence) = response.confidence_score {
431 self.statistics.add_confidence_score(confidence);
432 }
433 }
434
435 match &result {
437 Ok(response) => {
438 self.log_routing_execution(
439 &context,
440 &route_decision,
441 &request,
442 response,
443 execution_time,
444 None,
445 )
446 .await;
447 }
448 Err(error) => {
449 let dummy_response = ModelResponse {
451 content: "Error occurred".to_string(),
452 finish_reason: FinishReason::Error,
453 token_usage: None,
454 metadata: HashMap::new(),
455 confidence_score: Some(0.0),
456 };
457
458 self.log_routing_execution(
459 &context,
460 &route_decision,
461 &request,
462 &dummy_response,
463 execution_time,
464 Some(error),
465 )
466 .await;
467 }
468 }
469
470 result
471 }
472
473 fn validate_policies(&self) -> Result<(), RoutingError> {
474 Ok(())
476 }
477
478 async fn get_routing_stats(&self) -> RoutingStatistics {
479 (*self.statistics).clone()
480 }
481
482 async fn update_config(&self, config: RoutingConfig) -> Result<(), RoutingError> {
483 self.policy_evaluator
485 .write()
486 .await
487 .update_config(config.policy.clone())?;
488
489 self.config.store(Arc::new(config));
491
492 Ok(())
493 }
494}
495
496#[cfg(test)]
497use super::decision::TokenUsage;
498
499#[cfg(test)]
500struct MockSlmExecutor;
501
502#[cfg(test)]
503#[async_trait]
504impl SlmExecutor for MockSlmExecutor {
505 async fn execute(
506 &self,
507 request: &ModelRequest,
508 model: &crate::config::Model,
509 ) -> Result<ModelResponse, SlmRunnerError> {
510 tokio::time::sleep(Duration::from_millis(10)).await;
512
513 if request.prompt.contains("error") {
514 return Err(SlmRunnerError::ExecutionFailed {
515 reason: "Simulated execution error".to_string(),
516 });
517 }
518
519 Ok(ModelResponse {
520 content: format!("SLM ({}) response: {}", model.name, request.prompt),
521 finish_reason: FinishReason::Stop,
522 token_usage: Some(TokenUsage {
523 prompt_tokens: request.prompt.len() as u32 / 4,
524 completion_tokens: 30,
525 total_tokens: (request.prompt.len() as u32 / 4) + 30,
526 }),
527 metadata: {
528 let mut meta = HashMap::new();
529 meta.insert(
530 "model_id".to_string(),
531 serde_json::Value::String(model.id.clone()),
532 );
533 meta.insert(
534 "provider".to_string(),
535 serde_json::Value::String(format!("{:?}", model.provider)),
536 );
537 meta
538 },
539 confidence_score: Some(0.8 + (request.prompt.len() % 20) as f64 / 100.0),
540 })
541 }
542}
543
544#[cfg(test)]
545#[derive(Debug)]
546struct MockLLMClient;
547
548#[cfg(test)]
549#[async_trait]
550impl LLMClient for MockLLMClient {
551 async fn execute_request(
552 &self,
553 request: &ModelRequest,
554 provider: &LLMProvider,
555 ) -> Result<ModelResponse, RoutingError> {
556 tokio::time::sleep(Duration::from_millis(100)).await;
557
558 Ok(ModelResponse {
559 content: format!("LLM response to: {}", request.prompt),
560 finish_reason: FinishReason::Stop,
561 token_usage: Some(TokenUsage {
562 prompt_tokens: request.prompt.len() as u32 / 4,
563 completion_tokens: 50,
564 total_tokens: (request.prompt.len() as u32 / 4) + 50,
565 }),
566 metadata: {
567 let mut meta = HashMap::new();
568 meta.insert(
569 "provider".to_string(),
570 serde_json::Value::String(provider.to_string()),
571 );
572 meta.insert("mock".to_string(), serde_json::Value::Bool(true));
573 meta
574 },
575 confidence_score: Some(0.95),
576 })
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583 use crate::config::{
584 Model, ModelAllowListConfig, ModelProvider, ModelResourceRequirements, SandboxProfile, Slm,
585 };
586 use crate::types::AgentId;
587 use std::collections::HashMap;
588 use std::path::PathBuf;
589
590 fn create_mock_pool() -> LLMClientPool {
591 let mut pool = LLMClientPool::new();
592 pool.register("openai", Box::new(MockLLMClient));
593 pool.register("anthropic", Box::new(MockLLMClient));
594 pool.register("custom", Box::new(MockLLMClient));
595 pool
596 }
597
598 fn create_mock_slm() -> Arc<dyn SlmExecutor> {
599 Arc::new(MockSlmExecutor)
600 }
601
602 async fn create_test_engine() -> DefaultRoutingEngine {
603 let global_models = vec![
604 Model {
605 id: "test-slm".to_string(),
606 name: "Test SLM".to_string(),
607 provider: ModelProvider::LocalFile {
608 file_path: PathBuf::from("/tmp/test.gguf"),
609 },
610 capabilities: vec![
611 crate::config::ModelCapability::TextGeneration,
612 crate::config::ModelCapability::CodeGeneration,
613 ],
614 resource_requirements: ModelResourceRequirements {
615 min_memory_mb: 1024,
616 preferred_cpu_cores: 2.0,
617 gpu_requirements: None,
618 },
619 },
620 Model {
621 id: "error-slm".to_string(),
622 name: "Error SLM".to_string(),
623 provider: ModelProvider::LocalFile {
624 file_path: PathBuf::from("/tmp/error.gguf"),
625 },
626 capabilities: vec![crate::config::ModelCapability::TextGeneration],
627 resource_requirements: ModelResourceRequirements {
628 min_memory_mb: 512,
629 preferred_cpu_cores: 1.0,
630 gpu_requirements: None,
631 },
632 },
633 ];
634
635 let mut sandbox_profiles = HashMap::new();
636 sandbox_profiles.insert("default".to_string(), SandboxProfile::secure_default());
637
638 let slm_config = Slm {
639 enabled: true,
640 model_allow_lists: ModelAllowListConfig {
641 global_models,
642 agent_model_maps: HashMap::new(),
643 allow_runtime_overrides: false,
644 },
645 sandbox_profiles,
646 default_sandbox_profile: "default".to_string(),
647 };
648
649 let model_catalog = ModelCatalog::new(slm_config).unwrap();
650 let mut config = RoutingConfig::default();
651
652 config.policy.rules.push(super::super::config::RoutingRule {
654 name: "slm_code_rule".to_string(),
655 priority: 100,
656 conditions: super::super::config::RoutingConditions {
657 task_types: Some(vec![super::super::error::TaskType::CodeGeneration]),
658 agent_ids: None,
659 resource_constraints: None,
660 security_level: None,
661 custom_conditions: None,
662 },
663 action: super::super::config::RouteAction::UseSLM {
664 model_preference: super::super::config::ModelPreference::BestAvailable,
665 monitoring_level: crate::routing::MonitoringLevel::Basic,
666 fallback_on_low_confidence: true,
667 confidence_threshold: Some(0.8),
668 },
669 override_allowed: true,
670 action_extension: None,
671 });
672
673 DefaultRoutingEngine::new(
674 config,
675 model_catalog,
676 None,
677 create_mock_pool(),
678 create_mock_slm(),
679 )
680 .await
681 .unwrap()
682 }
683
684 async fn create_test_engine_with_logger() -> DefaultRoutingEngine {
685 let logger =
686 ModelLogger::new(super::super::super::logging::LoggingConfig::default(), None).unwrap();
687
688 let global_models = vec![Model {
689 id: "test-slm".to_string(),
690 name: "Test SLM".to_string(),
691 provider: ModelProvider::LocalFile {
692 file_path: PathBuf::from("/tmp/test.gguf"),
693 },
694 capabilities: vec![crate::config::ModelCapability::TextGeneration],
695 resource_requirements: ModelResourceRequirements {
696 min_memory_mb: 1024,
697 preferred_cpu_cores: 2.0,
698 gpu_requirements: None,
699 },
700 }];
701
702 let mut sandbox_profiles = HashMap::new();
703 sandbox_profiles.insert("default".to_string(), SandboxProfile::secure_default());
704
705 let slm_config = Slm {
706 enabled: true,
707 model_allow_lists: ModelAllowListConfig {
708 global_models,
709 agent_model_maps: HashMap::new(),
710 allow_runtime_overrides: false,
711 },
712 sandbox_profiles,
713 default_sandbox_profile: "default".to_string(),
714 };
715
716 let model_catalog = ModelCatalog::new(slm_config).unwrap();
717 let config = RoutingConfig::default();
718
719 DefaultRoutingEngine::new(
720 config,
721 model_catalog,
722 Some(Arc::new(logger)),
723 create_mock_pool(),
724 create_mock_slm(),
725 )
726 .await
727 .unwrap()
728 }
729
730 fn create_test_request(prompt: &str) -> ModelRequest {
731 ModelRequest::from_task(prompt.to_string())
732 }
733
734 fn create_test_context(
735 prompt: &str,
736 task_type: super::super::error::TaskType,
737 ) -> RoutingContext {
738 RoutingContext::new(AgentId::new(), task_type, prompt.to_string())
739 }
740
741 #[tokio::test]
742 async fn test_routing_engine_creation() {
743 let engine = create_test_engine().await;
744
745 let stats = engine.get_routing_stats().await;
747 assert_eq!(stats.total_requests(), 0);
748
749 assert!(engine.validate_policies().is_ok());
751 }
752
753 #[tokio::test]
754 async fn test_routing_engine_with_logger() {
755 let engine = create_test_engine_with_logger().await;
756
757 assert!(engine.model_logger.is_some());
759
760 let stats = engine.get_routing_stats().await;
761 assert_eq!(stats.total_requests(), 0);
762 }
763
764 #[tokio::test]
765 async fn test_routing_engine_basic_flow() {
766 let engine = create_test_engine().await;
767
768 let context = create_test_context(
769 "Write a hello world function",
770 super::super::error::TaskType::CodeGeneration,
771 );
772
773 let decision = engine.route_request(&context).await.unwrap();
774
775 match decision {
777 RouteDecision::UseSLM { .. } | RouteDecision::UseLLM { .. } => {
778 }
780 RouteDecision::Deny { .. } => {
781 panic!("Should not deny basic request");
782 }
783 }
784 }
785
786 #[tokio::test]
787 async fn test_execute_with_routing_slm_success() {
788 let engine = create_test_engine().await;
789
790 let context = create_test_context(
791 "Write a hello world function",
792 super::super::error::TaskType::CodeGeneration,
793 );
794
795 let request = create_test_request("Write a hello world function");
796
797 let response = engine.execute_with_routing(context, request).await.unwrap();
798
799 assert!(!response.content.is_empty());
800 assert!(response.confidence_score.is_some());
801 assert_eq!(response.finish_reason, FinishReason::Stop);
802
803 let stats = engine.get_routing_stats().await;
805 assert!(stats.total_requests() > 0);
806 }
807
808 #[tokio::test]
809 async fn test_execute_with_routing_slm_error_fallback() {
810 let engine = create_test_engine().await;
811
812 let context = create_test_context(
813 "This should trigger an error in SLM",
814 super::super::error::TaskType::CodeGeneration,
815 );
816
817 let request = create_test_request("error trigger");
818
819 let response = engine.execute_with_routing(context, request).await.unwrap();
820
821 assert!(!response.content.is_empty());
823 assert!(response.content.contains("LLM response"));
824
825 let stats = engine.get_routing_stats().await;
826 assert!(stats.fallback_routes() > 0);
827 }
828
829 #[tokio::test]
830 async fn test_slm_execution_success() {
831 let engine = create_test_engine().await;
832
833 let context =
834 create_test_context("Test prompt", super::super::error::TaskType::CodeGeneration);
835
836 let request = create_test_request("Test prompt");
837
838 let response = engine
839 .execute_slm_route(
840 &context,
841 &request,
842 "test-slm",
843 &crate::routing::MonitoringLevel::Basic,
844 true,
845 )
846 .await
847 .unwrap();
848
849 assert!(!response.content.is_empty());
850 assert!(response.content.contains("Test SLM"));
851 assert!(response.confidence_score.is_some());
852 }
853
854 #[tokio::test]
855 async fn test_slm_execution_with_enhanced_monitoring() {
856 let engine = create_test_engine().await;
857
858 let context = create_test_context(
859 "Test prompt with monitoring",
860 super::super::error::TaskType::CodeGeneration,
861 );
862
863 let request = create_test_request("Test prompt with monitoring");
864
865 let response = engine
866 .execute_slm_route(
867 &context,
868 &request,
869 "test-slm",
870 &crate::routing::MonitoringLevel::Enhanced {
871 confidence_threshold: 0.9,
872 },
873 true,
874 )
875 .await
876 .unwrap();
877
878 assert!(!response.content.is_empty());
880 }
881
882 #[tokio::test]
883 async fn test_slm_execution_no_monitoring() {
884 let engine = create_test_engine().await;
885
886 let context = create_test_context(
887 "Test prompt no monitoring",
888 super::super::error::TaskType::CodeGeneration,
889 );
890
891 let request = create_test_request("Test prompt no monitoring");
892
893 let response = engine
894 .execute_slm_route(
895 &context,
896 &request,
897 "test-slm",
898 &crate::routing::MonitoringLevel::None,
899 true,
900 )
901 .await
902 .unwrap();
903
904 assert!(!response.content.is_empty());
906 assert!(response.content.contains("Test SLM"));
907 }
908
909 #[tokio::test]
910 async fn test_slm_execution_error_no_fallback() {
911 let engine = create_test_engine().await;
912
913 let context = create_test_context(
914 "error trigger",
915 super::super::error::TaskType::CodeGeneration,
916 );
917
918 let request = create_test_request("error trigger");
919
920 let result = engine
921 .execute_slm_route(
922 &context,
923 &request,
924 "test-slm",
925 &crate::routing::MonitoringLevel::Basic,
926 false, )
928 .await;
929
930 assert!(result.is_err());
931 assert!(matches!(
932 result.unwrap_err(),
933 RoutingError::ModelExecutionFailed { .. }
934 ));
935 }
936
937 #[tokio::test]
938 async fn test_slm_execution_nonexistent_model() {
939 let engine = create_test_engine().await;
940
941 let context =
942 create_test_context("Test prompt", super::super::error::TaskType::CodeGeneration);
943
944 let request = create_test_request("Test prompt");
945
946 let result = engine
947 .execute_slm_route(
948 &context,
949 &request,
950 "nonexistent-model",
951 &crate::routing::MonitoringLevel::Basic,
952 true,
953 )
954 .await;
955
956 assert!(result.is_err());
957 assert!(matches!(
958 result.unwrap_err(),
959 RoutingError::NoSuitableModel { .. }
960 ));
961 }
962
963 #[tokio::test]
964 async fn test_llm_fallback_execution() {
965 let engine = create_test_engine().await;
966
967 let request = create_test_request("Test LLM fallback");
968
969 let response = engine
970 .execute_llm_fallback(&request, "Test reason")
971 .await
972 .unwrap();
973
974 assert!(!response.content.is_empty());
975 assert!(response.content.contains("LLM response"));
976 assert_eq!(response.finish_reason, FinishReason::Stop);
977 assert!(response.confidence_score.is_some());
978 }
979
980 #[tokio::test]
981 async fn test_llm_fallback_disabled() {
982 let engine = create_test_engine().await;
983
984 {
986 let mut config = (*engine.config.load().as_ref()).clone();
987 config.policy.fallback_config.enabled = false;
988 engine.config.store(Arc::new(config));
989 }
990
991 let request = create_test_request("Test disabled fallback");
992
993 let result = engine.execute_llm_fallback(&request, "Test reason").await;
994
995 assert!(result.is_err());
996 assert!(matches!(
997 result.unwrap_err(),
998 RoutingError::LLMFallbackFailed { .. }
999 ));
1000 }
1001
1002 #[tokio::test]
1003 async fn test_llm_client_pool() {
1004 let pool = create_mock_pool();
1005
1006 let request = create_test_request("Test LLM client");
1007
1008 let openai_response = pool
1010 .execute_request(
1011 &request,
1012 &LLMProvider::OpenAI {
1013 model: Some("gpt-3.5-turbo".to_string()),
1014 },
1015 )
1016 .await
1017 .unwrap();
1018
1019 assert!(!openai_response.content.is_empty());
1020 assert!(openai_response.metadata.contains_key("provider"));
1021
1022 let anthropic_response = pool
1024 .execute_request(
1025 &request,
1026 &LLMProvider::Anthropic {
1027 model: Some("claude-3".to_string()),
1028 },
1029 )
1030 .await
1031 .unwrap();
1032
1033 assert!(!anthropic_response.content.is_empty());
1034 assert!(anthropic_response.metadata.contains_key("provider"));
1035
1036 let custom_response = pool
1038 .execute_request(
1039 &request,
1040 &LLMProvider::Custom {
1041 endpoint: "http://localhost:8080".to_string(),
1042 model: None,
1043 },
1044 )
1045 .await
1046 .unwrap();
1047
1048 assert!(!custom_response.content.is_empty());
1049 }
1050
1051 #[tokio::test]
1052 async fn test_routing_statistics_tracking() {
1053 let engine = create_test_engine().await;
1054
1055 let context1 = create_test_context("Test 1", super::super::error::TaskType::CodeGeneration);
1057 let request1 = create_test_request("Test 1");
1058
1059 let _response1 = engine
1060 .execute_with_routing(context1, request1)
1061 .await
1062 .unwrap();
1063
1064 let context2 = create_test_context(
1065 "error trigger",
1066 super::super::error::TaskType::CodeGeneration,
1067 );
1068 let request2 = create_test_request("error trigger");
1069
1070 let _response2 = engine
1071 .execute_with_routing(context2, request2)
1072 .await
1073 .unwrap();
1074
1075 let stats = engine.get_routing_stats().await;
1076
1077 assert!(stats.total_requests() > 0);
1078 assert!(stats.fallback_routes() > 0); assert!(stats.average_response_time() > Duration::ZERO);
1080 }
1081
1082 #[tokio::test]
1083 async fn test_config_update() {
1084 let engine = create_test_engine().await;
1085
1086 let mut new_config = RoutingConfig::default();
1087 new_config.policy.fallback_config.enabled = false;
1088
1089 let result = engine.update_config(new_config.clone()).await;
1090 assert!(result.is_ok());
1091
1092 let updated_config = engine.config.load();
1094 assert!(!updated_config.policy.fallback_config.enabled);
1095 }
1096
1097 #[tokio::test]
1098 async fn test_routing_with_deny_decision() {
1099 let engine = create_test_engine().await;
1100
1101 let context = create_test_context(
1104 "forbidden operation",
1105 super::super::error::TaskType::Custom("forbidden".to_string()),
1106 );
1107
1108 let request = create_test_request("forbidden operation");
1109
1110 let result = engine.execute_with_routing(context, request).await;
1112
1113 match result {
1115 Ok(response) => {
1116 assert!(!response.content.is_empty());
1117 }
1118 Err(RoutingError::RoutingDenied { .. }) => {
1119 }
1121 Err(e) => {
1122 panic!("Unexpected error: {:?}", e);
1123 }
1124 }
1125 }
1126
1127 #[tokio::test]
1128 async fn test_logging_integration() {
1129 let engine = create_test_engine_with_logger().await;
1130
1131 let context = create_test_context(
1132 "Test logging integration",
1133 super::super::error::TaskType::CodeGeneration,
1134 );
1135
1136 let request = create_test_request("Test logging integration");
1137
1138 let response = engine.execute_with_routing(context, request).await.unwrap();
1139
1140 assert!(!response.content.is_empty());
1141 }
1143
1144 #[tokio::test]
1145 async fn test_confidence_monitoring_integration() {
1146 let engine = create_test_engine().await;
1147
1148 let context = create_test_context(
1149 "Test confidence monitoring",
1150 super::super::error::TaskType::CodeGeneration,
1151 );
1152
1153 let request = create_test_request("Test confidence monitoring");
1154
1155 let response = engine.execute_with_routing(context, request).await.unwrap();
1156
1157 assert!(!response.content.is_empty());
1158 assert!(response.confidence_score.is_some());
1159 }
1160
1161 #[tokio::test]
1162 async fn test_policy_evaluation_integration() {
1163 let engine = create_test_engine().await;
1164
1165 let task_types = vec![
1167 super::super::error::TaskType::CodeGeneration,
1168 super::super::error::TaskType::CodeGeneration,
1169 super::super::error::TaskType::Analysis,
1170 super::super::error::TaskType::Reasoning,
1171 ];
1172
1173 for task_type in task_types {
1174 let context = create_test_context("Test policy evaluation", task_type.clone());
1175
1176 let decision = engine.route_request(&context).await.unwrap();
1177
1178 match decision {
1180 RouteDecision::UseSLM { .. }
1181 | RouteDecision::UseLLM { .. }
1182 | RouteDecision::Deny { .. } => {
1183 }
1185 }
1186 }
1187 }
1188
1189 #[tokio::test]
1190 async fn test_concurrent_routing_requests() {
1191 let engine = Arc::new(create_test_engine().await);
1192
1193 let mut handles = Vec::new();
1194
1195 for i in 0..10 {
1197 let engine_clone = Arc::clone(&engine);
1198 let handle = tokio::spawn(async move {
1199 let context = create_test_context(
1200 &format!("Concurrent request {}", i),
1201 super::super::error::TaskType::CodeGeneration,
1202 );
1203
1204 let request = create_test_request(&format!("Concurrent request {}", i));
1205
1206 engine_clone.execute_with_routing(context, request).await
1207 });
1208 handles.push(handle);
1209 }
1210
1211 let results = futures::future::join_all(handles).await;
1213
1214 for result in results {
1216 let response = result.unwrap().unwrap();
1217 assert!(!response.content.is_empty());
1218 }
1219
1220 let stats = engine.get_routing_stats().await;
1222 assert_eq!(stats.total_requests(), 10);
1223 }
1224
1225 #[tokio::test]
1226 async fn test_error_handling_and_recovery() {
1227 let engine = create_test_engine().await;
1228
1229 let error_scenarios = vec![
1231 ("error trigger", "Should trigger SLM execution error"),
1232 ("", "Empty prompt"),
1233 (" ", "Whitespace-only prompt"),
1234 ];
1235
1236 for (prompt, description) in error_scenarios {
1237 let context =
1238 create_test_context(prompt, super::super::error::TaskType::CodeGeneration);
1239 let request = create_test_request(prompt);
1240
1241 let result = engine.execute_with_routing(context, request).await;
1242
1243 match result {
1244 Ok(response) => {
1245 assert!(!response.content.is_empty(), "Failed for: {}", description);
1247 }
1248 Err(e) => {
1249 tracing::info!("Expected error for '{}': {:?}", description, e);
1251 }
1252 }
1253 }
1254 }
1255
1256 #[tokio::test]
1257 async fn test_model_metadata_and_token_usage() {
1258 let engine = create_test_engine().await;
1259
1260 let context = create_test_context(
1261 "Test metadata and token usage",
1262 super::super::error::TaskType::CodeGeneration,
1263 );
1264
1265 let request = create_test_request("Test metadata and token usage");
1266
1267 let response = engine.execute_with_routing(context, request).await.unwrap();
1268
1269 assert!(!response.content.is_empty());
1271 assert!(response.token_usage.is_some());
1272 assert!(!response.metadata.is_empty());
1273
1274 let token_usage = response.token_usage.unwrap();
1275 assert!(token_usage.prompt_tokens > 0);
1276 assert!(token_usage.completion_tokens > 0);
1277 assert_eq!(
1278 token_usage.total_tokens,
1279 token_usage.prompt_tokens + token_usage.completion_tokens
1280 );
1281 }
1282
1283 #[tokio::test]
1284 async fn test_validate_policies() {
1285 let engine = create_test_engine().await;
1286
1287 let result = engine.validate_policies();
1288 assert!(result.is_ok());
1289 }
1290
1291 #[tokio::test]
1292 async fn test_engine_state_consistency() {
1293 let engine = create_test_engine().await;
1294
1295 let initial_stats = engine.get_routing_stats().await;
1297 assert_eq!(initial_stats.total_requests(), 0);
1298
1299 for i in 0..5 {
1301 let context = create_test_context(
1302 &format!("Test request {}", i),
1303 super::super::error::TaskType::CodeGeneration,
1304 );
1305 let request = create_test_request(&format!("Test request {}", i));
1306
1307 let _response = engine.execute_with_routing(context, request).await.unwrap();
1308 }
1309
1310 let final_stats = engine.get_routing_stats().await;
1312 assert_eq!(final_stats.total_requests(), 5);
1313 assert!(final_stats.average_response_time() > Duration::ZERO);
1314 }
1315
1316 #[tokio::test]
1317 async fn test_routing_statistics_concurrent_updates() {
1318 use super::super::decision::MonitoringLevel;
1319 use std::sync::Arc;
1320
1321 let stats = Arc::new(RoutingStatistics::default());
1322 let mut handles = vec![];
1323
1324 for _ in 0..100 {
1325 let stats_clone = Arc::clone(&stats);
1326 handles.push(tokio::spawn(async move {
1327 stats_clone.record_request(
1328 &RouteDecision::UseSLM {
1329 model_id: "test".to_string(),
1330 monitoring: MonitoringLevel::None,
1331 fallback_on_failure: false,
1332 sandbox_tier: None,
1333 },
1334 Duration::from_millis(10),
1335 true,
1336 );
1337 }));
1338 }
1339
1340 for handle in handles {
1341 handle.await.unwrap();
1342 }
1343
1344 assert_eq!(stats.total_requests(), 100);
1345 assert_eq!(stats.slm_routes(), 100);
1346 }
1347
1348 #[tokio::test]
1349 async fn test_fallback_counted_once_per_request() {
1350 let engine = create_test_engine().await;
1351
1352 let context =
1354 create_test_context("Test prompt", super::super::error::TaskType::CodeGeneration);
1355 let request = create_test_request("Test prompt");
1356 let _response = engine
1357 .execute_slm_route(
1358 &context,
1359 &request,
1360 "test-slm",
1361 &crate::routing::MonitoringLevel::Enhanced {
1362 confidence_threshold: 1.0,
1363 },
1364 true,
1365 )
1366 .await
1367 .unwrap();
1368
1369 let context2 = create_test_context(
1371 "error trigger",
1372 super::super::error::TaskType::CodeGeneration,
1373 );
1374 let request2 = create_test_request("error trigger");
1375 let _response2 = engine
1376 .execute_slm_route(
1377 &context2,
1378 &request2,
1379 "test-slm",
1380 &crate::routing::MonitoringLevel::Basic,
1381 true,
1382 )
1383 .await
1384 .unwrap();
1385
1386 let stats = engine.get_routing_stats().await;
1387 assert_eq!(stats.fallback_routes(), 2);
1388 }
1389
1390 #[tokio::test]
1391 async fn test_llm_client_pool_register_and_execute() {
1392 let mut pool = LLMClientPool::new();
1393 pool.register("openai", Box::new(MockLLMClient));
1394
1395 let request = create_test_request("Test");
1396 let provider = LLMProvider::OpenAI {
1397 model: Some("gpt-4".to_string()),
1398 };
1399 let response = pool.execute_request(&request, &provider).await.unwrap();
1400 assert!(!response.content.is_empty());
1401 }
1402
1403 #[tokio::test]
1404 async fn test_llm_client_pool_no_client_registered() {
1405 let pool = LLMClientPool::new();
1406 let request = create_test_request("Test");
1407 let provider = LLMProvider::OpenAI { model: None };
1408 let result = pool.execute_request(&request, &provider).await;
1409 assert!(result.is_err());
1410 }
1411}