ruvector_tiny_dancer_core/
router.rs

1//! Main routing engine combining all components
2
3use crate::circuit_breaker::CircuitBreaker;
4use crate::error::{Result, TinyDancerError};
5use crate::feature_engineering::FeatureEngineer;
6use crate::model::FastGRNN;
7use crate::types::{RouterConfig, RoutingDecision, RoutingRequest, RoutingResponse};
8use crate::uncertainty::UncertaintyEstimator;
9use parking_lot::RwLock;
10use std::sync::Arc;
11use std::time::Instant;
12
13/// Main router for AI agent routing
14pub struct Router {
15    config: RouterConfig,
16    model: Arc<RwLock<FastGRNN>>,
17    feature_engineer: FeatureEngineer,
18    uncertainty_estimator: UncertaintyEstimator,
19    circuit_breaker: Option<CircuitBreaker>,
20}
21
22impl Router {
23    /// Create a new router with the given configuration
24    pub fn new(config: RouterConfig) -> Result<Self> {
25        // Load or create model
26        let model = if std::path::Path::new(&config.model_path).exists() {
27            FastGRNN::load(&config.model_path)?
28        } else {
29            FastGRNN::new(Default::default())?
30        };
31
32        let circuit_breaker = if config.enable_circuit_breaker {
33            Some(CircuitBreaker::new(config.circuit_breaker_threshold))
34        } else {
35            None
36        };
37
38        Ok(Self {
39            config,
40            model: Arc::new(RwLock::new(model)),
41            feature_engineer: FeatureEngineer::new(),
42            uncertainty_estimator: UncertaintyEstimator::new(),
43            circuit_breaker,
44        })
45    }
46
47    /// Create a router with default configuration
48    pub fn default() -> Result<Self> {
49        Self::new(RouterConfig::default())
50    }
51
52    /// Route a request through the system
53    pub fn route(&self, request: RoutingRequest) -> Result<RoutingResponse> {
54        let start = Instant::now();
55
56        // Check circuit breaker
57        if let Some(ref cb) = self.circuit_breaker {
58            if !cb.is_closed() {
59                return Err(TinyDancerError::CircuitBreakerError(
60                    "Circuit breaker is open".to_string(),
61                ));
62            }
63        }
64
65        // Feature engineering
66        let feature_start = Instant::now();
67        let feature_vectors = self.feature_engineer.extract_batch_features(
68            &request.query_embedding,
69            &request.candidates,
70            request.metadata.as_ref(),
71        )?;
72        let feature_time_us = feature_start.elapsed().as_micros() as u64;
73
74        // Model inference
75        let model = self.model.read();
76        let mut decisions = Vec::new();
77
78        for (candidate, features) in request.candidates.iter().zip(feature_vectors.iter()) {
79            match model.forward(&features.features, None) {
80                Ok(score) => {
81                    // Estimate uncertainty
82                    let uncertainty = self
83                        .uncertainty_estimator
84                        .estimate(&features.features, score);
85
86                    // Determine routing decision
87                    let use_lightweight = score >= self.config.confidence_threshold
88                        && uncertainty <= self.config.max_uncertainty;
89
90                    decisions.push(RoutingDecision {
91                        candidate_id: candidate.id.clone(),
92                        confidence: score,
93                        use_lightweight,
94                        uncertainty,
95                    });
96
97                    // Record success with circuit breaker
98                    if let Some(ref cb) = self.circuit_breaker {
99                        cb.record_success();
100                    }
101                }
102                Err(e) => {
103                    // Record failure with circuit breaker
104                    if let Some(ref cb) = self.circuit_breaker {
105                        cb.record_failure();
106                    }
107                    return Err(e);
108                }
109            }
110        }
111
112        // Sort by confidence (descending)
113        decisions.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
114
115        let inference_time_us = start.elapsed().as_micros() as u64;
116
117        Ok(RoutingResponse {
118            decisions,
119            inference_time_us,
120            candidates_processed: request.candidates.len(),
121            feature_time_us,
122        })
123    }
124
125    /// Reload the model from disk
126    pub fn reload_model(&self) -> Result<()> {
127        let new_model = FastGRNN::load(&self.config.model_path)?;
128        let mut model = self.model.write();
129        *model = new_model;
130        Ok(())
131    }
132
133    /// Get router configuration
134    pub fn config(&self) -> &RouterConfig {
135        &self.config
136    }
137
138    /// Get circuit breaker status
139    pub fn circuit_breaker_status(&self) -> Option<bool> {
140        self.circuit_breaker.as_ref().map(|cb| cb.is_closed())
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::types::Candidate;
148    use chrono::Utc;
149    use std::collections::HashMap;
150
151    #[test]
152    fn test_router_creation() {
153        let router = Router::default().unwrap();
154        assert!(router.circuit_breaker_status().is_some());
155    }
156
157    #[test]
158    fn test_routing() {
159        let router = Router::default().unwrap();
160
161        // The default FastGRNN model expects input dimension to match feature count (5)
162        // Features: semantic_similarity, recency, frequency, success_rate, metadata_overlap
163        let candidates = vec![
164            Candidate {
165                id: "1".to_string(),
166                embedding: vec![0.5; 384], // Embeddings can be any size
167                metadata: HashMap::new(),
168                created_at: Utc::now().timestamp(),
169                access_count: 10,
170                success_rate: 0.95,
171            },
172            Candidate {
173                id: "2".to_string(),
174                embedding: vec![0.3; 384],
175                metadata: HashMap::new(),
176                created_at: Utc::now().timestamp(),
177                access_count: 5,
178                success_rate: 0.85,
179            },
180        ];
181
182        let request = RoutingRequest {
183            query_embedding: vec![0.5; 384],
184            candidates,
185            metadata: None,
186        };
187
188        let response = router.route(request).unwrap();
189        assert_eq!(response.decisions.len(), 2);
190        assert!(response.inference_time_us > 0);
191    }
192}