ruvector_tiny_dancer_core/
router.rs1use crate::circuit_breaker::CircuitBreaker;
4use crate::error::{Result, TinyDancerError};
5use crate::feature_engineering::FeatureEngineer;
6use crate::model::FastGRNN;
7use crate::types::{RouterConfig, RoutingDecision, RoutingRequest, RoutingResponse};
8use crate::uncertainty::UncertaintyEstimator;
9use parking_lot::RwLock;
10use std::sync::Arc;
11use std::time::Instant;
12
13pub struct Router {
15 config: RouterConfig,
16 model: Arc<RwLock<FastGRNN>>,
17 feature_engineer: FeatureEngineer,
18 uncertainty_estimator: UncertaintyEstimator,
19 circuit_breaker: Option<CircuitBreaker>,
20}
21
22impl Router {
23 pub fn new(config: RouterConfig) -> Result<Self> {
25 let model = if std::path::Path::new(&config.model_path).exists() {
27 FastGRNN::load(&config.model_path)?
28 } else {
29 FastGRNN::new(Default::default())?
30 };
31
32 let circuit_breaker = if config.enable_circuit_breaker {
33 Some(CircuitBreaker::new(config.circuit_breaker_threshold))
34 } else {
35 None
36 };
37
38 Ok(Self {
39 config,
40 model: Arc::new(RwLock::new(model)),
41 feature_engineer: FeatureEngineer::new(),
42 uncertainty_estimator: UncertaintyEstimator::new(),
43 circuit_breaker,
44 })
45 }
46
47 pub fn default() -> Result<Self> {
49 Self::new(RouterConfig::default())
50 }
51
52 pub fn route(&self, request: RoutingRequest) -> Result<RoutingResponse> {
54 let start = Instant::now();
55
56 if let Some(ref cb) = self.circuit_breaker {
58 if !cb.is_closed() {
59 return Err(TinyDancerError::CircuitBreakerError(
60 "Circuit breaker is open".to_string(),
61 ));
62 }
63 }
64
65 let feature_start = Instant::now();
67 let feature_vectors = self.feature_engineer.extract_batch_features(
68 &request.query_embedding,
69 &request.candidates,
70 request.metadata.as_ref(),
71 )?;
72 let feature_time_us = feature_start.elapsed().as_micros() as u64;
73
74 let model = self.model.read();
76 let mut decisions = Vec::new();
77
78 for (candidate, features) in request.candidates.iter().zip(feature_vectors.iter()) {
79 match model.forward(&features.features, None) {
80 Ok(score) => {
81 let uncertainty = self
83 .uncertainty_estimator
84 .estimate(&features.features, score);
85
86 let use_lightweight = score >= self.config.confidence_threshold
88 && uncertainty <= self.config.max_uncertainty;
89
90 decisions.push(RoutingDecision {
91 candidate_id: candidate.id.clone(),
92 confidence: score,
93 use_lightweight,
94 uncertainty,
95 });
96
97 if let Some(ref cb) = self.circuit_breaker {
99 cb.record_success();
100 }
101 }
102 Err(e) => {
103 if let Some(ref cb) = self.circuit_breaker {
105 cb.record_failure();
106 }
107 return Err(e);
108 }
109 }
110 }
111
112 decisions.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
114
115 let inference_time_us = start.elapsed().as_micros() as u64;
116
117 Ok(RoutingResponse {
118 decisions,
119 inference_time_us,
120 candidates_processed: request.candidates.len(),
121 feature_time_us,
122 })
123 }
124
125 pub fn reload_model(&self) -> Result<()> {
127 let new_model = FastGRNN::load(&self.config.model_path)?;
128 let mut model = self.model.write();
129 *model = new_model;
130 Ok(())
131 }
132
133 pub fn config(&self) -> &RouterConfig {
135 &self.config
136 }
137
138 pub fn circuit_breaker_status(&self) -> Option<bool> {
140 self.circuit_breaker.as_ref().map(|cb| cb.is_closed())
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::types::Candidate;
148 use chrono::Utc;
149 use std::collections::HashMap;
150
151 #[test]
152 fn test_router_creation() {
153 let router = Router::default().unwrap();
154 assert!(router.circuit_breaker_status().is_some());
155 }
156
157 #[test]
158 fn test_routing() {
159 let router = Router::default().unwrap();
160
161 let candidates = vec![
164 Candidate {
165 id: "1".to_string(),
166 embedding: vec![0.5; 384], metadata: HashMap::new(),
168 created_at: Utc::now().timestamp(),
169 access_count: 10,
170 success_rate: 0.95,
171 },
172 Candidate {
173 id: "2".to_string(),
174 embedding: vec![0.3; 384],
175 metadata: HashMap::new(),
176 created_at: Utc::now().timestamp(),
177 access_count: 5,
178 success_rate: 0.85,
179 },
180 ];
181
182 let request = RoutingRequest {
183 query_embedding: vec![0.5; 384],
184 candidates,
185 metadata: None,
186 };
187
188 let response = router.route(request).unwrap();
189 assert_eq!(response.decisions.len(), 2);
190 assert!(response.inference_time_us > 0);
191 }
192}