1use serde::{Deserialize, Serialize};
4#[allow(unused_imports)]
5use std::sync::Arc;
6use thiserror::Error;
7
8use crate::classifier::{QueryClassifier, QueryComplexity};
9use crate::compress::CompressionLevel;
10use crate::models::{Model, ModelPool};
11use crate::observability::Observability;
12use std::collections::HashMap;
13
14pub use crate::config::RoutingStrategy;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RoutingDecision {
20 pub model_id: String,
21 pub estimated_cost: f64,
22 pub estimated_latency_ms: u64,
23 pub estimated_savings: f64,
24 pub reason: String,
25}
26
27#[derive(Debug, Clone)]
29pub struct RouterConfig {
30 pub strategy: RoutingStrategy,
31 pub quality_threshold: f64,
32 pub max_cost_per_request: f64,
33 pub max_latency_ms: u64,
34 pub cache_enabled: bool,
35 pub guardrails_enabled: bool,
36 pub compression_level: CompressionLevel,
37 pub chora_public_key: Option<Vec<u8>>,
38 pub strict_mode: bool,
39}
40
41impl Default for RouterConfig {
42 fn default() -> Self {
43 Self {
44 strategy: RoutingStrategy::Auto,
45 quality_threshold: 0.85,
46 max_cost_per_request: 1.0,
47 max_latency_ms: 10000,
48 cache_enabled: true,
49 guardrails_enabled: true,
50 compression_level: CompressionLevel::Balanced,
51 chora_public_key: None,
52 strict_mode: false,
53 }
54 }
55}
56
57#[derive(Debug, Error)]
59pub enum RouterError {
60 #[error("No models available")]
61 NoModelsAvailable,
62 #[error("Request failed: {0}")]
63 RequestFailed(String),
64 #[error("All models failed")]
65 AllModelsFailed,
66 #[error("Guardrails blocked request")]
67 GuardrailsBlocked,
68 #[error("VEP verification failed: {0}")]
69 VepVerificationFailed(String),
70}
71
72pub struct Router {
74 pool: ModelPool,
75 classifier: QueryClassifier,
76 config: RouterConfig,
77 observability: Observability,
78 providers: HashMap<String, Arc<dyn LlmProvider>>,
80 fallback_provider: Option<Arc<dyn LlmProvider>>,
82}
83
84impl std::fmt::Debug for Router {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("Router")
87 .field("pool", &self.pool)
88 .field("config", &self.config)
89 .field("provider_count", &self.providers.len())
90 .field("has_fallback", &self.fallback_provider.is_some())
91 .finish()
92 }
93}
94
95impl Router {
96 pub fn new() -> Self {
98 Self {
99 pool: ModelPool::default(),
100 classifier: QueryClassifier::new(),
101 config: RouterConfig::default(),
102 observability: Observability::default(),
103 providers: HashMap::new(),
104 fallback_provider: None,
105 }
106 }
107
108 pub fn with_config(config: RouterConfig) -> Self {
110 Self {
111 pool: ModelPool::default(),
112 classifier: QueryClassifier::new(),
113 config,
114 observability: Observability::default(),
115 providers: HashMap::new(),
116 fallback_provider: None,
117 }
118 }
119
120 pub fn builder() -> RouterBuilder {
122 RouterBuilder::new()
123 }
124
125 pub fn route(&self, prompt: &str, system: &str) -> Result<RoutingDecision, RouterError> {
127 let mut complexity = self.classifier.classify(prompt);
128
129 let system_lower = system.to_lowercase();
132 if system_lower.contains("shadow")
133 || system_lower.contains("adversarial")
134 || system_lower.contains("red agent")
135 {
136 complexity.score = (complexity.score + 0.4).min(1.0);
137 complexity.capabilities.push("adversarial".to_string());
138 }
139
140 self.route_with_complexity(&complexity)
141 }
142
143 pub fn route_with_complexity(
145 &self,
146 complexity: &QueryComplexity,
147 ) -> Result<RoutingDecision, RouterError> {
148 if self.pool.is_empty() {
149 return Err(RouterError::NoModelsAvailable);
150 }
151
152 match self.config.strategy {
153 RoutingStrategy::Auto | RoutingStrategy::Balanced => self.route_auto(complexity),
154 RoutingStrategy::CostOptimized => self.route_cost_optimized(complexity),
155 RoutingStrategy::QualityOptimized => self.route_quality_optimized(complexity),
156 RoutingStrategy::LatencyOptimized => self.route_latency_optimized(complexity),
157 RoutingStrategy::Custom => {
158 self.route_auto(complexity)
160 }
161 }
162 }
163
164 pub async fn execute(&self, prompt: &str, system: &str) -> Result<String, RouterError> {
168 if self.config.strict_mode {
169 return Err(RouterError::VepVerificationFailed(
170 "Strict Mode Enabled: Unenveloped requests are blocked. Use VEP encapsulation."
171 .to_string(),
172 ));
173 }
174 let decision = self.route(prompt, system)?;
175
176 let provider = self
178 .providers
179 .get(&decision.model_id)
180 .or(self.fallback_provider.as_ref());
181
182 if let Some(provider) = provider {
183 let request = LlmRequest::with_role(system, prompt);
184 let response = provider
185 .complete(request)
186 .await
187 .map_err(|e| RouterError::RequestFailed(e.to_string()))?;
188 return Ok(response.content);
189 }
190
191 Ok(format!(
193 "[vex-router: {}] Query routed based on complexity: {:.2}, Role: {}, Estimated savings: {:.0}%",
194 decision.model_id,
195 0.5,
196 if system.to_lowercase().contains("shadow") { "Adversarial" } else { "Primary" },
197 decision.estimated_savings
198 ))
199 }
200
201 pub async fn verify_and_route(&self, vep_data: &[u8]) -> Result<String, RouterError> {
203 use vex_core::VepPacket;
204
205 let packet = VepPacket::new(vep_data)
207 .map_err(|e| RouterError::VepVerificationFailed(e.to_string()))?;
208
209 if let Some(pub_key) = &self.config.chora_public_key {
210 if !packet
211 .verify(pub_key)
212 .map_err(RouterError::VepVerificationFailed)?
213 {
214 return Err(RouterError::VepVerificationFailed(
215 "Cryptographic signature mismatch".to_string(),
216 ));
217 }
218 }
219
220 let capsule = packet
222 .to_capsule()
223 .map_err(RouterError::VepVerificationFailed)?;
224
225 let intent_id = match &capsule.intent {
227 vex_core::segment::IntentData::Transparent { request_sha256, .. } => {
228 request_sha256.clone()
229 }
230 vex_core::segment::IntentData::Shadow {
231 commitment_hash, ..
232 } => commitment_hash.clone(),
233 };
234
235 let prompt = format!("Encapsulated Intent (Hardened): {}", intent_id);
236 let system = "VEP Enveloped Request";
237
238 self.execute(&prompt, system).await
239 }
240
241 pub async fn ask(&self, prompt: &str) -> Result<String, RouterError> {
243 self.execute(prompt, "").await
244 }
245
246 fn route_auto(&self, complexity: &QueryComplexity) -> Result<RoutingDecision, RouterError> {
251 let model = if complexity.score < 0.3 {
253 self.pool.get_cheapest()
254 } else if complexity.score < 0.7 {
255 self.pool.get_medium()
256 } else {
257 self.pool.get_best()
258 };
259
260 let model = model.ok_or(RouterError::NoModelsAvailable)?;
261
262 let savings = if complexity.score < 0.3 {
263 95.0
264 } else if complexity.score < 0.7 {
265 60.0
266 } else {
267 20.0
268 };
269
270 Ok(RoutingDecision {
271 model_id: model.id.clone(),
272 estimated_cost: model.config.input_cost,
273 estimated_latency_ms: model.config.latency_ms,
274 estimated_savings: savings,
275 reason: format!(
276 "Auto-selected based on complexity score: {:.2}",
277 complexity.score
278 ),
279 })
280 }
281
282 fn route_cost_optimized(
283 &self,
284 _complexity: &QueryComplexity,
285 ) -> Result<RoutingDecision, RouterError> {
286 let mut models: Vec<&Model> = self.pool.models.iter().collect();
288 models.sort_by(|a, b| {
289 a.config
290 .input_cost
291 .partial_cmp(&b.config.input_cost)
292 .unwrap()
293 });
294
295 for model in models {
296 let meets_quality = model.config.quality_score >= self.config.quality_threshold;
297 if meets_quality {
298 return Ok(RoutingDecision {
299 model_id: model.id.clone(),
300 estimated_cost: model.config.input_cost,
301 estimated_latency_ms: model.config.latency_ms,
302 estimated_savings: 80.0,
303 reason: "Cost-optimized: cheapest model meeting quality threshold".to_string(),
304 });
305 }
306 }
307
308 Err(RouterError::NoModelsAvailable)
309 }
310
311 fn route_quality_optimized(
312 &self,
313 _complexity: &QueryComplexity,
314 ) -> Result<RoutingDecision, RouterError> {
315 let model = self.pool.get_best().ok_or(RouterError::NoModelsAvailable)?;
316
317 Ok(RoutingDecision {
318 model_id: model.id.clone(),
319 estimated_cost: model.config.input_cost,
320 estimated_latency_ms: model.config.latency_ms,
321 estimated_savings: 0.0,
322 reason: "Quality-optimized: selected best available model".to_string(),
323 })
324 }
325
326 fn route_latency_optimized(
327 &self,
328 _complexity: &QueryComplexity,
329 ) -> Result<RoutingDecision, RouterError> {
330 let mut models: Vec<&Model> = self.pool.models.iter().collect();
331 models.sort_by(|a, b| a.config.latency_ms.cmp(&b.config.latency_ms));
332
333 let model = models.first().ok_or(RouterError::NoModelsAvailable)?;
334
335 Ok(RoutingDecision {
336 model_id: model.id.clone(),
337 estimated_cost: model.config.input_cost,
338 estimated_latency_ms: model.config.latency_ms,
339 estimated_savings: 50.0,
340 reason: "Latency-optimized: fastest model".to_string(),
341 })
342 }
343
344 pub fn config(&self) -> &RouterConfig {
346 &self.config
347 }
348
349 pub fn pool(&self) -> &ModelPool {
351 &self.pool
352 }
353
354 pub fn observability(&self) -> &Observability {
356 &self.observability
357 }
358}
359
360impl Default for Router {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366pub struct RouterBuilder {
368 config: RouterConfig,
369 custom_models: Vec<crate::config::ModelConfig>,
370 providers: HashMap<String, Arc<dyn LlmProvider>>,
371 fallback_provider: Option<Arc<dyn LlmProvider>>,
372}
373
374impl std::fmt::Debug for RouterBuilder {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 f.debug_struct("RouterBuilder")
377 .field("config", &self.config)
378 .field("custom_models", &self.custom_models)
379 .field("provider_count", &self.providers.len())
380 .finish()
381 }
382}
383
384impl RouterBuilder {
385 pub fn new() -> Self {
386 Self {
387 config: RouterConfig::default(),
388 custom_models: Vec::new(),
389 providers: HashMap::new(),
390 fallback_provider: None,
391 }
392 }
393
394 pub fn strategy(mut self, strategy: RoutingStrategy) -> Self {
395 self.config.strategy = strategy;
396 self
397 }
398
399 pub fn quality_threshold(mut self, threshold: f64) -> Self {
400 self.config.quality_threshold = threshold;
401 self
402 }
403
404 pub fn max_cost(mut self, cost: f64) -> Self {
405 self.config.max_cost_per_request = cost;
406 self
407 }
408
409 pub fn cache_enabled(mut self, enabled: bool) -> Self {
410 self.config.cache_enabled = enabled;
411 self
412 }
413
414 pub fn guardrails_enabled(mut self, enabled: bool) -> Self {
415 self.config.guardrails_enabled = enabled;
416 self
417 }
418
419 pub fn compression_level(mut self, level: crate::compress::CompressionLevel) -> Self {
420 self.config.compression_level = level;
421 self
422 }
423
424 pub fn chora_public_key(mut self, key: Vec<u8>) -> Self {
425 self.config.chora_public_key = Some(key);
426 self
427 }
428
429 pub fn strict_mode(mut self, enabled: bool) -> Self {
430 self.config.strict_mode = enabled;
431 self
432 }
433
434 pub fn add_model(mut self, model: crate::config::ModelConfig) -> Self {
435 self.custom_models.push(model);
436 self
437 }
438
439 pub fn add_provider(
441 mut self,
442 model_id: impl Into<String>,
443 provider: Arc<dyn LlmProvider>,
444 ) -> Self {
445 self.providers.insert(model_id.into(), provider);
446 self
447 }
448
449 pub fn with_fallback_provider(mut self, provider: Arc<dyn LlmProvider>) -> Self {
451 self.fallback_provider = Some(provider);
452 self
453 }
454
455 pub fn build(self) -> Router {
456 let pool = if self.custom_models.is_empty() {
457 ModelPool::default()
458 } else {
459 ModelPool::new(self.custom_models)
460 };
461
462 Router {
463 pool,
464 classifier: QueryClassifier::new(),
465 config: self.config,
466 observability: Observability::new(1000),
467 providers: self.providers,
468 fallback_provider: self.fallback_provider,
469 }
470 }
471}
472
473impl Default for RouterBuilder {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479use async_trait::async_trait;
485use vex_llm::{LlmError, LlmProvider, LlmRequest, LlmResponse};
486
487#[async_trait]
488impl LlmProvider for Router {
489 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
491 let start = std::time::Instant::now();
492
493 let response = self
494 .execute(&request.prompt, &request.system)
495 .await
496 .map_err(|e| LlmError::RequestFailed(e.to_string()))?;
497
498 let response_len = response.len();
499 let latency = start.elapsed().as_millis() as u64;
500
501 let decision = self
502 .route(&request.prompt, &request.system)
503 .map_err(|e| LlmError::RequestFailed(e.to_string()))?;
504
505 Ok(LlmResponse {
506 content: response,
507 model: decision.model_id,
508 tokens_used: Some(((request.prompt.len() + response_len) as f64 / 4.0) as u32),
509 latency_ms: latency,
510 trace_root: None,
511 })
512 }
513
514 async fn is_available(&self) -> bool {
516 !self.pool.is_empty()
517 }
518
519 fn name(&self) -> &str {
521 "vex-router"
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[tokio::test]
530 async fn test_router_auto() {
531 let router = Router::builder().strategy(RoutingStrategy::Auto).build();
532
533 let decision = router.route("What is 2+2?", "").unwrap();
534 assert!(!decision.model_id.is_empty());
535 }
536
537 #[tokio::test]
538 async fn test_router_execute() {
539 let router = Router::new();
540 let response = router.ask("Hello").await.unwrap();
541 assert!(response.contains("vex-router"));
542 }
543
544 #[test]
545 fn test_router_builder() {
546 let router = Router::builder()
547 .strategy(RoutingStrategy::CostOptimized)
548 .quality_threshold(0.9)
549 .cache_enabled(false)
550 .build();
551
552 assert_eq!(router.config().strategy, RoutingStrategy::CostOptimized);
553 assert_eq!(router.config().quality_threshold, 0.9);
554 assert!(!router.config().cache_enabled);
555 }
556
557 #[tokio::test]
558 async fn test_llm_request() {
559 let request = LlmRequest::simple("test");
560 assert_eq!(request.system, "You are a helpful assistant.");
561 assert_eq!(request.prompt, "test");
562 }
563
564 #[tokio::test]
565 async fn test_strict_mode() {
566 let router = Router::builder().strict_mode(true).build();
567
568 let result = router.execute("hello", "").await;
570 assert!(result.is_err());
571
572 match result {
573 Err(crate::router::RouterError::VepVerificationFailed(msg)) => {
574 assert!(msg.contains("Strict Mode Enabled"));
575 }
576 _ => panic!("Expected Strict Mode error"),
577 }
578 }
579}