ruvector_tiny_dancer_core/
router.rs

1//! Main routing engine combining all components
2
3use crate::circuit_breaker::CircuitBreaker;
4use crate::error::{Result, TinyDancerError};
5use crate::feature_engineering::FeatureEngineer;
6use crate::model::FastGRNN;
7use crate::types::{RouterConfig, RoutingDecision, RoutingRequest, RoutingResponse};
8use crate::uncertainty::UncertaintyEstimator;
9use std::sync::Arc;
10use std::time::Instant;
11use parking_lot::RwLock;
12
13/// Main router for AI agent routing
14pub struct Router {
15    config: RouterConfig,
16    model: Arc<RwLock<FastGRNN>>,
17    feature_engineer: FeatureEngineer,
18    uncertainty_estimator: UncertaintyEstimator,
19    circuit_breaker: Option<CircuitBreaker>,
20}
21
22impl Router {
23    /// Create a new router with the given configuration
24    pub fn new(config: RouterConfig) -> Result<Self> {
25        // Load or create model
26        let model = if std::path::Path::new(&config.model_path).exists() {
27            FastGRNN::load(&config.model_path)?
28        } else {
29            FastGRNN::new(Default::default())?
30        };
31
32        let circuit_breaker = if config.enable_circuit_breaker {
33            Some(CircuitBreaker::new(config.circuit_breaker_threshold))
34        } else {
35            None
36        };
37
38        Ok(Self {
39            config,
40            model: Arc::new(RwLock::new(model)),
41            feature_engineer: FeatureEngineer::new(),
42            uncertainty_estimator: UncertaintyEstimator::new(),
43            circuit_breaker,
44        })
45    }
46
47    /// Create a router with default configuration
48    pub fn default() -> Result<Self> {
49        Self::new(RouterConfig::default())
50    }
51
52    /// Route a request through the system
53    pub fn route(&self, request: RoutingRequest) -> Result<RoutingResponse> {
54        let start = Instant::now();
55
56        // Check circuit breaker
57        if let Some(ref cb) = self.circuit_breaker {
58            if !cb.is_closed() {
59                return Err(TinyDancerError::CircuitBreakerError(
60                    "Circuit breaker is open".to_string(),
61                ));
62            }
63        }
64
65        // Feature engineering
66        let feature_start = Instant::now();
67        let feature_vectors = self.feature_engineer.extract_batch_features(
68            &request.query_embedding,
69            &request.candidates,
70            request.metadata.as_ref(),
71        )?;
72        let feature_time_us = feature_start.elapsed().as_micros() as u64;
73
74        // Model inference
75        let model = self.model.read();
76        let mut decisions = Vec::new();
77
78        for (candidate, features) in request.candidates.iter().zip(feature_vectors.iter()) {
79            match model.forward(&features.features, None) {
80                Ok(score) => {
81                    // Estimate uncertainty
82                    let uncertainty = self.uncertainty_estimator.estimate(&features.features, score);
83
84                    // Determine routing decision
85                    let use_lightweight = score >= self.config.confidence_threshold
86                        && uncertainty <= self.config.max_uncertainty;
87
88                    decisions.push(RoutingDecision {
89                        candidate_id: candidate.id.clone(),
90                        confidence: score,
91                        use_lightweight,
92                        uncertainty,
93                    });
94
95                    // Record success with circuit breaker
96                    if let Some(ref cb) = self.circuit_breaker {
97                        cb.record_success();
98                    }
99                }
100                Err(e) => {
101                    // Record failure with circuit breaker
102                    if let Some(ref cb) = self.circuit_breaker {
103                        cb.record_failure();
104                    }
105                    return Err(e);
106                }
107            }
108        }
109
110        // Sort by confidence (descending)
111        decisions.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
112
113        let inference_time_us = start.elapsed().as_micros() as u64;
114
115        Ok(RoutingResponse {
116            decisions,
117            inference_time_us,
118            candidates_processed: request.candidates.len(),
119            feature_time_us,
120        })
121    }
122
123    /// Reload the model from disk
124    pub fn reload_model(&self) -> Result<()> {
125        let new_model = FastGRNN::load(&self.config.model_path)?;
126        let mut model = self.model.write();
127        *model = new_model;
128        Ok(())
129    }
130
131    /// Get router configuration
132    pub fn config(&self) -> &RouterConfig {
133        &self.config
134    }
135
136    /// Get circuit breaker status
137    pub fn circuit_breaker_status(&self) -> Option<bool> {
138        self.circuit_breaker.as_ref().map(|cb| cb.is_closed())
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::types::Candidate;
146    use chrono::Utc;
147    use std::collections::HashMap;
148
149    #[test]
150    fn test_router_creation() {
151        let router = Router::default().unwrap();
152        assert!(router.circuit_breaker_status().is_some());
153    }
154
155    #[test]
156    fn test_routing() {
157        let router = Router::default().unwrap();
158
159        // The default FastGRNN model expects input dimension to match feature count (5)
160        // Features: semantic_similarity, recency, frequency, success_rate, metadata_overlap
161        let candidates = vec![
162            Candidate {
163                id: "1".to_string(),
164                embedding: vec![0.5; 384], // Embeddings can be any size
165                metadata: HashMap::new(),
166                created_at: Utc::now().timestamp(),
167                access_count: 10,
168                success_rate: 0.95,
169            },
170            Candidate {
171                id: "2".to_string(),
172                embedding: vec![0.3; 384],
173                metadata: HashMap::new(),
174                created_at: Utc::now().timestamp(),
175                access_count: 5,
176                success_rate: 0.85,
177            },
178        ];
179
180        let request = RoutingRequest {
181            query_embedding: vec![0.5; 384],
182            candidates,
183            metadata: None,
184        };
185
186        let response = router.route(request).unwrap();
187        assert_eq!(response.decisions.len(), 2);
188        assert!(response.inference_time_us > 0);
189    }
190}