1use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8
9use super::classifier::TaskClassifier;
10use super::confidence::{ConfidenceMonitorTrait, NoOpConfidenceMonitor};
11use super::config::RoutingConfig;
12use super::decision::{
13 FinishReason, LLMProvider, ModelRequest, ModelResponse, RouteDecision, RoutingContext,
14 RoutingStatistics, TokenUsage,
15};
16use super::error::RoutingError;
17use super::policy::PolicyEvaluator;
18use crate::logging::{
19 ModelInteractionType, ModelLogger, RequestData, ResponseData, TokenUsage as LogTokenUsage,
20};
21use crate::models::{ModelCatalog, SlmRunnerError};
22
23#[async_trait]
25pub trait RoutingEngine: Send + Sync {
26 async fn route_request(&self, context: &RoutingContext) -> Result<RouteDecision, RoutingError>;
28
29 async fn execute_with_routing(
31 &self,
32 context: RoutingContext,
33 request: ModelRequest,
34 ) -> Result<ModelResponse, RoutingError>;
35
36 fn validate_policies(&self) -> Result<(), RoutingError>;
38
39 async fn get_routing_stats(&self) -> RoutingStatistics;
41
42 async fn update_config(&self, config: RoutingConfig) -> Result<(), RoutingError>;
44}
45
46pub struct DefaultRoutingEngine {
48 policy_evaluator: Arc<RwLock<PolicyEvaluator>>,
50 model_catalog: Arc<ModelCatalog>,
52 #[allow(dead_code)]
54 confidence_monitor: Arc<RwLock<Box<dyn ConfidenceMonitorTrait>>>,
55 model_logger: Option<Arc<ModelLogger>>,
57 statistics: Arc<RwLock<RoutingStatistics>>,
59 config: Arc<RwLock<RoutingConfig>>,
61 llm_clients: Arc<LLMClientPool>,
63}
64
65struct LLMClientPool {
67 clients: HashMap<String, Box<dyn LLMClient>>,
68}
69
70#[async_trait]
72trait LLMClient: Send + Sync {
73 async fn execute_request(
74 &self,
75 request: &ModelRequest,
76 provider: &LLMProvider,
77 ) -> Result<ModelResponse, RoutingError>;
78}
79
80#[derive(Debug)]
82struct MockLLMClient;
83
84#[async_trait]
85impl LLMClient for MockLLMClient {
86 async fn execute_request(
87 &self,
88 request: &ModelRequest,
89 provider: &LLMProvider,
90 ) -> Result<ModelResponse, RoutingError> {
91 tokio::time::sleep(Duration::from_millis(100)).await;
93
94 Ok(ModelResponse {
95 content: format!("LLM response to: {}", request.prompt),
96 finish_reason: FinishReason::Stop,
97 token_usage: Some(TokenUsage {
98 prompt_tokens: request.prompt.len() as u32 / 4,
99 completion_tokens: 50,
100 total_tokens: (request.prompt.len() as u32 / 4) + 50,
101 }),
102 metadata: {
103 let mut meta = HashMap::new();
104 meta.insert(
105 "provider".to_string(),
106 serde_json::Value::String(provider.to_string()),
107 );
108 meta.insert("mock".to_string(), serde_json::Value::Bool(true));
109 meta
110 },
111 confidence_score: Some(0.95),
112 })
113 }
114}
115
116impl LLMClientPool {
117 fn new() -> Self {
118 let mut clients: HashMap<String, Box<dyn LLMClient>> = HashMap::new();
119 clients.insert("openai".to_string(), Box::new(MockLLMClient));
120 clients.insert("anthropic".to_string(), Box::new(MockLLMClient));
121 clients.insert("custom".to_string(), Box::new(MockLLMClient));
122
123 Self { clients }
124 }
125
126 async fn execute_request(
127 &self,
128 request: &ModelRequest,
129 provider: &LLMProvider,
130 ) -> Result<ModelResponse, RoutingError> {
131 let client_key = match provider {
132 LLMProvider::OpenAI { .. } => "openai",
133 LLMProvider::Anthropic { .. } => "anthropic",
134 LLMProvider::Custom { .. } => "custom",
135 };
136
137 let client =
138 self.clients
139 .get(client_key)
140 .ok_or_else(|| RoutingError::LLMFallbackFailed {
141 provider: provider.to_string(),
142 reason: "No client available for provider".to_string(),
143 })?;
144
145 client.execute_request(request, provider).await
146 }
147}
148
149impl DefaultRoutingEngine {
150 pub async fn new(
152 config: RoutingConfig,
153 model_catalog: ModelCatalog,
154 model_logger: Option<Arc<ModelLogger>>,
155 ) -> Result<Self, RoutingError> {
156 Self::new_with_confidence_monitor(
157 config,
158 model_catalog,
159 model_logger,
160 Box::new(NoOpConfidenceMonitor),
161 )
162 .await
163 }
164
165 pub async fn new_with_confidence_monitor(
168 config: RoutingConfig,
169 model_catalog: ModelCatalog,
170 model_logger: Option<Arc<ModelLogger>>,
171 confidence_monitor: Box<dyn ConfidenceMonitorTrait>,
172 ) -> Result<Self, RoutingError> {
173 let classifier = TaskClassifier::new(config.classification.clone())?;
175
176 let policy_evaluator =
178 PolicyEvaluator::new(config.policy.clone(), classifier, model_catalog.clone())?;
179
180 let llm_clients = Arc::new(LLMClientPool::new());
182
183 let engine = Self {
184 policy_evaluator: Arc::new(RwLock::new(policy_evaluator)),
185 model_catalog: Arc::new(model_catalog),
186 confidence_monitor: Arc::new(RwLock::new(confidence_monitor)),
187 model_logger,
188 statistics: Arc::new(RwLock::new(RoutingStatistics::default())),
189 config: Arc::new(RwLock::new(config)),
190 llm_clients,
191 };
192
193 Ok(engine)
194 }
195
196 async fn execute_slm_route(
198 &self,
199 context: &RoutingContext,
200 request: &ModelRequest,
201 model_id: &str,
202 monitoring_level: &super::decision::MonitoringLevel,
203 fallback_on_failure: bool,
204 ) -> Result<ModelResponse, RoutingError> {
205 let _start_time = Instant::now();
206
207 let model = self.model_catalog.get_model(model_id).ok_or_else(|| {
209 RoutingError::NoSuitableModel {
210 task_type: context.task_type.clone(),
211 }
212 })?;
213
214 let slm_result = self.execute_slm_mock(request, model).await;
216
217 match slm_result {
218 Ok(response) => {
219 let should_fallback = match monitoring_level {
221 super::decision::MonitoringLevel::None => false,
222 super::decision::MonitoringLevel::Basic => {
223 response.finish_reason != FinishReason::Stop
225 }
226 super::decision::MonitoringLevel::Enhanced {
227 confidence_threshold,
228 } => {
229 let confidence_score = response.confidence_score.unwrap_or(0.5);
232 confidence_score < *confidence_threshold
233 }
234 };
235
236 if should_fallback && fallback_on_failure {
237 tracing::warn!(
238 "SLM response did not meet confidence threshold, falling back to LLM"
239 );
240
241 {
243 let mut stats = self.statistics.write().await;
244 stats.fallback_routes += 1;
245 }
246
247 self.execute_llm_fallback(request, "Low confidence SLM response")
248 .await
249 } else {
250 Ok(response)
253 }
254 }
255 Err(e) => {
256 if fallback_on_failure {
257 tracing::error!("SLM execution failed, falling back to LLM: {}", e);
258
259 {
261 let mut stats = self.statistics.write().await;
262 stats.fallback_routes += 1;
263 }
264
265 self.execute_llm_fallback(request, &format!("SLM execution failed: {}", e))
266 .await
267 } else {
268 Err(RoutingError::ModelExecutionFailed {
269 model_id: model_id.to_string(),
270 reason: e.to_string(),
271 })
272 }
273 }
274 }
275 }
276
277 async fn execute_slm_mock(
279 &self,
280 request: &ModelRequest,
281 model: &crate::config::Model,
282 ) -> Result<ModelResponse, SlmRunnerError> {
283 tokio::time::sleep(Duration::from_millis(200)).await;
285
286 if request.prompt.contains("error") {
288 return Err(SlmRunnerError::ExecutionFailed {
289 reason: "Simulated execution error".to_string(),
290 });
291 }
292
293 Ok(ModelResponse {
294 content: format!("SLM ({}) response: {}", model.name, request.prompt),
295 finish_reason: FinishReason::Stop,
296 token_usage: Some(TokenUsage {
297 prompt_tokens: request.prompt.len() as u32 / 4,
298 completion_tokens: 30,
299 total_tokens: (request.prompt.len() as u32 / 4) + 30,
300 }),
301 metadata: {
302 let mut meta = HashMap::new();
303 meta.insert(
304 "model_id".to_string(),
305 serde_json::Value::String(model.id.clone()),
306 );
307 meta.insert(
308 "provider".to_string(),
309 serde_json::Value::String(format!("{:?}", model.provider)),
310 );
311 meta
312 },
313 confidence_score: Some(0.8 + (request.prompt.len() % 20) as f64 / 100.0), })
315 }
316
317 async fn execute_llm_fallback(
319 &self,
320 request: &ModelRequest,
321 _reason: &str,
322 ) -> Result<ModelResponse, RoutingError> {
323 let config = self.config.read().await;
324 let fallback_config = &config.policy.fallback_config;
325
326 if !fallback_config.enabled {
327 return Err(RoutingError::LLMFallbackFailed {
328 provider: "disabled".to_string(),
329 reason: "LLM fallback is disabled".to_string(),
330 });
331 }
332
333 let provider = LLMProvider::OpenAI { model: None };
335
336 match self.llm_clients.execute_request(request, &provider).await {
337 Ok(response) => Ok(response),
338 Err(e) => Err(RoutingError::LLMFallbackFailed {
339 provider: provider.to_string(),
340 reason: e.to_string(),
341 }),
342 }
343 }
344
345 async fn log_routing_execution(
347 &self,
348 context: &RoutingContext,
349 decision: &RouteDecision,
350 request: &ModelRequest,
351 response: &ModelResponse,
352 execution_time: Duration,
353 error: Option<&RoutingError>,
354 ) {
355 if let Some(ref logger) = self.model_logger {
356 let model_name = match decision {
357 RouteDecision::UseSLM { model_id, .. } => model_id.clone(),
358 RouteDecision::UseLLM { provider, .. } => provider.to_string(),
359 RouteDecision::Deny { .. } => "denied".to_string(),
360 };
361
362 let request_data = RequestData {
363 prompt: request.prompt.clone(),
364 tool_name: None,
365 tool_arguments: None,
366 parameters: {
367 let mut params = HashMap::new();
368 params.insert(
369 "routing_decision".to_string(),
370 serde_json::Value::String(format!("{:?}", decision)),
371 );
372 params.insert(
373 "task_type".to_string(),
374 serde_json::Value::String(context.task_type.to_string()),
375 );
376 params
377 },
378 };
379
380 let response_data = ResponseData {
381 content: response.content.clone(),
382 tool_result: None,
383 confidence: response.confidence_score,
384 metadata: response.metadata.clone(),
385 };
386
387 let metadata = {
388 let mut meta = HashMap::new();
389 meta.insert("routing_engine".to_string(), "default".to_string());
390 meta.insert("agent_id".to_string(), context.agent_id.to_string());
391 meta.insert("request_id".to_string(), context.request_id.clone());
392 meta
393 };
394
395 if let Err(e) = logger
396 .log_interaction(
397 context.agent_id,
398 ModelInteractionType::Completion,
399 &model_name,
400 request_data,
401 response_data,
402 execution_time,
403 metadata,
404 response.token_usage.as_ref().map(|t| LogTokenUsage {
405 input_tokens: t.prompt_tokens,
406 output_tokens: t.completion_tokens,
407 total_tokens: t.total_tokens,
408 }),
409 error.map(|e| e.to_string()),
410 )
411 .await
412 {
413 tracing::warn!("Failed to log routing execution: {}", e);
414 }
415 }
416 }
417}
418
419#[async_trait]
420impl RoutingEngine for DefaultRoutingEngine {
421 async fn route_request(&self, context: &RoutingContext) -> Result<RouteDecision, RoutingError> {
422 let policy_result = self
423 .policy_evaluator
424 .read()
425 .await
426 .evaluate_policies(context)
427 .await?;
428
429 tracing::debug!(
430 "Routing decision for agent {}: {:?} (rule: {:?})",
431 context.agent_id,
432 policy_result.decision,
433 policy_result.matched_rule
434 );
435
436 Ok(policy_result.decision)
437 }
438
439 async fn execute_with_routing(
440 &self,
441 context: RoutingContext,
442 request: ModelRequest,
443 ) -> Result<ModelResponse, RoutingError> {
444 let start_time = Instant::now();
445 let route_decision = self.route_request(&context).await?;
446
447 let result = match &route_decision {
448 RouteDecision::UseSLM {
449 model_id,
450 monitoring,
451 fallback_on_failure,
452 sandbox_tier: _,
453 } => {
454 self.execute_slm_route(
455 &context,
456 &request,
457 model_id,
458 monitoring,
459 *fallback_on_failure,
460 )
461 .await
462 }
463 RouteDecision::UseLLM {
464 provider,
465 reason,
466 sandbox_tier: _,
467 } => {
468 tracing::info!("Routing to LLM: {}", reason);
469 self.llm_clients.execute_request(&request, provider).await
470 }
471 RouteDecision::Deny {
472 reason,
473 policy_violated,
474 } => {
475 return Err(RoutingError::RoutingDenied {
476 policy: policy_violated.clone(),
477 reason: reason.clone(),
478 });
479 }
480 };
481
482 let execution_time = start_time.elapsed();
483
484 {
486 let mut stats = self.statistics.write().await;
487 stats.update(&route_decision, execution_time, result.is_ok());
488
489 if let Ok(ref response) = result {
490 if let Some(confidence) = response.confidence_score {
491 stats.add_confidence_score(confidence);
492 }
493 }
494 }
495
496 match &result {
498 Ok(response) => {
499 self.log_routing_execution(
500 &context,
501 &route_decision,
502 &request,
503 response,
504 execution_time,
505 None,
506 )
507 .await;
508 }
509 Err(error) => {
510 let dummy_response = ModelResponse {
512 content: "Error occurred".to_string(),
513 finish_reason: FinishReason::Error,
514 token_usage: None,
515 metadata: HashMap::new(),
516 confidence_score: Some(0.0),
517 };
518
519 self.log_routing_execution(
520 &context,
521 &route_decision,
522 &request,
523 &dummy_response,
524 execution_time,
525 Some(error),
526 )
527 .await;
528 }
529 }
530
531 result
532 }
533
534 fn validate_policies(&self) -> Result<(), RoutingError> {
535 Ok(())
537 }
538
539 async fn get_routing_stats(&self) -> RoutingStatistics {
540 self.statistics.read().await.clone()
541 }
542
543 async fn update_config(&self, config: RoutingConfig) -> Result<(), RoutingError> {
544 *self.config.write().await = config.clone();
546
547 self.policy_evaluator
549 .write()
550 .await
551 .update_config(config.policy)?;
552
553 Ok(())
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::config::{
561 Model, ModelAllowListConfig, ModelProvider, ModelResourceRequirements, SandboxProfile, Slm,
562 };
563 use crate::types::AgentId;
564 use std::collections::HashMap;
565 use std::path::PathBuf;
566
567 async fn create_test_engine() -> DefaultRoutingEngine {
568 let global_models = vec![
569 Model {
570 id: "test-slm".to_string(),
571 name: "Test SLM".to_string(),
572 provider: ModelProvider::LocalFile {
573 file_path: PathBuf::from("/tmp/test.gguf"),
574 },
575 capabilities: vec![
576 crate::config::ModelCapability::TextGeneration,
577 crate::config::ModelCapability::CodeGeneration,
578 ],
579 resource_requirements: ModelResourceRequirements {
580 min_memory_mb: 1024,
581 preferred_cpu_cores: 2.0,
582 gpu_requirements: None,
583 },
584 },
585 Model {
586 id: "error-slm".to_string(),
587 name: "Error SLM".to_string(),
588 provider: ModelProvider::LocalFile {
589 file_path: PathBuf::from("/tmp/error.gguf"),
590 },
591 capabilities: vec![crate::config::ModelCapability::TextGeneration],
592 resource_requirements: ModelResourceRequirements {
593 min_memory_mb: 512,
594 preferred_cpu_cores: 1.0,
595 gpu_requirements: None,
596 },
597 },
598 ];
599
600 let mut sandbox_profiles = HashMap::new();
601 sandbox_profiles.insert("default".to_string(), SandboxProfile::secure_default());
602
603 let slm_config = Slm {
604 enabled: true,
605 model_allow_lists: ModelAllowListConfig {
606 global_models,
607 agent_model_maps: HashMap::new(),
608 allow_runtime_overrides: false,
609 },
610 sandbox_profiles,
611 default_sandbox_profile: "default".to_string(),
612 };
613
614 let model_catalog = ModelCatalog::new(slm_config).unwrap();
615 let mut config = RoutingConfig::default();
616
617 config.policy.rules.push(super::super::config::RoutingRule {
619 name: "slm_code_rule".to_string(),
620 priority: 100,
621 conditions: super::super::config::RoutingConditions {
622 task_types: Some(vec![super::super::error::TaskType::CodeGeneration]),
623 agent_ids: None,
624 resource_constraints: None,
625 security_level: None,
626 custom_conditions: None,
627 },
628 action: super::super::config::RouteAction::UseSLM {
629 model_preference: super::super::config::ModelPreference::BestAvailable,
630 monitoring_level: crate::routing::MonitoringLevel::Basic,
631 fallback_on_low_confidence: true,
632 confidence_threshold: Some(0.8),
633 },
634 override_allowed: true,
635 action_extension: None,
636 });
637
638 DefaultRoutingEngine::new(config, model_catalog, None)
639 .await
640 .unwrap()
641 }
642
643 async fn create_test_engine_with_logger() -> DefaultRoutingEngine {
644 let logger =
645 ModelLogger::new(super::super::super::logging::LoggingConfig::default(), None).unwrap();
646
647 let global_models = vec![Model {
648 id: "test-slm".to_string(),
649 name: "Test SLM".to_string(),
650 provider: ModelProvider::LocalFile {
651 file_path: PathBuf::from("/tmp/test.gguf"),
652 },
653 capabilities: vec![crate::config::ModelCapability::TextGeneration],
654 resource_requirements: ModelResourceRequirements {
655 min_memory_mb: 1024,
656 preferred_cpu_cores: 2.0,
657 gpu_requirements: None,
658 },
659 }];
660
661 let mut sandbox_profiles = HashMap::new();
662 sandbox_profiles.insert("default".to_string(), SandboxProfile::secure_default());
663
664 let slm_config = Slm {
665 enabled: true,
666 model_allow_lists: ModelAllowListConfig {
667 global_models,
668 agent_model_maps: HashMap::new(),
669 allow_runtime_overrides: false,
670 },
671 sandbox_profiles,
672 default_sandbox_profile: "default".to_string(),
673 };
674
675 let model_catalog = ModelCatalog::new(slm_config).unwrap();
676 let config = RoutingConfig::default();
677
678 DefaultRoutingEngine::new(config, model_catalog, Some(Arc::new(logger)))
679 .await
680 .unwrap()
681 }
682
683 fn create_test_request(prompt: &str) -> ModelRequest {
684 ModelRequest::from_task(prompt.to_string())
685 }
686
687 fn create_test_context(
688 prompt: &str,
689 task_type: super::super::error::TaskType,
690 ) -> RoutingContext {
691 RoutingContext::new(AgentId::new(), task_type, prompt.to_string())
692 }
693
694 #[tokio::test]
695 async fn test_routing_engine_creation() {
696 let engine = create_test_engine().await;
697
698 let stats = engine.get_routing_stats().await;
700 assert_eq!(stats.total_requests, 0);
701
702 assert!(engine.validate_policies().is_ok());
704 }
705
706 #[tokio::test]
707 async fn test_routing_engine_with_logger() {
708 let engine = create_test_engine_with_logger().await;
709
710 assert!(engine.model_logger.is_some());
712
713 let stats = engine.get_routing_stats().await;
714 assert_eq!(stats.total_requests, 0);
715 }
716
717 #[tokio::test]
718 async fn test_routing_engine_basic_flow() {
719 let engine = create_test_engine().await;
720
721 let context = create_test_context(
722 "Write a hello world function",
723 super::super::error::TaskType::CodeGeneration,
724 );
725
726 let decision = engine.route_request(&context).await.unwrap();
727
728 match decision {
730 RouteDecision::UseSLM { .. } | RouteDecision::UseLLM { .. } => {
731 }
733 RouteDecision::Deny { .. } => {
734 panic!("Should not deny basic request");
735 }
736 }
737 }
738
739 #[tokio::test]
740 async fn test_execute_with_routing_slm_success() {
741 let engine = create_test_engine().await;
742
743 let context = create_test_context(
744 "Write a hello world function",
745 super::super::error::TaskType::CodeGeneration,
746 );
747
748 let request = create_test_request("Write a hello world function");
749
750 let response = engine.execute_with_routing(context, request).await.unwrap();
751
752 assert!(!response.content.is_empty());
753 assert!(response.confidence_score.is_some());
754 assert_eq!(response.finish_reason, FinishReason::Stop);
755
756 let stats = engine.get_routing_stats().await;
758 assert!(stats.total_requests > 0);
759 }
760
761 #[tokio::test]
762 async fn test_execute_with_routing_slm_error_fallback() {
763 let engine = create_test_engine().await;
764
765 let context = create_test_context(
766 "This should trigger an error in SLM",
767 super::super::error::TaskType::CodeGeneration,
768 );
769
770 let request = create_test_request("error trigger");
771
772 let response = engine.execute_with_routing(context, request).await.unwrap();
773
774 assert!(!response.content.is_empty());
776 assert!(response.content.contains("LLM response"));
777
778 let stats = engine.get_routing_stats().await;
779 assert!(stats.fallback_routes > 0);
780 }
781
782 #[tokio::test]
783 async fn test_slm_execution_success() {
784 let engine = create_test_engine().await;
785
786 let context =
787 create_test_context("Test prompt", super::super::error::TaskType::CodeGeneration);
788
789 let request = create_test_request("Test prompt");
790
791 let response = engine
792 .execute_slm_route(
793 &context,
794 &request,
795 "test-slm",
796 &crate::routing::MonitoringLevel::Basic,
797 true,
798 )
799 .await
800 .unwrap();
801
802 assert!(!response.content.is_empty());
803 assert!(response.content.contains("Test SLM"));
804 assert!(response.confidence_score.is_some());
805 }
806
807 #[tokio::test]
808 async fn test_slm_execution_with_enhanced_monitoring() {
809 let engine = create_test_engine().await;
810
811 let context = create_test_context(
812 "Test prompt with monitoring",
813 super::super::error::TaskType::CodeGeneration,
814 );
815
816 let request = create_test_request("Test prompt with monitoring");
817
818 let response = engine
819 .execute_slm_route(
820 &context,
821 &request,
822 "test-slm",
823 &crate::routing::MonitoringLevel::Enhanced {
824 confidence_threshold: 0.9,
825 },
826 true,
827 )
828 .await
829 .unwrap();
830
831 assert!(!response.content.is_empty());
833 }
834
835 #[tokio::test]
836 async fn test_slm_execution_no_monitoring() {
837 let engine = create_test_engine().await;
838
839 let context = create_test_context(
840 "Test prompt no monitoring",
841 super::super::error::TaskType::CodeGeneration,
842 );
843
844 let request = create_test_request("Test prompt no monitoring");
845
846 let response = engine
847 .execute_slm_route(
848 &context,
849 &request,
850 "test-slm",
851 &crate::routing::MonitoringLevel::None,
852 true,
853 )
854 .await
855 .unwrap();
856
857 assert!(!response.content.is_empty());
859 assert!(response.content.contains("Test SLM"));
860 }
861
862 #[tokio::test]
863 async fn test_slm_execution_error_no_fallback() {
864 let engine = create_test_engine().await;
865
866 let context = create_test_context(
867 "error trigger",
868 super::super::error::TaskType::CodeGeneration,
869 );
870
871 let request = create_test_request("error trigger");
872
873 let result = engine
874 .execute_slm_route(
875 &context,
876 &request,
877 "test-slm",
878 &crate::routing::MonitoringLevel::Basic,
879 false, )
881 .await;
882
883 assert!(result.is_err());
884 assert!(matches!(
885 result.unwrap_err(),
886 RoutingError::ModelExecutionFailed { .. }
887 ));
888 }
889
890 #[tokio::test]
891 async fn test_slm_execution_nonexistent_model() {
892 let engine = create_test_engine().await;
893
894 let context =
895 create_test_context("Test prompt", super::super::error::TaskType::CodeGeneration);
896
897 let request = create_test_request("Test prompt");
898
899 let result = engine
900 .execute_slm_route(
901 &context,
902 &request,
903 "nonexistent-model",
904 &crate::routing::MonitoringLevel::Basic,
905 true,
906 )
907 .await;
908
909 assert!(result.is_err());
910 assert!(matches!(
911 result.unwrap_err(),
912 RoutingError::NoSuitableModel { .. }
913 ));
914 }
915
916 #[tokio::test]
917 async fn test_llm_fallback_execution() {
918 let engine = create_test_engine().await;
919
920 let request = create_test_request("Test LLM fallback");
921
922 let response = engine
923 .execute_llm_fallback(&request, "Test reason")
924 .await
925 .unwrap();
926
927 assert!(!response.content.is_empty());
928 assert!(response.content.contains("LLM response"));
929 assert_eq!(response.finish_reason, FinishReason::Stop);
930 assert!(response.confidence_score.is_some());
931 }
932
933 #[tokio::test]
934 async fn test_llm_fallback_disabled() {
935 let engine = create_test_engine().await;
936
937 {
939 let mut config = engine.config.write().await;
940 config.policy.fallback_config.enabled = false;
941 }
942
943 let request = create_test_request("Test disabled fallback");
944
945 let result = engine.execute_llm_fallback(&request, "Test reason").await;
946
947 assert!(result.is_err());
948 assert!(matches!(
949 result.unwrap_err(),
950 RoutingError::LLMFallbackFailed { .. }
951 ));
952 }
953
954 #[tokio::test]
955 async fn test_llm_client_pool() {
956 let pool = LLMClientPool::new();
957
958 let request = create_test_request("Test LLM client");
959
960 let openai_response = pool
962 .execute_request(
963 &request,
964 &LLMProvider::OpenAI {
965 model: Some("gpt-3.5-turbo".to_string()),
966 },
967 )
968 .await
969 .unwrap();
970
971 assert!(!openai_response.content.is_empty());
972 assert!(openai_response.metadata.contains_key("provider"));
973
974 let anthropic_response = pool
976 .execute_request(
977 &request,
978 &LLMProvider::Anthropic {
979 model: Some("claude-3".to_string()),
980 },
981 )
982 .await
983 .unwrap();
984
985 assert!(!anthropic_response.content.is_empty());
986 assert!(anthropic_response.metadata.contains_key("provider"));
987
988 let custom_response = pool
990 .execute_request(
991 &request,
992 &LLMProvider::Custom {
993 endpoint: "http://localhost:8080".to_string(),
994 model: None,
995 },
996 )
997 .await
998 .unwrap();
999
1000 assert!(!custom_response.content.is_empty());
1001 }
1002
1003 #[tokio::test]
1004 async fn test_mock_slm_execution() {
1005 let engine = create_test_engine().await;
1006
1007 let request = create_test_request("Test SLM execution");
1008 let model = engine.model_catalog.get_model("test-slm").unwrap();
1009
1010 let response = engine.execute_slm_mock(&request, model).await.unwrap();
1011
1012 assert!(!response.content.is_empty());
1013 assert!(response.content.contains("Test SLM"));
1014 assert!(response.content.contains("Test SLM execution"));
1015 assert_eq!(response.finish_reason, FinishReason::Stop);
1016 assert!(response.confidence_score.is_some());
1017 assert!(response.token_usage.is_some());
1018 }
1019
1020 #[tokio::test]
1021 async fn test_mock_slm_execution_error() {
1022 let engine = create_test_engine().await;
1023
1024 let request = create_test_request("This should error out");
1025 let model = engine.model_catalog.get_model("test-slm").unwrap();
1026
1027 let result = engine.execute_slm_mock(&request, model).await;
1028
1029 assert!(result.is_err());
1030 assert!(matches!(
1031 result.unwrap_err(),
1032 SlmRunnerError::ExecutionFailed { .. }
1033 ));
1034 }
1035
1036 #[tokio::test]
1037 async fn test_routing_statistics_tracking() {
1038 let engine = create_test_engine().await;
1039
1040 let context1 = create_test_context("Test 1", super::super::error::TaskType::CodeGeneration);
1042 let request1 = create_test_request("Test 1");
1043
1044 let _response1 = engine
1045 .execute_with_routing(context1, request1)
1046 .await
1047 .unwrap();
1048
1049 let context2 = create_test_context(
1050 "error trigger",
1051 super::super::error::TaskType::CodeGeneration,
1052 );
1053 let request2 = create_test_request("error trigger");
1054
1055 let _response2 = engine
1056 .execute_with_routing(context2, request2)
1057 .await
1058 .unwrap();
1059
1060 let stats = engine.get_routing_stats().await;
1061
1062 assert!(stats.total_requests > 0);
1063 assert!(stats.fallback_routes > 0); assert!(stats.average_response_time > Duration::from_millis(0));
1065 }
1066
1067 #[tokio::test]
1068 async fn test_config_update() {
1069 let engine = create_test_engine().await;
1070
1071 let mut new_config = RoutingConfig::default();
1072 new_config.policy.fallback_config.enabled = false;
1073
1074 let result = engine.update_config(new_config.clone()).await;
1075 assert!(result.is_ok());
1076
1077 let updated_config = engine.config.read().await;
1079 assert!(!updated_config.policy.fallback_config.enabled);
1080 }
1081
1082 #[tokio::test]
1083 async fn test_routing_with_deny_decision() {
1084 let engine = create_test_engine().await;
1085
1086 let context = create_test_context(
1089 "forbidden operation",
1090 super::super::error::TaskType::Custom("forbidden".to_string()),
1091 );
1092
1093 let request = create_test_request("forbidden operation");
1094
1095 let result = engine.execute_with_routing(context, request).await;
1097
1098 match result {
1100 Ok(response) => {
1101 assert!(!response.content.is_empty());
1102 }
1103 Err(RoutingError::RoutingDenied { .. }) => {
1104 }
1106 Err(e) => {
1107 panic!("Unexpected error: {:?}", e);
1108 }
1109 }
1110 }
1111
1112 #[tokio::test]
1113 async fn test_logging_integration() {
1114 let engine = create_test_engine_with_logger().await;
1115
1116 let context = create_test_context(
1117 "Test logging integration",
1118 super::super::error::TaskType::CodeGeneration,
1119 );
1120
1121 let request = create_test_request("Test logging integration");
1122
1123 let response = engine.execute_with_routing(context, request).await.unwrap();
1124
1125 assert!(!response.content.is_empty());
1126 }
1128
1129 #[tokio::test]
1130 async fn test_confidence_monitoring_integration() {
1131 let engine = create_test_engine().await;
1132
1133 let context = create_test_context(
1134 "Test confidence monitoring",
1135 super::super::error::TaskType::CodeGeneration,
1136 );
1137
1138 let request = create_test_request("Test confidence monitoring");
1139
1140 let response = engine.execute_with_routing(context, request).await.unwrap();
1141
1142 assert!(!response.content.is_empty());
1143 assert!(response.confidence_score.is_some());
1144
1145 }
1148
1149 #[tokio::test]
1150 async fn test_policy_evaluation_integration() {
1151 let engine = create_test_engine().await;
1152
1153 let task_types = vec![
1155 super::super::error::TaskType::CodeGeneration,
1156 super::super::error::TaskType::CodeGeneration,
1157 super::super::error::TaskType::Analysis,
1158 super::super::error::TaskType::Reasoning,
1159 ];
1160
1161 for task_type in task_types {
1162 let context = create_test_context("Test policy evaluation", task_type.clone());
1163
1164 let decision = engine.route_request(&context).await.unwrap();
1165
1166 match decision {
1168 RouteDecision::UseSLM { .. }
1169 | RouteDecision::UseLLM { .. }
1170 | RouteDecision::Deny { .. } => {
1171 }
1173 }
1174 }
1175 }
1176
1177 #[tokio::test]
1178 async fn test_concurrent_routing_requests() {
1179 let engine = Arc::new(create_test_engine().await);
1180
1181 let mut handles = Vec::new();
1182
1183 for i in 0..10 {
1185 let engine_clone = Arc::clone(&engine);
1186 let handle = tokio::spawn(async move {
1187 let context = create_test_context(
1188 &format!("Concurrent request {}", i),
1189 super::super::error::TaskType::CodeGeneration,
1190 );
1191
1192 let request = create_test_request(&format!("Concurrent request {}", i));
1193
1194 engine_clone.execute_with_routing(context, request).await
1195 });
1196 handles.push(handle);
1197 }
1198
1199 let results = futures::future::join_all(handles).await;
1201
1202 for result in results {
1204 let response = result.unwrap().unwrap();
1205 assert!(!response.content.is_empty());
1206 }
1207
1208 let stats = engine.get_routing_stats().await;
1210 assert_eq!(stats.total_requests, 10);
1211 }
1212
1213 #[tokio::test]
1214 async fn test_error_handling_and_recovery() {
1215 let engine = create_test_engine().await;
1216
1217 let error_scenarios = vec![
1219 ("error trigger", "Should trigger SLM execution error"),
1220 ("", "Empty prompt"),
1221 (" ", "Whitespace-only prompt"),
1222 ];
1223
1224 for (prompt, description) in error_scenarios {
1225 let context =
1226 create_test_context(prompt, super::super::error::TaskType::CodeGeneration);
1227 let request = create_test_request(prompt);
1228
1229 let result = engine.execute_with_routing(context, request).await;
1230
1231 match result {
1232 Ok(response) => {
1233 assert!(!response.content.is_empty(), "Failed for: {}", description);
1235 }
1236 Err(e) => {
1237 tracing::info!("Expected error for '{}': {:?}", description, e);
1239 }
1240 }
1241 }
1242 }
1243
1244 #[tokio::test]
1245 async fn test_model_metadata_and_token_usage() {
1246 let engine = create_test_engine().await;
1247
1248 let context = create_test_context(
1249 "Test metadata and token usage",
1250 super::super::error::TaskType::CodeGeneration,
1251 );
1252
1253 let request = create_test_request("Test metadata and token usage");
1254
1255 let response = engine.execute_with_routing(context, request).await.unwrap();
1256
1257 assert!(!response.content.is_empty());
1259 assert!(response.token_usage.is_some());
1260 assert!(!response.metadata.is_empty());
1261
1262 let token_usage = response.token_usage.unwrap();
1263 assert!(token_usage.prompt_tokens > 0);
1264 assert!(token_usage.completion_tokens > 0);
1265 assert_eq!(
1266 token_usage.total_tokens,
1267 token_usage.prompt_tokens + token_usage.completion_tokens
1268 );
1269 }
1270
1271 #[tokio::test]
1272 async fn test_validate_policies() {
1273 let engine = create_test_engine().await;
1274
1275 let result = engine.validate_policies();
1276 assert!(result.is_ok());
1277 }
1278
1279 #[tokio::test]
1280 async fn test_engine_state_consistency() {
1281 let engine = create_test_engine().await;
1282
1283 let initial_stats = engine.get_routing_stats().await;
1285 assert_eq!(initial_stats.total_requests, 0);
1286
1287 for i in 0..5 {
1289 let context = create_test_context(
1290 &format!("Test request {}", i),
1291 super::super::error::TaskType::CodeGeneration,
1292 );
1293 let request = create_test_request(&format!("Test request {}", i));
1294
1295 let _response = engine.execute_with_routing(context, request).await.unwrap();
1296 }
1297
1298 let final_stats = engine.get_routing_stats().await;
1300 assert_eq!(final_stats.total_requests, 5);
1301 assert!(final_stats.average_response_time > Duration::from_millis(0));
1302
1303 }
1306}