1use async_trait::async_trait;
20use std::collections::HashMap;
21use tokio::sync::RwLock;
22
23use super::AdaptiveNetworkError;
24use super::NodeId;
25use super::Result;
26use super::learning::ThompsonSampling;
27use super::{HyperbolicCoordinate, RoutingStrategy, TrustProvider};
28use std::sync::Arc;
29
30pub use super::{ContentType, StrategyChoice};
32
33#[derive(Debug, Clone, Default)]
35pub struct RoutingStats {
36 pub total_requests: u64,
38
39 pub successful_requests: u64,
41
42 pub failed_requests: u64,
44
45 pub avg_latency_ms: f64,
47
48 pub strategy_success: HashMap<String, f64>,
50}
51
52impl RoutingStats {
53 pub fn success_rate(&self) -> f64 {
55 if self.total_requests == 0 {
56 1.0
57 } else {
58 self.successful_requests as f64 / self.total_requests as f64
59 }
60 }
61}
62
63pub struct AdaptiveRouter {
65 _local_id: NodeId,
67
68 strategies: Arc<RwLock<HashMap<StrategyChoice, Box<dyn RoutingStrategy>>>>,
70
71 bandit: Arc<RwLock<ThompsonSampling>>,
73
74 metrics: Arc<RwLock<HashMap<String, f64>>>,
76
77 stats: Arc<RwLock<RoutingStats>>,
79}
80
81impl AdaptiveRouter {
82 pub fn new(trust_provider: Arc<dyn TrustProvider>) -> Self {
84 let node_id = NodeId::from_bytes([0u8; 32]); Self::new_with_id(node_id, trust_provider)
86 }
87 pub fn new_with_id(node_id: NodeId, _trust_provider: Arc<dyn TrustProvider>) -> Self {
89 Self {
90 _local_id: node_id,
91 strategies: Arc::new(RwLock::new(HashMap::new())),
92 bandit: Arc::new(RwLock::new(ThompsonSampling::new())),
93 metrics: Arc::new(RwLock::new(HashMap::new())),
94 stats: Arc::new(RwLock::new(RoutingStats::default())),
95 }
96 }
97
98 pub async fn register_strategy(
100 &self,
101 choice: StrategyChoice,
102 strategy: Box<dyn RoutingStrategy>,
103 ) {
104 let mut strategies = self.strategies.write().await;
105 strategies.insert(choice, strategy);
106 }
107
108 pub async fn route(
110 &self,
111 target: &NodeId,
112 content_type: ContentType,
113 ) -> std::result::Result<Vec<NodeId>, AdaptiveNetworkError> {
114 let strategy_choice = self
116 .bandit
117 .read()
118 .await
119 .select_strategy(content_type)
120 .await
121 .unwrap_or(StrategyChoice::Kademlia);
122
123 {
125 let mut metrics = self.metrics.write().await;
126 let key = format!("route_attempts_{strategy_choice:?}");
127 let count = metrics.get(&key).copied().unwrap_or(0.0) + 1.0;
128 metrics.insert(key, count);
129 }
130
131 let start = std::time::Instant::now();
133 let strategies = self.strategies.read().await;
134
135 let result = if let Some(strategy) = strategies.get(&strategy_choice) {
136 let primary_result = strategy.find_path(target).await;
137
138 if primary_result.is_err() && strategy_choice != StrategyChoice::Kademlia {
140 if let Some(kademlia) = strategies.get(&StrategyChoice::Kademlia) {
141 kademlia.find_path(target).await
142 } else {
143 primary_result
144 }
145 } else {
146 primary_result
147 }
148 } else {
149 if let Some(kademlia) = strategies.get(&StrategyChoice::Kademlia) {
151 kademlia.find_path(target).await
152 } else {
153 Err(AdaptiveNetworkError::Routing(
154 "No routing strategies available".to_string(),
155 ))
156 }
157 };
158
159 let success = result.is_ok();
161 let latency = start.elapsed().as_millis() as f64;
162
163 self.bandit
164 .write()
165 .await
166 .update(content_type, strategy_choice, success, latency as u64)
167 .await
168 .unwrap_or(());
169
170 if success {
172 let mut metrics = self.metrics.write().await;
173 let success_key = format!("route_success_{strategy_choice:?}");
174 let count = metrics.get(&success_key).copied().unwrap_or(0.0) + 1.0;
175 metrics.insert(success_key, count);
176 metrics.insert(format!("route_latency_{strategy_choice:?}"), latency);
177 }
178
179 result
180 }
181
182 pub async fn get_metrics(&self) -> std::collections::HashMap<String, f64> {
184 self.metrics.read().await.clone()
185 }
186
187 pub fn get_all_strategies(&self) -> HashMap<String, Arc<dyn RoutingStrategy>> {
189 HashMap::new()
191 }
192
193 pub async fn mark_node_unreliable(&self, _node_id: &NodeId) {
195 let strategies = self.strategies.read().await;
197 for (_choice, _strategy) in strategies.iter() {
198 }
200 }
201
202 pub async fn remove_node(&self, _node_id: &NodeId) {
204 }
207
208 pub async fn remove_hyperbolic_coordinate(&self, _node_id: &NodeId) {
210 }
212
213 pub async fn remove_from_som(&self, _node_id: &NodeId) {
215 }
217
218 pub async fn enable_aggressive_caching(&self) {
220 }
222
223 pub async fn rebalance_hyperbolic_space(&self) {
225 }
227
228 pub async fn update_som_grid(&self) {
230 }
232
233 pub async fn trigger_trust_recomputation(&self) {
235 }
237
238 pub async fn update_statistics(&self, node_id: &NodeId, success: bool, latency_ms: u64) {
240 let mut metrics = self.metrics.write().await;
241 let key = format!("node_{node_id:?}_success_rate");
242 let current = metrics.get(&key).copied().unwrap_or(0.0);
243 let new_value = if success {
244 current * 0.9 + 0.1
245 } else {
246 current * 0.9
247 };
248 metrics.insert(key, new_value);
249
250 let mut stats = self.stats.write().await;
252 stats.total_requests += 1;
253 if success {
254 stats.successful_requests += 1;
255 } else {
256 stats.failed_requests += 1;
257 }
258
259 let current_avg = stats.avg_latency_ms;
261 let count = stats.total_requests as f64;
262 stats.avg_latency_ms = (current_avg * (count - 1.0) + latency_ms as f64) / count;
263 }
264
265 pub async fn get_stats(&self) -> RoutingStats {
267 self.stats.read().await.clone()
268 }
269}
270
271pub struct KademliaRouting {
273 _node_id: NodeId,
274 _routing_table: Arc<RwLock<HashMap<NodeId, Vec<NodeId>>>>,
276}
277
278impl KademliaRouting {
279 pub fn new(node_id: NodeId) -> Self {
280 Self {
281 _node_id: node_id.clone(),
282 _routing_table: Arc::new(RwLock::new(HashMap::new())),
283 }
284 }
285}
286
287#[async_trait]
288impl RoutingStrategy for KademliaRouting {
289 async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
290 Ok(vec![target.clone()])
293 }
294
295 fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
296 let neighbor_bytes = &_neighbor.hash;
298 let target_bytes = &_target.hash;
299 let mut distance = 0u32;
300
301 for i in 0..32 {
302 distance += (neighbor_bytes[i] ^ target_bytes[i]).count_ones();
303 }
304
305 1.0 / (1.0 + distance as f64)
307 }
308
309 fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
310 }
312}
313
314pub struct HyperbolicRouting {
316 _coordinates: Arc<RwLock<HashMap<NodeId, HyperbolicCoordinate>>>,
317}
318
319impl Default for HyperbolicRouting {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325impl HyperbolicRouting {
326 pub fn new() -> Self {
327 Self {
328 _coordinates: Arc::new(RwLock::new(HashMap::new())),
329 }
330 }
331
332 pub fn distance(a: &HyperbolicCoordinate, b: &HyperbolicCoordinate) -> f64 {
334 let delta = 2.0 * ((a.r - b.r).powi(2) + (a.theta - b.theta).cos().acos().powi(2)).sqrt();
335 let denominator = (1.0 - a.r.powi(2)) * (1.0 - b.r.powi(2));
336
337 (1.0 + delta / denominator).acosh()
338 }
339}
340
341#[async_trait]
342impl RoutingStrategy for HyperbolicRouting {
343 async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
344 Ok(vec![target.clone()])
347 }
348
349 fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
350 0.0 }
354
355 fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
356 }
358}
359
360pub struct TrustRouting {
362 trust_provider: Arc<dyn TrustProvider>,
363}
364
365impl TrustRouting {
366 pub fn new(trust_provider: Arc<dyn TrustProvider>) -> Self {
367 Self { trust_provider }
368 }
369}
370
371#[async_trait]
372impl RoutingStrategy for TrustRouting {
373 async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
374 Ok(vec![target.clone()])
376 }
377
378 fn route_score(&self, neighbor: &NodeId, _target: &NodeId) -> f64 {
379 self.trust_provider.get_trust(neighbor)
380 }
381
382 fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
383 }
385}
386
387pub struct SOMRouting {
389 _som_positions: Arc<RwLock<HashMap<NodeId, [f64; 4]>>>,
390}
391
392impl Default for SOMRouting {
393 fn default() -> Self {
394 Self::new()
395 }
396}
397
398impl SOMRouting {
399 pub fn new() -> Self {
400 Self {
401 _som_positions: Arc::new(RwLock::new(HashMap::new())),
402 }
403 }
404}
405
406#[async_trait]
407impl RoutingStrategy for SOMRouting {
408 async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
409 Ok(vec![target.clone()])
411 }
412
413 fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
414 0.0 }
418
419 fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
420 }
422}
423
424#[derive(Debug, Clone)]
431pub struct BetaDistribution {
432 alpha: f64,
433 beta: f64,
434}
435
436impl BetaDistribution {
437 pub fn new(alpha: f64, beta: f64) -> Self {
438 Self { alpha, beta }
439 }
440
441 pub fn sample(&self) -> f64 {
442 let total = self.alpha + self.beta;
445 self.alpha / total
446 }
447}
448
449#[cfg(test)]
451pub struct MockRoutingStrategy {
452 nodes: Vec<NodeId>,
453}
454
455#[cfg(test)]
456impl Default for MockRoutingStrategy {
457 fn default() -> Self {
458 Self::new()
459 }
460}
461#[cfg(test)]
462impl MockRoutingStrategy {
463 pub fn new() -> Self {
464 Self {
465 nodes: vec![
466 NodeId { hash: [1u8; 32] },
467 NodeId { hash: [2u8; 32] },
468 NodeId { hash: [3u8; 32] },
469 NodeId { hash: [4u8; 32] },
470 NodeId { hash: [5u8; 32] },
471 ],
472 }
473 }
474}
475
476#[cfg(test)]
477#[async_trait]
478impl RoutingStrategy for MockRoutingStrategy {
479 async fn find_closest_nodes(
480 &self,
481 _target: &super::ContentHash,
482 count: usize,
483 ) -> std::result::Result<Vec<NodeId>, AdaptiveNetworkError> {
484 Ok(self.nodes.iter().take(count).cloned().collect())
485 }
486
487 async fn find_path(
488 &self,
489 target: &NodeId,
490 ) -> std::result::Result<Vec<NodeId>, AdaptiveNetworkError> {
491 let mut path = vec![NodeId { hash: [0u8; 32] }]; if self.nodes.contains(target) {
493 path.push(target.clone());
494 }
495 Ok(path)
496 }
497
498 fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
499 0.5
500 }
501
502 fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
503 }
505}
506
507#[cfg(test)]
508mod tests {
509 #[allow(unused_imports)]
510 use super::{
511 AdaptiveRouter, ContentType, HyperbolicCoordinate, HyperbolicRouting, NodeId,
512 StrategyChoice, ThompsonSampling, TrustProvider,
513 };
514 use rand::RngCore;
515 use std::sync::Arc;
516 #[tokio::test]
517 async fn test_adaptive_router_creation() {
518 struct MockTrustProvider;
519 impl TrustProvider for MockTrustProvider {
520 fn get_trust(&self, _node: &NodeId) -> f64 {
521 0.5
522 }
523 fn update_trust(&self, _from: &NodeId, _to: &NodeId, _success: bool) {}
524 fn get_global_trust(&self) -> std::collections::HashMap<NodeId, f64> {
525 std::collections::HashMap::new()
526 }
527 fn remove_node(&self, _node: &NodeId) {}
528 }
529
530 let mut hash = [0u8; 32];
531 rand::thread_rng().fill_bytes(&mut hash);
532 let node_id = NodeId { hash }; let trust_provider = Arc::new(MockTrustProvider);
534 let router = AdaptiveRouter::new_with_id(node_id, trust_provider);
535
536 let metrics = router.get_metrics().await;
537 assert!(!metrics.is_empty() || metrics.is_empty()); }
540
541 #[allow(clippy::unwrap_used)]
542 #[test]
543 fn test_thompson_sampling() {
544 let bandit = ThompsonSampling::new();
545
546 let rt = tokio::runtime::Runtime::new().unwrap();
547 let strategy = rt
548 .block_on(bandit.select_strategy(ContentType::DHTLookup))
549 .unwrap_or(StrategyChoice::Kademlia);
550
551 matches!(
553 strategy,
554 StrategyChoice::Kademlia
555 | StrategyChoice::Hyperbolic
556 | StrategyChoice::TrustPath
557 | StrategyChoice::SOMRegion
558 );
559
560 rt.block_on(bandit.update(ContentType::DHTLookup, strategy, true, 100))
562 .unwrap();
563 }
564
565 #[test]
566 fn test_hyperbolic_distance() {
567 let a = HyperbolicCoordinate { r: 0.0, theta: 0.0 };
568 let b = HyperbolicCoordinate {
569 r: 0.5,
570 theta: std::f64::consts::PI,
571 };
572
573 let distance = HyperbolicRouting::distance(&a, &b);
574 assert!(distance > 0.0);
575
576 let self_distance = HyperbolicRouting::distance(&a, &a);
578 assert!((self_distance - 0.0).abs() < 1e-10);
579 }
580}