Skip to main content

vex_router/router/
mod.rs

1//! Router - Core routing logic for VEX
2
3use serde::{Deserialize, Serialize};
4#[allow(unused_imports)]
5use std::sync::Arc;
6use thiserror::Error;
7
8use crate::classifier::{QueryClassifier, QueryComplexity};
9use crate::compress::CompressionLevel;
10use crate::models::{Model, ModelPool};
11use crate::observability::Observability;
12use std::collections::HashMap;
13
14/// Routing strategy (re-exported from config)
15pub use crate::config::RoutingStrategy;
16
17/// A routing decision
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RoutingDecision {
20    pub model_id: String,
21    pub estimated_cost: f64,
22    pub estimated_latency_ms: u64,
23    pub estimated_savings: f64,
24    pub reason: String,
25}
26
27/// Router configuration
28#[derive(Debug, Clone)]
29pub struct RouterConfig {
30    pub strategy: RoutingStrategy,
31    pub quality_threshold: f64,
32    pub max_cost_per_request: f64,
33    pub max_latency_ms: u64,
34    pub cache_enabled: bool,
35    pub guardrails_enabled: bool,
36    pub compression_level: CompressionLevel,
37    pub chora_public_key: Option<Vec<u8>>,
38    pub strict_mode: bool,
39}
40
41impl Default for RouterConfig {
42    fn default() -> Self {
43        Self {
44            strategy: RoutingStrategy::Auto,
45            quality_threshold: 0.85,
46            max_cost_per_request: 1.0,
47            max_latency_ms: 10000,
48            cache_enabled: true,
49            guardrails_enabled: true,
50            compression_level: CompressionLevel::Balanced,
51            chora_public_key: None,
52            strict_mode: false,
53        }
54    }
55}
56
57/// Router errors
58#[derive(Debug, Error)]
59pub enum RouterError {
60    #[error("No models available")]
61    NoModelsAvailable,
62    #[error("Request failed: {0}")]
63    RequestFailed(String),
64    #[error("All models failed")]
65    AllModelsFailed,
66    #[error("Guardrails blocked request")]
67    GuardrailsBlocked,
68    #[error("VEP verification failed: {0}")]
69    VepVerificationFailed(String),
70}
71
72/// The main Router - implements LlmProvider trait for VEX
73pub struct Router {
74    pool: ModelPool,
75    classifier: QueryClassifier,
76    config: RouterConfig,
77    observability: Observability,
78    /// Registered LLM providers keyed by model_id
79    providers: HashMap<String, Arc<dyn LlmProvider>>,
80    /// Fallback provider used when no specific provider is registered for a model
81    fallback_provider: Option<Arc<dyn LlmProvider>>,
82}
83
84impl std::fmt::Debug for Router {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.debug_struct("Router")
87            .field("pool", &self.pool)
88            .field("config", &self.config)
89            .field("provider_count", &self.providers.len())
90            .field("has_fallback", &self.fallback_provider.is_some())
91            .finish()
92    }
93}
94
95impl Router {
96    /// Create a new router with default settings
97    pub fn new() -> Self {
98        Self {
99            pool: ModelPool::default(),
100            classifier: QueryClassifier::new(),
101            config: RouterConfig::default(),
102            observability: Observability::default(),
103            providers: HashMap::new(),
104            fallback_provider: None,
105        }
106    }
107
108    /// Create a router with a custom configuration
109    pub fn with_config(config: RouterConfig) -> Self {
110        Self {
111            pool: ModelPool::default(),
112            classifier: QueryClassifier::new(),
113            config,
114            observability: Observability::default(),
115            providers: HashMap::new(),
116            fallback_provider: None,
117        }
118    }
119
120    /// Get a builder for configuration
121    pub fn builder() -> RouterBuilder {
122        RouterBuilder::new()
123    }
124
125    /// Route a query and return a decision (without executing)
126    pub fn route(&self, prompt: &str, system: &str) -> Result<RoutingDecision, RouterError> {
127        let mut complexity = self.classifier.classify(prompt);
128
129        // ADVERSARIAL ROUTING: If system prompt implies an attacker/shadow role,
130        // bump the complexity/quality requirements to ensure a strong adversary.
131        let system_lower = system.to_lowercase();
132        if system_lower.contains("shadow")
133            || system_lower.contains("adversarial")
134            || system_lower.contains("red agent")
135        {
136            complexity.score = (complexity.score + 0.4).min(1.0);
137            complexity.capabilities.push("adversarial".to_string());
138        }
139
140        self.route_with_complexity(&complexity)
141    }
142
143    /// Route with pre-computed complexity
144    pub fn route_with_complexity(
145        &self,
146        complexity: &QueryComplexity,
147    ) -> Result<RoutingDecision, RouterError> {
148        if self.pool.is_empty() {
149            return Err(RouterError::NoModelsAvailable);
150        }
151
152        match self.config.strategy {
153            RoutingStrategy::Auto | RoutingStrategy::Balanced => self.route_auto(complexity),
154            RoutingStrategy::CostOptimized => self.route_cost_optimized(complexity),
155            RoutingStrategy::QualityOptimized => self.route_quality_optimized(complexity),
156            RoutingStrategy::LatencyOptimized => self.route_latency_optimized(complexity),
157            RoutingStrategy::Custom => {
158                // Fall back to auto for custom
159                self.route_auto(complexity)
160            }
161        }
162    }
163
164    /// Execute a query through the router.
165    /// If a provider is registered for the routed model, calls it directly.
166    /// Otherwise falls back to mock response for backward compatibility.
167    pub async fn execute(&self, prompt: &str, system: &str) -> Result<String, RouterError> {
168        if self.config.strict_mode {
169            return Err(RouterError::VepVerificationFailed(
170                "Strict Mode Enabled: Unenveloped requests are blocked. Use VEP encapsulation."
171                    .to_string(),
172            ));
173        }
174        let decision = self.route(prompt, system)?;
175
176        // Try to call a registered provider
177        let provider = self
178            .providers
179            .get(&decision.model_id)
180            .or(self.fallback_provider.as_ref());
181
182        if let Some(provider) = provider {
183            let request = LlmRequest::with_role(system, prompt);
184            let response = provider
185                .complete(request)
186                .await
187                .map_err(|e| RouterError::RequestFailed(e.to_string()))?;
188            return Ok(response.content);
189        }
190
191        // Mock fallback when no provider is registered
192        Ok(format!(
193            "[vex-router: {}] Query routed based on complexity: {:.2}, Role: {}, Estimated savings: {:.0}%",
194            decision.model_id,
195            0.5,
196            if system.to_lowercase().contains("shadow") { "Adversarial" } else { "Primary" },
197            decision.estimated_savings
198        ))
199    }
200
201    /// Stateless verification and routing of a VEP packet (Phase 3.2)
202    pub async fn verify_and_route(&self, vep_data: &[u8]) -> Result<String, RouterError> {
203        use vex_core::VepPacket;
204
205        // 1. Parse and verify the VEP packet
206        let packet = VepPacket::new(vep_data)
207            .map_err(|e| RouterError::VepVerificationFailed(e.to_string()))?;
208
209        if let Some(pub_key) = &self.config.chora_public_key {
210            if !packet
211                .verify(pub_key)
212                .map_err(RouterError::VepVerificationFailed)?
213            {
214                return Err(RouterError::VepVerificationFailed(
215                    "Cryptographic signature mismatch".to_string(),
216                ));
217            }
218        }
219
220        // 2. Reconstruct the capsule (intent + authority)
221        let capsule = packet
222            .to_capsule()
223            .map_err(RouterError::VepVerificationFailed)?;
224
225        // 3. Extract prompt from intent (hardened identifier)
226        let intent_id = match &capsule.intent {
227            vex_core::segment::IntentData::Transparent { request_sha256, .. } => {
228                request_sha256.clone()
229            }
230            vex_core::segment::IntentData::Shadow {
231                commitment_hash, ..
232            } => commitment_hash.clone(),
233        };
234
235        let prompt = format!("Encapsulated Intent (Hardened): {}", intent_id);
236        let system = "VEP Enveloped Request";
237
238        self.execute(&prompt, system).await
239    }
240
241    /// Convenience method - ask a question
242    pub async fn ask(&self, prompt: &str) -> Result<String, RouterError> {
243        self.execute(prompt, "").await
244    }
245
246    // =========================================================================
247    // Routing Strategies
248    // =========================================================================
249
250    fn route_auto(&self, complexity: &QueryComplexity) -> Result<RoutingDecision, RouterError> {
251        // Simple heuristic: low complexity = cheap model, high complexity = premium
252        let model = if complexity.score < 0.3 {
253            self.pool.get_cheapest()
254        } else if complexity.score < 0.7 {
255            self.pool.get_medium()
256        } else {
257            self.pool.get_best()
258        };
259
260        let model = model.ok_or(RouterError::NoModelsAvailable)?;
261
262        let savings = if complexity.score < 0.3 {
263            95.0
264        } else if complexity.score < 0.7 {
265            60.0
266        } else {
267            20.0
268        };
269
270        Ok(RoutingDecision {
271            model_id: model.id.clone(),
272            estimated_cost: model.config.input_cost,
273            estimated_latency_ms: model.config.latency_ms,
274            estimated_savings: savings,
275            reason: format!(
276                "Auto-selected based on complexity score: {:.2}",
277                complexity.score
278            ),
279        })
280    }
281
282    fn route_cost_optimized(
283        &self,
284        _complexity: &QueryComplexity,
285    ) -> Result<RoutingDecision, RouterError> {
286        // Find cheapest model that meets quality threshold
287        let mut models: Vec<&Model> = self.pool.models.iter().collect();
288        models.sort_by(|a, b| {
289            a.config
290                .input_cost
291                .partial_cmp(&b.config.input_cost)
292                .unwrap()
293        });
294
295        for model in models {
296            let meets_quality = model.config.quality_score >= self.config.quality_threshold;
297            if meets_quality {
298                return Ok(RoutingDecision {
299                    model_id: model.id.clone(),
300                    estimated_cost: model.config.input_cost,
301                    estimated_latency_ms: model.config.latency_ms,
302                    estimated_savings: 80.0,
303                    reason: "Cost-optimized: cheapest model meeting quality threshold".to_string(),
304                });
305            }
306        }
307
308        Err(RouterError::NoModelsAvailable)
309    }
310
311    fn route_quality_optimized(
312        &self,
313        _complexity: &QueryComplexity,
314    ) -> Result<RoutingDecision, RouterError> {
315        let model = self.pool.get_best().ok_or(RouterError::NoModelsAvailable)?;
316
317        Ok(RoutingDecision {
318            model_id: model.id.clone(),
319            estimated_cost: model.config.input_cost,
320            estimated_latency_ms: model.config.latency_ms,
321            estimated_savings: 0.0,
322            reason: "Quality-optimized: selected best available model".to_string(),
323        })
324    }
325
326    fn route_latency_optimized(
327        &self,
328        _complexity: &QueryComplexity,
329    ) -> Result<RoutingDecision, RouterError> {
330        let mut models: Vec<&Model> = self.pool.models.iter().collect();
331        models.sort_by(|a, b| a.config.latency_ms.cmp(&b.config.latency_ms));
332
333        let model = models.first().ok_or(RouterError::NoModelsAvailable)?;
334
335        Ok(RoutingDecision {
336            model_id: model.id.clone(),
337            estimated_cost: model.config.input_cost,
338            estimated_latency_ms: model.config.latency_ms,
339            estimated_savings: 50.0,
340            reason: "Latency-optimized: fastest model".to_string(),
341        })
342    }
343
344    /// Get the current configuration
345    pub fn config(&self) -> &RouterConfig {
346        &self.config
347    }
348
349    /// Get the model pool
350    pub fn pool(&self) -> &ModelPool {
351        &self.pool
352    }
353
354    /// Get the observability metrics
355    pub fn observability(&self) -> &Observability {
356        &self.observability
357    }
358}
359
360impl Default for Router {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366/// Builder for Router
367pub struct RouterBuilder {
368    config: RouterConfig,
369    custom_models: Vec<crate::config::ModelConfig>,
370    providers: HashMap<String, Arc<dyn LlmProvider>>,
371    fallback_provider: Option<Arc<dyn LlmProvider>>,
372}
373
374impl std::fmt::Debug for RouterBuilder {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        f.debug_struct("RouterBuilder")
377            .field("config", &self.config)
378            .field("custom_models", &self.custom_models)
379            .field("provider_count", &self.providers.len())
380            .finish()
381    }
382}
383
384impl RouterBuilder {
385    pub fn new() -> Self {
386        Self {
387            config: RouterConfig::default(),
388            custom_models: Vec::new(),
389            providers: HashMap::new(),
390            fallback_provider: None,
391        }
392    }
393
394    pub fn strategy(mut self, strategy: RoutingStrategy) -> Self {
395        self.config.strategy = strategy;
396        self
397    }
398
399    pub fn quality_threshold(mut self, threshold: f64) -> Self {
400        self.config.quality_threshold = threshold;
401        self
402    }
403
404    pub fn max_cost(mut self, cost: f64) -> Self {
405        self.config.max_cost_per_request = cost;
406        self
407    }
408
409    pub fn cache_enabled(mut self, enabled: bool) -> Self {
410        self.config.cache_enabled = enabled;
411        self
412    }
413
414    pub fn guardrails_enabled(mut self, enabled: bool) -> Self {
415        self.config.guardrails_enabled = enabled;
416        self
417    }
418
419    pub fn compression_level(mut self, level: crate::compress::CompressionLevel) -> Self {
420        self.config.compression_level = level;
421        self
422    }
423
424    pub fn chora_public_key(mut self, key: Vec<u8>) -> Self {
425        self.config.chora_public_key = Some(key);
426        self
427    }
428
429    pub fn strict_mode(mut self, enabled: bool) -> Self {
430        self.config.strict_mode = enabled;
431        self
432    }
433
434    pub fn add_model(mut self, model: crate::config::ModelConfig) -> Self {
435        self.custom_models.push(model);
436        self
437    }
438
439    /// Register an LLM provider for a specific model_id
440    pub fn add_provider(
441        mut self,
442        model_id: impl Into<String>,
443        provider: Arc<dyn LlmProvider>,
444    ) -> Self {
445        self.providers.insert(model_id.into(), provider);
446        self
447    }
448
449    /// Set a fallback provider used when no model-specific provider is found
450    pub fn with_fallback_provider(mut self, provider: Arc<dyn LlmProvider>) -> Self {
451        self.fallback_provider = Some(provider);
452        self
453    }
454
455    pub fn build(self) -> Router {
456        let pool = if self.custom_models.is_empty() {
457            ModelPool::default()
458        } else {
459            ModelPool::new(self.custom_models)
460        };
461
462        Router {
463            pool,
464            classifier: QueryClassifier::new(),
465            config: self.config,
466            observability: Observability::new(1000),
467            providers: self.providers,
468            fallback_provider: self.fallback_provider,
469        }
470    }
471}
472
473impl Default for RouterBuilder {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479// =============================================================================
480// VEX LlmProvider Trait Implementation (for VEX integration)
481// =============================================================================
482
483// Re-using official VEX LLM types
484use async_trait::async_trait;
485use vex_llm::{LlmError, LlmProvider, LlmRequest, LlmResponse};
486
487#[async_trait]
488impl LlmProvider for Router {
489    /// Complete a request (implements vex_llm::LlmProvider::complete)
490    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
491        let start = std::time::Instant::now();
492
493        let response = self
494            .execute(&request.prompt, &request.system)
495            .await
496            .map_err(|e| LlmError::RequestFailed(e.to_string()))?;
497
498        let response_len = response.len();
499        let latency = start.elapsed().as_millis() as u64;
500
501        let decision = self
502            .route(&request.prompt, &request.system)
503            .map_err(|e| LlmError::RequestFailed(e.to_string()))?;
504
505        Ok(LlmResponse {
506            content: response,
507            model: decision.model_id,
508            tokens_used: Some(((request.prompt.len() + response_len) as f64 / 4.0) as u32),
509            latency_ms: latency,
510            trace_root: None,
511        })
512    }
513
514    /// Check if router is available
515    async fn is_available(&self) -> bool {
516        !self.pool.is_empty()
517    }
518
519    /// Get provider name
520    fn name(&self) -> &str {
521        "vex-router"
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    #[tokio::test]
530    async fn test_router_auto() {
531        let router = Router::builder().strategy(RoutingStrategy::Auto).build();
532
533        let decision = router.route("What is 2+2?", "").unwrap();
534        assert!(!decision.model_id.is_empty());
535    }
536
537    #[tokio::test]
538    async fn test_router_execute() {
539        let router = Router::new();
540        let response = router.ask("Hello").await.unwrap();
541        assert!(response.contains("vex-router"));
542    }
543
544    #[test]
545    fn test_router_builder() {
546        let router = Router::builder()
547            .strategy(RoutingStrategy::CostOptimized)
548            .quality_threshold(0.9)
549            .cache_enabled(false)
550            .build();
551
552        assert_eq!(router.config().strategy, RoutingStrategy::CostOptimized);
553        assert_eq!(router.config().quality_threshold, 0.9);
554        assert!(!router.config().cache_enabled);
555    }
556
557    #[tokio::test]
558    async fn test_llm_request() {
559        let request = LlmRequest::simple("test");
560        assert_eq!(request.system, "You are a helpful assistant.");
561        assert_eq!(request.prompt, "test");
562    }
563
564    #[tokio::test]
565    async fn test_strict_mode() {
566        let router = Router::builder().strict_mode(true).build();
567
568        // Standard execution should fail
569        let result = router.execute("hello", "").await;
570        assert!(result.is_err());
571
572        match result {
573            Err(crate::router::RouterError::VepVerificationFailed(msg)) => {
574                assert!(msg.contains("Strict Mode Enabled"));
575            }
576            _ => panic!("Expected Strict Mode error"),
577        }
578    }
579}