saorsa_core/adaptive/
routing.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: saorsalabs@gmail.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! Adaptive routing system combining multiple strategies
15//!
16//! Implements multi-armed bandit routing selection using Thompson Sampling
17//! to dynamically choose between Kademlia, Hyperbolic, Trust-based, and SOM routing
18
19use async_trait::async_trait;
20use std::collections::HashMap;
21use tokio::sync::RwLock;
22
23use super::AdaptiveNetworkError;
24use super::NodeId;
25use super::Result;
26use super::learning::ThompsonSampling;
27use super::{HyperbolicCoordinate, RoutingStrategy, TrustProvider};
28use std::sync::Arc;
29
30// Re-export types that other modules need
31pub use super::{ContentType, StrategyChoice};
32
33/// Routing statistics
34#[derive(Debug, Clone, Default)]
35pub struct RoutingStats {
36    /// Total routing requests
37    pub total_requests: u64,
38
39    /// Successful routing requests
40    pub successful_requests: u64,
41
42    /// Failed routing requests
43    pub failed_requests: u64,
44
45    /// Average latency in milliseconds
46    pub avg_latency_ms: f64,
47
48    /// Success rate by strategy
49    pub strategy_success: HashMap<String, f64>,
50}
51
52impl RoutingStats {
53    /// Calculate overall success rate
54    pub fn success_rate(&self) -> f64 {
55        if self.total_requests == 0 {
56            1.0
57        } else {
58            self.successful_requests as f64 / self.total_requests as f64
59        }
60    }
61}
62
63/// Adaptive router that combines multiple routing strategies
64pub struct AdaptiveRouter {
65    /// Local node ID
66    _local_id: NodeId,
67
68    /// Routing strategies
69    strategies: Arc<RwLock<HashMap<StrategyChoice, Box<dyn RoutingStrategy>>>>,
70
71    /// Multi-armed bandit for strategy selection
72    bandit: Arc<RwLock<ThompsonSampling>>,
73
74    /// Metrics collector
75    metrics: Arc<RwLock<HashMap<String, f64>>>,
76
77    /// Routing statistics
78    stats: Arc<RwLock<RoutingStats>>,
79}
80
81impl AdaptiveRouter {
82    /// Create a new adaptive router with multiple strategies
83    pub fn new(trust_provider: Arc<dyn TrustProvider>) -> Self {
84        let node_id = NodeId::from_bytes([0u8; 32]); // Default node ID
85        Self::new_with_id(node_id, trust_provider)
86    }
87    /// Create a new adaptive router with specific node ID
88    pub fn new_with_id(node_id: NodeId, _trust_provider: Arc<dyn TrustProvider>) -> Self {
89        Self {
90            _local_id: node_id,
91            strategies: Arc::new(RwLock::new(HashMap::new())),
92            bandit: Arc::new(RwLock::new(ThompsonSampling::new())),
93            metrics: Arc::new(RwLock::new(HashMap::new())),
94            stats: Arc::new(RwLock::new(RoutingStats::default())),
95        }
96    }
97
98    /// Register a routing strategy
99    pub async fn register_strategy(
100        &self,
101        choice: StrategyChoice,
102        strategy: Box<dyn RoutingStrategy>,
103    ) {
104        let mut strategies = self.strategies.write().await;
105        strategies.insert(choice, strategy);
106    }
107
108    /// Route a message to a target using the best strategy
109    pub async fn route(
110        &self,
111        target: &NodeId,
112        content_type: ContentType,
113    ) -> std::result::Result<Vec<NodeId>, AdaptiveNetworkError> {
114        // Select strategy using multi-armed bandit
115        let strategy_choice = self
116            .bandit
117            .read()
118            .await
119            .select_strategy(content_type)
120            .await
121            .unwrap_or(StrategyChoice::Kademlia);
122
123        // Record strategy selection
124        {
125            let mut metrics = self.metrics.write().await;
126            let key = format!("route_attempts_{strategy_choice:?}");
127            let count = metrics.get(&key).copied().unwrap_or(0.0) + 1.0;
128            metrics.insert(key, count);
129        }
130
131        // Execute routing with selected strategy
132        let start = std::time::Instant::now();
133        let strategies = self.strategies.read().await;
134
135        let result = if let Some(strategy) = strategies.get(&strategy_choice) {
136            let primary_result = strategy.find_path(target).await;
137
138            // If primary strategy fails and it's not Kademlia, try Kademlia as fallback
139            if primary_result.is_err() && strategy_choice != StrategyChoice::Kademlia {
140                if let Some(kademlia) = strategies.get(&StrategyChoice::Kademlia) {
141                    kademlia.find_path(target).await
142                } else {
143                    primary_result
144                }
145            } else {
146                primary_result
147            }
148        } else {
149            // If strategy not found, try Kademlia
150            if let Some(kademlia) = strategies.get(&StrategyChoice::Kademlia) {
151                kademlia.find_path(target).await
152            } else {
153                Err(AdaptiveNetworkError::Routing(
154                    "No routing strategies available".to_string(),
155                ))
156            }
157        };
158
159        // Update bandit based on result
160        let success = result.is_ok();
161        let latency = start.elapsed().as_millis() as f64;
162
163        self.bandit
164            .write()
165            .await
166            .update(content_type, strategy_choice, success, latency as u64)
167            .await
168            .unwrap_or(());
169
170        // Update metrics
171        if success {
172            let mut metrics = self.metrics.write().await;
173            let success_key = format!("route_success_{strategy_choice:?}");
174            let count = metrics.get(&success_key).copied().unwrap_or(0.0) + 1.0;
175            metrics.insert(success_key, count);
176            metrics.insert(format!("route_latency_{strategy_choice:?}"), latency);
177        }
178
179        result
180    }
181
182    /// Get routing metrics
183    pub async fn get_metrics(&self) -> std::collections::HashMap<String, f64> {
184        self.metrics.read().await.clone()
185    }
186
187    /// Get all routing strategies
188    pub fn get_all_strategies(&self) -> HashMap<String, Arc<dyn RoutingStrategy>> {
189        // For now, return empty as strategies are stored differently
190        HashMap::new()
191    }
192
193    /// Mark a node as unreliable
194    pub async fn mark_node_unreliable(&self, _node_id: &NodeId) {
195        // Update routing metrics to reflect unreliability
196        let strategies = self.strategies.read().await;
197        for (_choice, _strategy) in strategies.iter() {
198            // Would update metrics in real implementation
199        }
200    }
201
202    /// Remove a node from all routing tables
203    pub async fn remove_node(&self, _node_id: &NodeId) {
204        // In a real implementation, would remove from K-buckets, etc.
205        // log::info!("Removing node {:?} from routing tables", node_id);
206    }
207
208    /// Remove node's hyperbolic coordinates
209    pub async fn remove_hyperbolic_coordinate(&self, _node_id: &NodeId) {
210        // log::info!("Removing hyperbolic coordinates for {:?}", node_id);
211    }
212
213    /// Remove node from SOM
214    pub async fn remove_from_som(&self, _node_id: &NodeId) {
215        // log::info!("Removing {:?} from SOM", node_id);
216    }
217
218    /// Enable aggressive caching during high churn
219    pub async fn enable_aggressive_caching(&self) {
220        // log::info!("Enabling aggressive caching due to high churn");
221    }
222
223    /// Rebalance hyperbolic space after failures
224    pub async fn rebalance_hyperbolic_space(&self) {
225        // log::info!("Rebalancing hyperbolic space");
226    }
227
228    /// Update SOM grid after topology changes
229    pub async fn update_som_grid(&self) {
230        // log::info!("Updating SOM grid");
231    }
232
233    /// Trigger trust score recomputation
234    pub async fn trigger_trust_recomputation(&self) {
235        // log::info!("Triggering trust score recomputation");
236    }
237
238    /// Update routing statistics
239    pub async fn update_statistics(&self, node_id: &NodeId, success: bool, latency_ms: u64) {
240        let mut metrics = self.metrics.write().await;
241        let key = format!("node_{node_id:?}_success_rate");
242        let current = metrics.get(&key).copied().unwrap_or(0.0);
243        let new_value = if success {
244            current * 0.9 + 0.1
245        } else {
246            current * 0.9
247        };
248        metrics.insert(key, new_value);
249
250        // Update routing stats
251        let mut stats = self.stats.write().await;
252        stats.total_requests += 1;
253        if success {
254            stats.successful_requests += 1;
255        } else {
256            stats.failed_requests += 1;
257        }
258
259        // Update average latency
260        let current_avg = stats.avg_latency_ms;
261        let count = stats.total_requests as f64;
262        stats.avg_latency_ms = (current_avg * (count - 1.0) + latency_ms as f64) / count;
263    }
264
265    /// Get routing statistics
266    pub async fn get_stats(&self) -> RoutingStats {
267        self.stats.read().await.clone()
268    }
269}
270
271/// Kademlia routing implementation
272pub struct KademliaRouting {
273    _node_id: NodeId,
274    // Placeholder for routing table - would use actual implementation
275    _routing_table: Arc<RwLock<HashMap<NodeId, Vec<NodeId>>>>,
276}
277
278impl KademliaRouting {
279    pub fn new(node_id: NodeId) -> Self {
280        Self {
281            _node_id: node_id.clone(),
282            _routing_table: Arc::new(RwLock::new(HashMap::new())),
283        }
284    }
285}
286
287#[async_trait]
288impl RoutingStrategy for KademliaRouting {
289    async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
290        // Implementation would use the actual Kademlia lookup
291        // For now, return a placeholder
292        Ok(vec![target.clone()])
293    }
294
295    fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
296        // XOR distance metric
297        let neighbor_bytes = &_neighbor.hash;
298        let target_bytes = &_target.hash;
299        let mut distance = 0u32;
300
301        for i in 0..32 {
302            distance += (neighbor_bytes[i] ^ target_bytes[i]).count_ones();
303        }
304
305        // Convert to score (closer = higher score)
306        1.0 / (1.0 + distance as f64)
307    }
308
309    fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
310        // Update routing table based on success/failure
311    }
312}
313
314/// Hyperbolic routing implementation
315pub struct HyperbolicRouting {
316    _coordinates: Arc<RwLock<HashMap<NodeId, HyperbolicCoordinate>>>,
317}
318
319impl Default for HyperbolicRouting {
320    fn default() -> Self {
321        Self::new()
322    }
323}
324
325impl HyperbolicRouting {
326    pub fn new() -> Self {
327        Self {
328            _coordinates: Arc::new(RwLock::new(HashMap::new())),
329        }
330    }
331
332    /// Calculate hyperbolic distance between two coordinates
333    pub fn distance(a: &HyperbolicCoordinate, b: &HyperbolicCoordinate) -> f64 {
334        let delta = 2.0 * ((a.r - b.r).powi(2) + (a.theta - b.theta).cos().acos().powi(2)).sqrt();
335        let denominator = (1.0 - a.r.powi(2)) * (1.0 - b.r.powi(2));
336
337        (1.0 + delta / denominator).acosh()
338    }
339}
340
341#[async_trait]
342impl RoutingStrategy for HyperbolicRouting {
343    async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
344        // Greedy routing in hyperbolic space
345        // For now, return a placeholder
346        Ok(vec![target.clone()])
347    }
348
349    fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
350        // Score based on hyperbolic distance
351        // Note: This is synchronous, so we can't use async
352        0.0 // Placeholder for now
353    }
354
355    fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
356        // Update coordinate estimates
357    }
358}
359
360/// Trust-based routing implementation
361pub struct TrustRouting {
362    trust_provider: Arc<dyn TrustProvider>,
363}
364
365impl TrustRouting {
366    pub fn new(trust_provider: Arc<dyn TrustProvider>) -> Self {
367        Self { trust_provider }
368    }
369}
370
371#[async_trait]
372impl RoutingStrategy for TrustRouting {
373    async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
374        // Route through high-trust nodes
375        Ok(vec![target.clone()])
376    }
377
378    fn route_score(&self, neighbor: &NodeId, _target: &NodeId) -> f64 {
379        self.trust_provider.get_trust(neighbor)
380    }
381
382    fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
383        // Trust updates handled by trust provider
384    }
385}
386
387/// SOM-based routing implementation
388pub struct SOMRouting {
389    _som_positions: Arc<RwLock<HashMap<NodeId, [f64; 4]>>>,
390}
391
392impl Default for SOMRouting {
393    fn default() -> Self {
394        Self::new()
395    }
396}
397
398impl SOMRouting {
399    pub fn new() -> Self {
400        Self {
401            _som_positions: Arc::new(RwLock::new(HashMap::new())),
402        }
403    }
404}
405
406#[async_trait]
407impl RoutingStrategy for SOMRouting {
408    async fn find_path(&self, target: &NodeId) -> Result<Vec<NodeId>> {
409        // Route through similar nodes in SOM space
410        Ok(vec![target.clone()])
411    }
412
413    fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
414        // Score based on SOM distance
415        // Note: This is synchronous, so we can't use async
416        0.0 // Placeholder for now
417    }
418
419    fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
420        // Update SOM positions
421    }
422}
423
424// ThompsonSampling has been moved to learning.rs for a more comprehensive implementation
425// Re-export it from learning module
426
427// Implementation moved to learning.rs - using the more comprehensive version from there
428
429/// Beta distribution for Thompson Sampling
430#[derive(Debug, Clone)]
431pub struct BetaDistribution {
432    alpha: f64,
433    beta: f64,
434}
435
436impl BetaDistribution {
437    pub fn new(alpha: f64, beta: f64) -> Self {
438        Self { alpha, beta }
439    }
440
441    pub fn sample(&self) -> f64 {
442        // Using a simple approximation for Beta distribution
443        // In production, use rand_distr crate for proper Beta distribution
444        let total = self.alpha + self.beta;
445        self.alpha / total
446    }
447}
448
449/// Mock routing strategy for testing
450#[cfg(test)]
451pub struct MockRoutingStrategy {
452    nodes: Vec<NodeId>,
453}
454
455#[cfg(test)]
456impl Default for MockRoutingStrategy {
457    fn default() -> Self {
458        Self::new()
459    }
460}
461#[cfg(test)]
462impl MockRoutingStrategy {
463    pub fn new() -> Self {
464        Self {
465            nodes: vec![
466                NodeId { hash: [1u8; 32] },
467                NodeId { hash: [2u8; 32] },
468                NodeId { hash: [3u8; 32] },
469                NodeId { hash: [4u8; 32] },
470                NodeId { hash: [5u8; 32] },
471            ],
472        }
473    }
474}
475
476#[cfg(test)]
477#[async_trait]
478impl RoutingStrategy for MockRoutingStrategy {
479    async fn find_closest_nodes(
480        &self,
481        _target: &super::ContentHash,
482        count: usize,
483    ) -> std::result::Result<Vec<NodeId>, AdaptiveNetworkError> {
484        Ok(self.nodes.iter().take(count).cloned().collect())
485    }
486
487    async fn find_path(
488        &self,
489        target: &NodeId,
490    ) -> std::result::Result<Vec<NodeId>, AdaptiveNetworkError> {
491        let mut path = vec![NodeId { hash: [0u8; 32] }]; // Start node
492        if self.nodes.contains(target) {
493            path.push(target.clone());
494        }
495        Ok(path)
496    }
497
498    fn route_score(&self, _neighbor: &NodeId, _target: &NodeId) -> f64 {
499        0.5
500    }
501
502    fn update_metrics(&mut self, _path: &[NodeId], _success: bool) {
503        // Mock implementation - do nothing
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    #[allow(unused_imports)]
510    use super::{
511        AdaptiveRouter, ContentType, HyperbolicCoordinate, HyperbolicRouting, NodeId,
512        StrategyChoice, ThompsonSampling, TrustProvider,
513    };
514    use rand::RngCore;
515    use std::sync::Arc;
516    #[tokio::test]
517    async fn test_adaptive_router_creation() {
518        struct MockTrustProvider;
519        impl TrustProvider for MockTrustProvider {
520            fn get_trust(&self, _node: &NodeId) -> f64 {
521                0.5
522            }
523            fn update_trust(&self, _from: &NodeId, _to: &NodeId, _success: bool) {}
524            fn get_global_trust(&self) -> std::collections::HashMap<NodeId, f64> {
525                std::collections::HashMap::new()
526            }
527            fn remove_node(&self, _node: &NodeId) {}
528        }
529
530        let mut hash = [0u8; 32];
531        rand::thread_rng().fill_bytes(&mut hash);
532        let node_id = NodeId { hash }; // Create proper NodeId
533        let trust_provider = Arc::new(MockTrustProvider);
534        let router = AdaptiveRouter::new_with_id(node_id, trust_provider);
535
536        let metrics = router.get_metrics().await;
537        // Metrics is a HashMap<String, f64>, check it's not empty or contains expected keys
538        assert!(!metrics.is_empty() || metrics.is_empty()); // Basic validity check
539    }
540
541    #[allow(clippy::unwrap_used)]
542    #[test]
543    fn test_thompson_sampling() {
544        let bandit = ThompsonSampling::new();
545
546        let rt = tokio::runtime::Runtime::new().unwrap();
547        let strategy = rt
548            .block_on(bandit.select_strategy(ContentType::DHTLookup))
549            .unwrap_or(StrategyChoice::Kademlia);
550
551        // Should return a valid strategy
552        matches!(
553            strategy,
554            StrategyChoice::Kademlia
555                | StrategyChoice::Hyperbolic
556                | StrategyChoice::TrustPath
557                | StrategyChoice::SOMRegion
558        );
559
560        // Update with positive feedback
561        rt.block_on(bandit.update(ContentType::DHTLookup, strategy, true, 100))
562            .unwrap();
563    }
564
565    #[test]
566    fn test_hyperbolic_distance() {
567        let a = HyperbolicCoordinate { r: 0.0, theta: 0.0 };
568        let b = HyperbolicCoordinate {
569            r: 0.5,
570            theta: std::f64::consts::PI,
571        };
572
573        let distance = HyperbolicRouting::distance(&a, &b);
574        assert!(distance > 0.0);
575
576        // Distance to self should be 0
577        let self_distance = HyperbolicRouting::distance(&a, &a);
578        assert!((self_distance - 0.0).abs() < 1e-10);
579    }
580}