saorsa_core/adaptive/
routing.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: saorsalabs@gmail.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! Adaptive routing system combining multiple strategies
15//!
16//! Implements multi-armed bandit routing selection using Thompson Sampling
17//! to dynamically choose between Kademlia, Hyperbolic, Trust-based, and SOM routing
18
19use super::*;
20use async_trait::async_trait;
21use std::collections::HashMap;
22use std::sync::Arc;
23use tokio::sync::RwLock;
24
25// Re-export types that other modules need
26pub use super::{ContentType, StrategyChoice};
27
28/// Routing statistics
29#[derive(Debug, Clone, Default)]
30pub struct RoutingStats {
31    /// Total routing requests
32    pub total_requests: u64,
33
34    /// Successful routing requests
35    pub successful_requests: u64,
36
37    /// Failed routing requests
38    pub failed_requests: u64,
39
40    /// Average latency in milliseconds
41    pub avg_latency_ms: f64,
42
43    /// Success rate by strategy
44    pub strategy_success: HashMap<String, f64>,
45}
46
47impl RoutingStats {
48    /// Calculate overall success rate
49    pub fn success_rate(&self) -> f64 {
50        if self.total_requests == 0 {
51            1.0
52        } else {
53            self.successful_requests as f64 / self.total_requests as f64
54        }
55    }
56}
57
58/// Adaptive router that combines multiple routing strategies
59pub struct AdaptiveRouter {
60    /// Local node ID
61    _local_id: NodeId,
62
63    /// Routing strategies
64    strategies: Arc<RwLock<HashMap<StrategyChoice, Box<dyn RoutingStrategy>>>>,
65
66    /// Multi-armed bandit for strategy selection
67    bandit: Arc<RwLock<ThompsonSampling>>,
68
69    /// Metrics collector
70    metrics: Arc<RwLock<HashMap<String, f64>>>,
71
72    /// Routing statistics
73    stats: Arc<RwLock<RoutingStats>>,
74}
75
76impl AdaptiveRouter {
77    /// Create a new adaptive router with multiple strategies
78    pub fn new(
79        trust_provider: Arc<dyn TrustProvider>,
80        _hyperbolic: Arc<hyperbolic::HyperbolicSpace>,
81        _som: Arc<som::SelfOrganizingMap>,
82    ) -> Self {
83        let node_id = NodeId { hash: [0u8; 32] }; // Default node ID
84        Self::new_with_id(node_id, trust_provider)
85    }
86
87    /// Create a new adaptive router with specific node ID
88    pub fn new_with_id(node_id: NodeId, _trust_provider: Arc<dyn TrustProvider>) -> Self {
89        Self {
90            _local_id: node_id,
91            strategies: Arc::new(RwLock::new(HashMap::new())),
92            bandit: Arc::new(RwLock::new(ThompsonSampling::new())),
93            metrics: Arc::new(RwLock::new(HashMap::new())),
94            stats: Arc::new(RwLock::new(RoutingStats::default())),
95        }
96    }
97
98    /// Register a routing strategy
99    pub async fn register_strategy(
100        &self,
101        choice: StrategyChoice,
102        strategy: Box<dyn RoutingStrategy>,
103    ) {
104        let mut strategies = self.strategies.write().await;
105        strategies.insert(choice, strategy);
106    }
107
108    /// Route a message to a target using the best strategy
109    pub async fn route(&self, target: &NodeId, content_type: ContentType) -> Result<Vec<NodeId>> {
110        // Select strategy using multi-armed bandit
111        let strategy_choice = self
112            .bandit
113            .read()
114            .await
115            .select_strategy(content_type)
116            .await
117            .unwrap_or(StrategyChoice::Kademlia);
118
119        // Record strategy selection
120        {
121            let mut metrics = self.metrics.write().await;
122            let key = format!("route_attempts_{strategy_choice:?}");
123            let count = metrics.get(&key).copied().unwrap_or(0.0) + 1.0;
124            metrics.insert(key, count);
125        }
126
127        // Execute routing with selected strategy
128        let start = std::time::Instant::now();
129        let strategies = self.strategies.read().await;
130
131        let result = if let Some(strategy) = strategies.get(&strategy_choice) {
132            let primary_result = strategy.find_path(target).await;
133
134            // If primary strategy fails and it's not Kademlia, try Kademlia as fallback
135            if primary_result.is_err() && strategy_choice != StrategyChoice::Kademlia {
136                if let Some(kademlia) = strategies.get(&StrategyChoice::Kademlia) {
137                    kademlia.find_path(target).await
138                } else {
139                    primary_result
140                }
141            } else {
142                primary_result
143            }
144        } else {
145            // If strategy not found, try Kademlia
146            if let Some(kademlia) = strategies.get(&StrategyChoice::Kademlia) {
147                kademlia.find_path(target).await
148            } else {
149                Err(AdaptiveNetworkError::Routing(
150                    "No routing strategies available".to_string(),
151                ))
152            }
153        };
154
155        // Update bandit based on result
156        let success = result.is_ok();
157        let latency = start.elapsed().as_millis() as f64;
158
159        self.bandit
160            .write()
161            .await
162            .update(content_type, strategy_choice, success, latency as u64)
163            .await
164            .unwrap_or(());
165
166        // Update metrics
167        if success {
168            let mut metrics = self.metrics.write().await;
169            let success_key = format!("route_success_{strategy_choice:?}");
170            let count = metrics.get(&success_key).copied().unwrap_or(0.0) + 1.0;
171            metrics.insert(success_key, count);
172            metrics.insert(format!("route_latency_{strategy_choice:?}"), latency);
173        }
174
175        result
176    }
177
178    /// Get routing metrics
179    pub async fn get_metrics(&self) -> std::collections::HashMap<String, f64> {
180        self.metrics.read().await.clone()
181    }
182
183    /// Get all routing strategies
184    pub fn get_all_strategies(&self) -> HashMap<String, Arc<dyn RoutingStrategy>> {
185        // For now, return empty as strategies are stored differently
186        HashMap::new()
187    }
188
189    /// Mark a node as unreliable
190    pub async fn mark_node_unreliable(&self, _node_id: &NodeId) {
191        // Update routing metrics to reflect unreliability
192        let strategies = self.strategies.read().await;
193        for (_choice, _strategy) in strategies.iter() {
194            // Would update metrics in real implementation
195        }
196    }
197
198    /// Remove a node from all routing tables
199    pub async fn remove_node(&self, _node_id: &NodeId) {
200        // In a real implementation, would remove from K-buckets, etc.
201        // log::info!("Removing node {:?} from routing tables", node_id);
202    }
203
204    /// Remove node's hyperbolic coordinates
205    pub async fn remove_hyperbolic_coordinate(&self, _node_id: &NodeId) {
206        // log::info!("Removing hyperbolic coordinates for {:?}", node_id);
207    }
208
209    /// Remove node from SOM
210    pub async fn remove_from_som(&self, _node_id: &NodeId) {
211        // log::info!("Removing {:?} from SOM", node_id);
212    }
213
214    /// Enable aggressive caching during high churn
215    pub async fn enable_aggressive_caching(&self) {
216        // log::info!("Enabling aggressive caching due to high churn");
217    }
218
219    /// Rebalance hyperbolic space after failures
220    pub async fn rebalance_hyperbolic_space(&self) {
221        // log::info!("Rebalancing hyperbolic space");
222    }
223
224    /// Update SOM grid after topology changes
225    pub async fn update_som_grid(&self) {
226        // log::info!("Updating SOM grid");
227    }
228
229    /// Trigger trust score recomputation
230    pub async fn trigger_trust_recomputation(&self) {
231        // log::info!("Triggering trust score recomputation");
232    }
233
234    /// Update routing statistics
235    pub async fn update_statistics(&self, node_id: &NodeId, success: bool, latency_ms: u64) {
236        let mut metrics = self.metrics.write().await;
237        let key = format!("node_{node_id:?}_success_rate");
238        let current = metrics.get(&key).copied().unwrap_or(0.0);
239        let new_value = if success {
240            current * 0.9 + 0.1
241        } else {
242            current * 0.9
243        };
244        metrics.insert(key, new_value);
245
246        // Update routing stats
247        let mut stats = self.stats.write().await;
248        stats.total_requests += 1;
249        if success {
250            stats.successful_requests += 1;
251        } else {
252            stats.failed_requests += 1;
253        }
254
255        // Update average latency
256        let current_avg = stats.avg_latency_ms;
257        let count = stats.total_requests as f64;
258        stats.avg_latency_ms = (current_avg * (count - 1.0) + latency_ms as f64) / count;
259    }
260
261    /// Get routing statistics
262    pub async fn get_stats(&self) -> RoutingStats {
263        self.stats.read().await.clone()
264    }
265}
266
267/// Kademlia routing implementation
268pub struct KademliaRouting {
269    _node_id: NodeId,
270    // Placeholder for routing table - would use actual implementation
271    _routing_table: Arc<RwLock<HashMap<NodeId, Vec<NodeId>>>>,
272}
273
274impl KademliaRouting {
275    pub fn new(node_id: NodeId) -> Self {
276        Self {
277            _node_id: node_id.clone(),
278            _routing_table: Arc::new(RwLock::new(HashMap::new())),
279        }
280    }
281}
282
283#[async_trait]
284impl RoutingStrategy for KademliaRouting {
285    async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
286        // Implementation would use the actual Kademlia lookup
287        // For now, return a placeholder
288        Ok(vec![target.clone()])
289    }
290
291    fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
292        // XOR distance metric
293        let neighbor_bytes = &_neighbor.hash;
294        let target_bytes = &_target.hash;
295        let mut distance = 0u32;
296
297        for i in 0..32 {
298            distance += (neighbor_bytes[i] ^ target_bytes[i]).count_ones();
299        }
300
301        // Convert to score (closer = higher score)
302        1.0 / (1.0 + distance as f64)
303    }
304
305    fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
306        // Update routing table based on success/failure
307    }
308}
309
310/// Hyperbolic routing implementation
311pub struct HyperbolicRouting {
312    _coordinates: Arc<RwLock<HashMap<NodeId, HyperbolicCoordinate>>>,
313}
314
315impl Default for HyperbolicRouting {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321impl HyperbolicRouting {
322    pub fn new() -> Self {
323        Self {
324            _coordinates: Arc::new(RwLock::new(HashMap::new())),
325        }
326    }
327
328    /// Calculate hyperbolic distance between two coordinates
329    pub fn distance(a: &HyperbolicCoordinate, b: &HyperbolicCoordinate) -> f64 {
330        let delta = 2.0 * ((a.r - b.r).powi(2) + (a.theta - b.theta).cos().acos().powi(2)).sqrt();
331        let denominator = (1.0 - a.r.powi(2)) * (1.0 - b.r.powi(2));
332
333        (1.0 + delta / denominator).acosh()
334    }
335}
336
337#[async_trait]
338impl RoutingStrategy for HyperbolicRouting {
339    async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
340        // Greedy routing in hyperbolic space
341        // For now, return a placeholder
342        Ok(vec![target.clone()])
343    }
344
345    fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
346        // Score based on hyperbolic distance
347        // Note: This is synchronous, so we can't use async
348        0.0 // Placeholder for now
349    }
350
351    fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
352        // Update coordinate estimates
353    }
354}
355
356/// Trust-based routing implementation
357pub struct TrustRouting {
358    trust_provider: Arc<dyn TrustProvider>,
359}
360
361impl TrustRouting {
362    pub fn new(trust_provider: Arc<dyn TrustProvider>) -> Self {
363        Self { trust_provider }
364    }
365}
366
367#[async_trait]
368impl RoutingStrategy for TrustRouting {
369    async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
370        // Route through high-trust nodes
371        Ok(vec![target.clone()])
372    }
373
374    fn route_score(&self, neighbor: &NodeId, _target: &NodeId) -> f64 {
375        self.trust_provider.get_trust(neighbor)
376    }
377
378    fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
379        // Trust updates handled by trust provider
380    }
381}
382
383/// SOM-based routing implementation
384pub struct SOMRouting {
385    _som_positions: Arc<RwLock<HashMap<NodeId, [f64; 4]>>>,
386}
387
388impl Default for SOMRouting {
389    fn default() -> Self {
390        Self::new()
391    }
392}
393
394impl SOMRouting {
395    pub fn new() -> Self {
396        Self {
397            _som_positions: Arc::new(RwLock::new(HashMap::new())),
398        }
399    }
400}
401
402#[async_trait]
403impl RoutingStrategy for SOMRouting {
404    async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
405        // Route through similar nodes in SOM space
406        Ok(vec![target.clone()])
407    }
408
409    fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
410        // Score based on SOM distance
411        // Note: This is synchronous, so we can't use async
412        0.0 // Placeholder for now
413    }
414
415    fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
416        // Update SOM positions
417    }
418}
419
420// ThompsonSampling has been moved to learning.rs for a more comprehensive implementation
421// Re-export it from learning module
422pub use crate::adaptive::learning::ThompsonSampling;
423
424// Implementation moved to learning.rs - using the more comprehensive version from there
425
426/// Beta distribution for Thompson Sampling
427#[derive(Debug, Clone)]
428pub struct BetaDistribution {
429    alpha: f64,
430    beta: f64,
431}
432
433impl BetaDistribution {
434    pub fn new(alpha: f64, beta: f64) -> Self {
435        Self { alpha, beta }
436    }
437
438    pub fn sample(&self) -> f64 {
439        // Using a simple approximation for Beta distribution
440        // In production, use rand_distr crate for proper Beta distribution
441        let total = self.alpha + self.beta;
442        self.alpha / total
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use rand_core::RngCore;
450
451    #[tokio::test]
452    async fn test_adaptive_router_creation() {
453        use crate::peer_record::UserId;
454        struct MockTrustProvider;
455        impl TrustProvider for MockTrustProvider {
456            fn get_trust(&self, _node: &NodeId) -> f64 {
457                0.5
458            }
459            fn update_trust(&self, _from: &NodeId, _to: &NodeId, _success: bool) {}
460            fn get_global_trust(&self) -> std::collections::HashMap<NodeId, f64> {
461                std::collections::HashMap::new()
462            }
463            fn remove_node(&self, _node: &NodeId) {}
464        }
465
466        let mut hash = [0u8; 32];
467        rand::thread_rng().fill_bytes(&mut hash);
468        let node_id = UserId::from_bytes(hash);
469        let trust_provider = Arc::new(MockTrustProvider);
470        let router = AdaptiveRouter::new_with_id(node_id, trust_provider);
471
472        let metrics = router.get_metrics().await;
473        assert!(metrics.is_empty());
474    }
475
476    #[test]
477    fn test_thompson_sampling() {
478        let bandit = ThompsonSampling::new();
479
480        // Test that it returns a strategy
481        let rt = tokio::runtime::Runtime::new().unwrap();
482        let strategy = rt
483            .block_on(bandit.select_strategy(ContentType::DHTLookup))
484            .unwrap_or(StrategyChoice::Kademlia);
485        assert!(matches!(
486            strategy,
487            StrategyChoice::Kademlia
488                | StrategyChoice::Hyperbolic
489                | StrategyChoice::TrustPath
490                | StrategyChoice::SOMRegion
491        ));
492
493        // Test update
494        rt.block_on(bandit.update(ContentType::DHTLookup, strategy, true, 100))
495            .unwrap();
496    }
497
498    #[test]
499    fn test_hyperbolic_distance() {
500        let a = HyperbolicCoordinate { r: 0.0, theta: 0.0 };
501        let b = HyperbolicCoordinate {
502            r: 0.5,
503            theta: std::f64::consts::PI,
504        };
505
506        let distance = HyperbolicRouting::distance(&a, &b);
507        assert!(distance > 0.0);
508
509        // Distance to self should be 0
510        let self_distance = HyperbolicRouting::distance(&a, &a);
511        assert!((self_distance - 0.0).abs() < 1e-10);
512    }
513}
514
515/// Mock routing strategy for testing
516#[cfg(test)]
517pub struct MockRoutingStrategy {
518    nodes: Vec<NodeId>,
519}
520
521#[cfg(test)]
522impl MockRoutingStrategy {
523    pub fn new() -> Self {
524        Self {
525            nodes: vec![
526                NodeId { hash: [1u8; 32] },
527                NodeId { hash: [2u8; 32] },
528                NodeId { hash: [3u8; 32] },
529                NodeId { hash: [4u8; 32] },
530                NodeId { hash: [5u8; 32] },
531            ],
532        }
533    }
534}
535
536#[cfg(test)]
537#[async_trait]
538impl RoutingStrategy for MockRoutingStrategy {
539    async fn find_closest_nodes(&self, _target: &ContentHash, count: usize) -> Result<Vec<NodeId>> {
540        Ok(self.nodes.iter().take(count).cloned().collect())
541    }
542
543    async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
544        let mut path = vec![NodeId { hash: [0u8; 32] }]; // Start node
545        if self.nodes.contains(target) {
546            path.push(target.clone());
547        }
548        Ok(path)
549    }
550
551    fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
552        0.5
553    }
554
555    fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
556        // Mock implementation - do nothing
557    }
558}