ruvector_tiny_dancer_core/
router.rs1use crate::circuit_breaker::CircuitBreaker;
4use crate::error::{Result, TinyDancerError};
5use crate::feature_engineering::FeatureEngineer;
6use crate::model::FastGRNN;
7use crate::types::{RouterConfig, RoutingDecision, RoutingRequest, RoutingResponse};
8use crate::uncertainty::UncertaintyEstimator;
9use std::sync::Arc;
10use std::time::Instant;
11use parking_lot::RwLock;
12
13pub struct Router {
15 config: RouterConfig,
16 model: Arc<RwLock<FastGRNN>>,
17 feature_engineer: FeatureEngineer,
18 uncertainty_estimator: UncertaintyEstimator,
19 circuit_breaker: Option<CircuitBreaker>,
20}
21
22impl Router {
23 pub fn new(config: RouterConfig) -> Result<Self> {
25 let model = if std::path::Path::new(&config.model_path).exists() {
27 FastGRNN::load(&config.model_path)?
28 } else {
29 FastGRNN::new(Default::default())?
30 };
31
32 let circuit_breaker = if config.enable_circuit_breaker {
33 Some(CircuitBreaker::new(config.circuit_breaker_threshold))
34 } else {
35 None
36 };
37
38 Ok(Self {
39 config,
40 model: Arc::new(RwLock::new(model)),
41 feature_engineer: FeatureEngineer::new(),
42 uncertainty_estimator: UncertaintyEstimator::new(),
43 circuit_breaker,
44 })
45 }
46
47 pub fn default() -> Result<Self> {
49 Self::new(RouterConfig::default())
50 }
51
52 pub fn route(&self, request: RoutingRequest) -> Result<RoutingResponse> {
54 let start = Instant::now();
55
56 if let Some(ref cb) = self.circuit_breaker {
58 if !cb.is_closed() {
59 return Err(TinyDancerError::CircuitBreakerError(
60 "Circuit breaker is open".to_string(),
61 ));
62 }
63 }
64
65 let feature_start = Instant::now();
67 let feature_vectors = self.feature_engineer.extract_batch_features(
68 &request.query_embedding,
69 &request.candidates,
70 request.metadata.as_ref(),
71 )?;
72 let feature_time_us = feature_start.elapsed().as_micros() as u64;
73
74 let model = self.model.read();
76 let mut decisions = Vec::new();
77
78 for (candidate, features) in request.candidates.iter().zip(feature_vectors.iter()) {
79 match model.forward(&features.features, None) {
80 Ok(score) => {
81 let uncertainty = self.uncertainty_estimator.estimate(&features.features, score);
83
84 let use_lightweight = score >= self.config.confidence_threshold
86 && uncertainty <= self.config.max_uncertainty;
87
88 decisions.push(RoutingDecision {
89 candidate_id: candidate.id.clone(),
90 confidence: score,
91 use_lightweight,
92 uncertainty,
93 });
94
95 if let Some(ref cb) = self.circuit_breaker {
97 cb.record_success();
98 }
99 }
100 Err(e) => {
101 if let Some(ref cb) = self.circuit_breaker {
103 cb.record_failure();
104 }
105 return Err(e);
106 }
107 }
108 }
109
110 decisions.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
112
113 let inference_time_us = start.elapsed().as_micros() as u64;
114
115 Ok(RoutingResponse {
116 decisions,
117 inference_time_us,
118 candidates_processed: request.candidates.len(),
119 feature_time_us,
120 })
121 }
122
123 pub fn reload_model(&self) -> Result<()> {
125 let new_model = FastGRNN::load(&self.config.model_path)?;
126 let mut model = self.model.write();
127 *model = new_model;
128 Ok(())
129 }
130
131 pub fn config(&self) -> &RouterConfig {
133 &self.config
134 }
135
136 pub fn circuit_breaker_status(&self) -> Option<bool> {
138 self.circuit_breaker.as_ref().map(|cb| cb.is_closed())
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::types::Candidate;
146 use chrono::Utc;
147 use std::collections::HashMap;
148
149 #[test]
150 fn test_router_creation() {
151 let router = Router::default().unwrap();
152 assert!(router.circuit_breaker_status().is_some());
153 }
154
155 #[test]
156 fn test_routing() {
157 let router = Router::default().unwrap();
158
159 let candidates = vec![
162 Candidate {
163 id: "1".to_string(),
164 embedding: vec![0.5; 384], metadata: HashMap::new(),
166 created_at: Utc::now().timestamp(),
167 access_count: 10,
168 success_rate: 0.95,
169 },
170 Candidate {
171 id: "2".to_string(),
172 embedding: vec![0.3; 384],
173 metadata: HashMap::new(),
174 created_at: Utc::now().timestamp(),
175 access_count: 5,
176 success_rate: 0.85,
177 },
178 ];
179
180 let request = RoutingRequest {
181 query_embedding: vec![0.5; 384],
182 candidates,
183 metadata: None,
184 };
185
186 let response = router.route(request).unwrap();
187 assert_eq!(response.decisions.len(), 2);
188 assert!(response.inference_time_us > 0);
189 }
190}