1use super::*;
20use async_trait::async_trait;
21use std::collections::HashMap;
22use std::sync::Arc;
23use tokio::sync::RwLock;
24
25pub use super::{ContentType, StrategyChoice};
27
28#[derive(Debug, Clone, Default)]
30pub struct RoutingStats {
31 pub total_requests: u64,
33
34 pub successful_requests: u64,
36
37 pub failed_requests: u64,
39
40 pub avg_latency_ms: f64,
42
43 pub strategy_success: HashMap<String, f64>,
45}
46
47impl RoutingStats {
48 pub fn success_rate(&self) -> f64 {
50 if self.total_requests == 0 {
51 1.0
52 } else {
53 self.successful_requests as f64 / self.total_requests as f64
54 }
55 }
56}
57
58pub struct AdaptiveRouter {
60 _local_id: NodeId,
62
63 strategies: Arc<RwLock<HashMap<StrategyChoice, Box<dyn RoutingStrategy>>>>,
65
66 bandit: Arc<RwLock<ThompsonSampling>>,
68
69 metrics: Arc<RwLock<HashMap<String, f64>>>,
71
72 stats: Arc<RwLock<RoutingStats>>,
74}
75
76impl AdaptiveRouter {
77 pub fn new(
79 trust_provider: Arc<dyn TrustProvider>,
80 _hyperbolic: Arc<hyperbolic::HyperbolicSpace>,
81 _som: Arc<som::SelfOrganizingMap>,
82 ) -> Self {
83 let node_id = NodeId { hash: [0u8; 32] }; Self::new_with_id(node_id, trust_provider)
85 }
86
87 pub fn new_with_id(node_id: NodeId, _trust_provider: Arc<dyn TrustProvider>) -> Self {
89 Self {
90 _local_id: node_id,
91 strategies: Arc::new(RwLock::new(HashMap::new())),
92 bandit: Arc::new(RwLock::new(ThompsonSampling::new())),
93 metrics: Arc::new(RwLock::new(HashMap::new())),
94 stats: Arc::new(RwLock::new(RoutingStats::default())),
95 }
96 }
97
98 pub async fn register_strategy(
100 &self,
101 choice: StrategyChoice,
102 strategy: Box<dyn RoutingStrategy>,
103 ) {
104 let mut strategies = self.strategies.write().await;
105 strategies.insert(choice, strategy);
106 }
107
108 pub async fn route(&self, target: &NodeId, content_type: ContentType) -> Result<Vec<NodeId>> {
110 let strategy_choice = self
112 .bandit
113 .read()
114 .await
115 .select_strategy(content_type)
116 .await
117 .unwrap_or(StrategyChoice::Kademlia);
118
119 {
121 let mut metrics = self.metrics.write().await;
122 let key = format!("route_attempts_{strategy_choice:?}");
123 let count = metrics.get(&key).copied().unwrap_or(0.0) + 1.0;
124 metrics.insert(key, count);
125 }
126
127 let start = std::time::Instant::now();
129 let strategies = self.strategies.read().await;
130
131 let result = if let Some(strategy) = strategies.get(&strategy_choice) {
132 let primary_result = strategy.find_path(target).await;
133
134 if primary_result.is_err() && strategy_choice != StrategyChoice::Kademlia {
136 if let Some(kademlia) = strategies.get(&StrategyChoice::Kademlia) {
137 kademlia.find_path(target).await
138 } else {
139 primary_result
140 }
141 } else {
142 primary_result
143 }
144 } else {
145 if let Some(kademlia) = strategies.get(&StrategyChoice::Kademlia) {
147 kademlia.find_path(target).await
148 } else {
149 Err(AdaptiveNetworkError::Routing(
150 "No routing strategies available".to_string(),
151 ))
152 }
153 };
154
155 let success = result.is_ok();
157 let latency = start.elapsed().as_millis() as f64;
158
159 self.bandit
160 .write()
161 .await
162 .update(content_type, strategy_choice, success, latency as u64)
163 .await
164 .unwrap_or(());
165
166 if success {
168 let mut metrics = self.metrics.write().await;
169 let success_key = format!("route_success_{strategy_choice:?}");
170 let count = metrics.get(&success_key).copied().unwrap_or(0.0) + 1.0;
171 metrics.insert(success_key, count);
172 metrics.insert(format!("route_latency_{strategy_choice:?}"), latency);
173 }
174
175 result
176 }
177
178 pub async fn get_metrics(&self) -> std::collections::HashMap<String, f64> {
180 self.metrics.read().await.clone()
181 }
182
183 pub fn get_all_strategies(&self) -> HashMap<String, Arc<dyn RoutingStrategy>> {
185 HashMap::new()
187 }
188
189 pub async fn mark_node_unreliable(&self, _node_id: &NodeId) {
191 let strategies = self.strategies.read().await;
193 for (_choice, _strategy) in strategies.iter() {
194 }
196 }
197
198 pub async fn remove_node(&self, _node_id: &NodeId) {
200 }
203
204 pub async fn remove_hyperbolic_coordinate(&self, _node_id: &NodeId) {
206 }
208
209 pub async fn remove_from_som(&self, _node_id: &NodeId) {
211 }
213
214 pub async fn enable_aggressive_caching(&self) {
216 }
218
219 pub async fn rebalance_hyperbolic_space(&self) {
221 }
223
224 pub async fn update_som_grid(&self) {
226 }
228
229 pub async fn trigger_trust_recomputation(&self) {
231 }
233
234 pub async fn update_statistics(&self, node_id: &NodeId, success: bool, latency_ms: u64) {
236 let mut metrics = self.metrics.write().await;
237 let key = format!("node_{node_id:?}_success_rate");
238 let current = metrics.get(&key).copied().unwrap_or(0.0);
239 let new_value = if success {
240 current * 0.9 + 0.1
241 } else {
242 current * 0.9
243 };
244 metrics.insert(key, new_value);
245
246 let mut stats = self.stats.write().await;
248 stats.total_requests += 1;
249 if success {
250 stats.successful_requests += 1;
251 } else {
252 stats.failed_requests += 1;
253 }
254
255 let current_avg = stats.avg_latency_ms;
257 let count = stats.total_requests as f64;
258 stats.avg_latency_ms = (current_avg * (count - 1.0) + latency_ms as f64) / count;
259 }
260
261 pub async fn get_stats(&self) -> RoutingStats {
263 self.stats.read().await.clone()
264 }
265}
266
267pub struct KademliaRouting {
269 _node_id: NodeId,
270 _routing_table: Arc<RwLock<HashMap<NodeId, Vec<NodeId>>>>,
272}
273
274impl KademliaRouting {
275 pub fn new(node_id: NodeId) -> Self {
276 Self {
277 _node_id: node_id.clone(),
278 _routing_table: Arc::new(RwLock::new(HashMap::new())),
279 }
280 }
281}
282
283#[async_trait]
284impl RoutingStrategy for KademliaRouting {
285 async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
286 Ok(vec![target.clone()])
289 }
290
291 fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
292 let neighbor_bytes = &_neighbor.hash;
294 let target_bytes = &_target.hash;
295 let mut distance = 0u32;
296
297 for i in 0..32 {
298 distance += (neighbor_bytes[i] ^ target_bytes[i]).count_ones();
299 }
300
301 1.0 / (1.0 + distance as f64)
303 }
304
305 fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
306 }
308}
309
310pub struct HyperbolicRouting {
312 _coordinates: Arc<RwLock<HashMap<NodeId, HyperbolicCoordinate>>>,
313}
314
315impl Default for HyperbolicRouting {
316 fn default() -> Self {
317 Self::new()
318 }
319}
320
321impl HyperbolicRouting {
322 pub fn new() -> Self {
323 Self {
324 _coordinates: Arc::new(RwLock::new(HashMap::new())),
325 }
326 }
327
328 pub fn distance(a: &HyperbolicCoordinate, b: &HyperbolicCoordinate) -> f64 {
330 let delta = 2.0 * ((a.r - b.r).powi(2) + (a.theta - b.theta).cos().acos().powi(2)).sqrt();
331 let denominator = (1.0 - a.r.powi(2)) * (1.0 - b.r.powi(2));
332
333 (1.0 + delta / denominator).acosh()
334 }
335}
336
337#[async_trait]
338impl RoutingStrategy for HyperbolicRouting {
339 async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
340 Ok(vec![target.clone()])
343 }
344
345 fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
346 0.0 }
350
351 fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
352 }
354}
355
356pub struct TrustRouting {
358 trust_provider: Arc<dyn TrustProvider>,
359}
360
361impl TrustRouting {
362 pub fn new(trust_provider: Arc<dyn TrustProvider>) -> Self {
363 Self { trust_provider }
364 }
365}
366
367#[async_trait]
368impl RoutingStrategy for TrustRouting {
369 async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
370 Ok(vec![target.clone()])
372 }
373
374 fn route_score(&self, neighbor: &NodeId, _target: &NodeId) -> f64 {
375 self.trust_provider.get_trust(neighbor)
376 }
377
378 fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
379 }
381}
382
383pub struct SOMRouting {
385 _som_positions: Arc<RwLock<HashMap<NodeId, [f64; 4]>>>,
386}
387
388impl Default for SOMRouting {
389 fn default() -> Self {
390 Self::new()
391 }
392}
393
394impl SOMRouting {
395 pub fn new() -> Self {
396 Self {
397 _som_positions: Arc::new(RwLock::new(HashMap::new())),
398 }
399 }
400}
401
402#[async_trait]
403impl RoutingStrategy for SOMRouting {
404 async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
405 Ok(vec![target.clone()])
407 }
408
409 fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
410 0.0 }
414
415 fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
416 }
418}
419
420pub use crate::adaptive::learning::ThompsonSampling;
423
424#[derive(Debug, Clone)]
428pub struct BetaDistribution {
429 alpha: f64,
430 beta: f64,
431}
432
433impl BetaDistribution {
434 pub fn new(alpha: f64, beta: f64) -> Self {
435 Self { alpha, beta }
436 }
437
438 pub fn sample(&self) -> f64 {
439 let total = self.alpha + self.beta;
442 self.alpha / total
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use rand_core::RngCore;
450
451 #[tokio::test]
452 async fn test_adaptive_router_creation() {
453 use crate::peer_record::UserId;
454 struct MockTrustProvider;
455 impl TrustProvider for MockTrustProvider {
456 fn get_trust(&self, _node: &NodeId) -> f64 {
457 0.5
458 }
459 fn update_trust(&self, _from: &NodeId, _to: &NodeId, _success: bool) {}
460 fn get_global_trust(&self) -> std::collections::HashMap<NodeId, f64> {
461 std::collections::HashMap::new()
462 }
463 fn remove_node(&self, _node: &NodeId) {}
464 }
465
466 let mut hash = [0u8; 32];
467 rand::thread_rng().fill_bytes(&mut hash);
468 let node_id = UserId::from_bytes(hash);
469 let trust_provider = Arc::new(MockTrustProvider);
470 let router = AdaptiveRouter::new_with_id(node_id, trust_provider);
471
472 let metrics = router.get_metrics().await;
473 assert!(metrics.is_empty());
474 }
475
476 #[test]
477 fn test_thompson_sampling() {
478 let bandit = ThompsonSampling::new();
479
480 let rt = tokio::runtime::Runtime::new().unwrap();
482 let strategy = rt
483 .block_on(bandit.select_strategy(ContentType::DHTLookup))
484 .unwrap_or(StrategyChoice::Kademlia);
485 assert!(matches!(
486 strategy,
487 StrategyChoice::Kademlia
488 | StrategyChoice::Hyperbolic
489 | StrategyChoice::TrustPath
490 | StrategyChoice::SOMRegion
491 ));
492
493 rt.block_on(bandit.update(ContentType::DHTLookup, strategy, true, 100))
495 .unwrap();
496 }
497
498 #[test]
499 fn test_hyperbolic_distance() {
500 let a = HyperbolicCoordinate { r: 0.0, theta: 0.0 };
501 let b = HyperbolicCoordinate {
502 r: 0.5,
503 theta: std::f64::consts::PI,
504 };
505
506 let distance = HyperbolicRouting::distance(&a, &b);
507 assert!(distance > 0.0);
508
509 let self_distance = HyperbolicRouting::distance(&a, &a);
511 assert!((self_distance - 0.0).abs() < 1e-10);
512 }
513}
514
515#[cfg(test)]
517pub struct MockRoutingStrategy {
518 nodes: Vec<NodeId>,
519}
520
521#[cfg(test)]
522impl MockRoutingStrategy {
523 pub fn new() -> Self {
524 Self {
525 nodes: vec![
526 NodeId { hash: [1u8; 32] },
527 NodeId { hash: [2u8; 32] },
528 NodeId { hash: [3u8; 32] },
529 NodeId { hash: [4u8; 32] },
530 NodeId { hash: [5u8; 32] },
531 ],
532 }
533 }
534}
535
536#[cfg(test)]
537#[async_trait]
538impl RoutingStrategy for MockRoutingStrategy {
539 async fn find_closest_nodes(&self, _target: &ContentHash, count: usize) -> Result<Vec<NodeId>> {
540 Ok(self.nodes.iter().take(count).cloned().collect())
541 }
542
543 async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
544 let mut path = vec![NodeId { hash: [0u8; 32] }]; if self.nodes.contains(target) {
546 path.push(target.clone());
547 }
548 Ok(path)
549 }
550
551 fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
552 0.5
553 }
554
555 fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
556 }
558}