1use super::beta_distribution::BetaDistribution;
20use super::*;
21use rand::Rng;
22use std::collections::HashMap;
23use std::sync::Arc;
24use tokio::sync::RwLock;
25
26pub struct ThompsonSampling {
31 arms: Arc<RwLock<HashMap<(ContentType, StrategyChoice), BetaParams>>>,
34
35 min_samples: u32,
37
38 decay_factor: f64,
40
41 metrics: Arc<RwLock<RoutingMetrics>>,
43}
44
45#[derive(Debug, Clone)]
47struct BetaParams {
48 distribution: BetaDistribution,
50 trials: u32,
52 last_update: std::time::Instant,
54}
55
56impl Default for BetaParams {
57 fn default() -> Self {
58 let distribution = BetaDistribution::new(1.0, 1.0).unwrap_or(BetaDistribution {
59 alpha: 1.0,
60 beta: 1.0,
61 });
62 Self {
63 distribution,
64 trials: 0,
65 last_update: std::time::Instant::now(),
66 }
67 }
68}
69
70#[derive(Debug, Default, Clone)]
72pub struct RoutingMetrics {
73 pub total_decisions: u64,
75 pub decisions_by_type: HashMap<ContentType, u64>,
77 pub strategy_success_rates: HashMap<StrategyChoice, f64>,
79 pub strategy_latencies: HashMap<StrategyChoice, f64>,
81}
82
83impl Default for ThompsonSampling {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl ThompsonSampling {
90 pub fn new() -> Self {
92 Self {
93 arms: Arc::new(RwLock::new(HashMap::new())),
94 min_samples: 10,
95 decay_factor: 0.99,
96 metrics: Arc::new(RwLock::new(RoutingMetrics::default())),
97 }
98 }
99
100 pub async fn select_strategy(&self, content_type: ContentType) -> Result<StrategyChoice> {
102 let mut arms = self.arms.write().await;
103 let mut metrics = self.metrics.write().await;
104
105 metrics.total_decisions += 1;
106 *metrics.decisions_by_type.entry(content_type).or_insert(0) += 1;
107
108 let strategies = vec![
109 StrategyChoice::Kademlia,
110 StrategyChoice::Hyperbolic,
111 StrategyChoice::TrustPath,
112 StrategyChoice::SOMRegion,
113 ];
114
115 let mut best_strategy = StrategyChoice::Kademlia;
116 let mut best_sample = 0.0;
117
118 for strategy in &strategies {
120 let key = (content_type, *strategy);
121 let params = arms.entry(key).or_default();
122
123 if params.trials > 0 {
125 let elapsed = params.last_update.elapsed().as_secs() as f64;
126 let decay = self.decay_factor.powf(elapsed / 3600.0); let alpha = params.distribution.alpha;
128 let beta = params.distribution.beta;
129 let new_alpha = 1.0 + (alpha - 1.0) * decay;
130 let new_beta = 1.0 + (beta - 1.0) * decay;
131 params.distribution =
132 BetaDistribution::new(new_alpha, new_beta).unwrap_or(BetaDistribution {
133 alpha: new_alpha.max(f64::MIN_POSITIVE),
134 beta: new_beta.max(f64::MIN_POSITIVE),
135 });
136 }
137
138 let mut rng = rand::thread_rng();
140 let sample = params.distribution.sample(&mut rng);
141
142 let exploration_bonus = if params.trials < self.min_samples {
144 0.1 * (1.0 - (params.trials as f64 / self.min_samples as f64))
145 } else {
146 0.0
147 };
148
149 let adjusted_sample = sample + exploration_bonus;
150
151 if adjusted_sample > best_sample {
152 best_sample = adjusted_sample;
153 best_strategy = *strategy;
154 }
155 }
156
157 Ok(best_strategy)
158 }
159
160 pub async fn update(
162 &self,
163 _content_type: ContentType,
164 strategy: StrategyChoice,
165 success: bool,
166 latency_ms: u64,
167 ) -> anyhow::Result<()> {
168 let mut arms = self.arms.write().await;
169 let mut metrics = self.metrics.write().await;
170
171 let key = (_content_type, strategy);
172 let params = arms.entry(key).or_default();
173
174 params.distribution.update(success);
176 params.trials += 1;
177 params.last_update = std::time::Instant::now();
178
179 let success_rate = params.distribution.mean();
181 let current_rate = metrics
182 .strategy_success_rates
183 .entry(strategy)
184 .or_insert(0.5);
185 *current_rate = 0.9 * (*current_rate) + 0.1 * success_rate;
186
187 let current_latency = metrics
189 .strategy_latencies
190 .entry(strategy)
191 .or_insert(latency_ms as f64);
192 *current_latency = 0.9 * (*current_latency) + 0.1 * (latency_ms as f64);
193
194 Ok(())
195 }
196
197 pub async fn get_metrics(&self) -> RoutingMetrics {
199 self.metrics.read().await.clone()
200 }
201
202 pub async fn get_confidence_interval(
204 &self,
205 _content_type: ContentType,
206 strategy: StrategyChoice,
207 ) -> (f64, f64) {
208 let arms = self.arms.read().await;
209 let key = (_content_type, strategy);
210
211 if let Some(params) = arms.get(&key) {
212 if params.trials == 0 {
213 return (0.0, 1.0);
214 }
215
216 params.distribution.confidence_interval()
218 } else {
219 (0.0, 1.0)
220 }
221 }
222
223 pub async fn reset_strategy(&self, _content_type: ContentType, strategy: StrategyChoice) {
225 let mut arms = self.arms.write().await;
226 arms.remove(&(_content_type, strategy));
227 }
228}
229
230#[async_trait]
231impl LearningSystem for ThompsonSampling {
232 async fn select_strategy(&self, context: &LearningContext) -> StrategyChoice {
233 self.select_strategy(context.content_type)
234 .await
235 .unwrap_or(StrategyChoice::Kademlia)
236 }
237
238 async fn update(
239 &mut self,
240 context: &LearningContext,
241 choice: &StrategyChoice,
242 outcome: &Outcome,
243 ) {
244 let _ = ThompsonSampling::update(
245 self,
246 context.content_type,
247 *choice,
248 outcome.success,
249 outcome.latency_ms,
250 )
251 .await;
252 }
253
254 async fn metrics(&self) -> LearningMetrics {
255 let metrics = self.get_metrics().await;
256
257 LearningMetrics {
258 total_decisions: metrics.total_decisions,
259 success_rate: metrics.strategy_success_rates.values().sum::<f64>()
260 / metrics.strategy_success_rates.len().max(1) as f64,
261 avg_latency_ms: metrics.strategy_latencies.values().sum::<f64>()
262 / metrics.strategy_latencies.len().max(1) as f64,
263 strategy_performance: metrics.strategy_success_rates.clone(),
264 }
265 }
266}
267
268#[derive(Debug, Clone, Default)]
270pub struct CacheStats {
271 pub hits: u64,
273
274 pub misses: u64,
276
277 pub size_bytes: u64,
279
280 pub item_count: u64,
282
283 pub evictions: u64,
285
286 pub hit_rate: f64,
288}
289
290pub struct QLearnCacheManager {
292 q_table: Arc<tokio::sync::RwLock<HashMap<CacheState, HashMap<CacheAction, f64>>>>,
294
295 learning_rate: f64,
297
298 discount_factor: f64,
300
301 epsilon: f64,
303
304 cache: Arc<tokio::sync::RwLock<HashMap<ContentHash, CachedContent>>>,
306
307 capacity: usize,
309
310 current_size: Arc<std::sync::atomic::AtomicUsize>,
312
313 request_stats: Arc<tokio::sync::RwLock<HashMap<ContentHash, RequestStats>>>,
315
316 hit_count: Arc<std::sync::atomic::AtomicU64>,
318 miss_count: Arc<std::sync::atomic::AtomicU64>,
319
320 _bandwidth_used: Arc<std::sync::atomic::AtomicU64>,
322}
323
324#[derive(Debug, Clone)]
326pub struct RequestStats {
327 request_count: u64,
329 hourly_requests: u64,
331 last_request: std::time::Instant,
333 content_size: usize,
335}
336
337#[derive(Debug, Clone, Hash, PartialEq, Eq)]
339pub struct CacheState {
340 utilization_bucket: u8,
342
343 request_rate_bucket: u8,
345
346 content_popularity: u8,
348
349 size_bucket: u8,
351}
352
353#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
355pub enum CacheAction {
356 Cache,
357 Evict(EvictionPolicy),
358 IncreaseReplication,
359 DecreaseReplication,
360 NoAction,
361}
362
363#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
365pub enum EvictionPolicy {
366 LRU,
367 LFU,
368 Random,
369}
370
371#[derive(Debug, Clone)]
373pub struct CachedContent {
374 pub data: Vec<u8>,
375 pub access_count: u64,
376 pub last_access: std::time::Instant,
377 pub insertion_time: std::time::Instant,
378}
379
380impl QLearnCacheManager {
381 pub fn new(capacity: usize) -> Self {
383 Self {
384 q_table: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
385 learning_rate: 0.1,
386 discount_factor: 0.9,
387 epsilon: 0.1,
388 cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
389 capacity,
390 current_size: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
391 request_stats: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
392 hit_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
393 miss_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
394 _bandwidth_used: Arc::new(std::sync::atomic::AtomicU64::new(0)),
395 }
396 }
397
398 pub async fn decide_action(&self, content_hash: &ContentHash) -> CacheAction {
400 let state = self.get_current_state(content_hash);
401
402 if rand::random::<f64>() < self.epsilon {
403 self.random_action()
405 } else {
406 let q_table = self.q_table.read().await;
408 q_table
409 .get(&state)
410 .and_then(|actions| {
411 actions
412 .iter()
413 .max_by(|(_, a), (_, b)| {
414 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
415 })
416 .map(|(action, _)| *action)
417 })
418 .unwrap_or(CacheAction::NoAction)
419 }
420 }
421
422 pub async fn update_q_value(
424 &self,
425 state: CacheState,
426 action: CacheAction,
427 reward: f64,
428 next_state: CacheState,
429 ) {
430 let mut q_table = self.q_table.write().await;
431
432 let current_q = q_table
433 .entry(state.clone())
434 .or_insert_with(HashMap::new)
435 .get(&action)
436 .copied()
437 .unwrap_or(0.0);
438
439 let max_next_q = q_table
440 .get(&next_state)
441 .and_then(|actions| {
442 actions
443 .values()
444 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
445 .copied()
446 })
447 .unwrap_or(0.0);
448
449 let new_q = current_q
450 + self.learning_rate * (reward + self.discount_factor * max_next_q - current_q);
451
452 q_table
453 .entry(state)
454 .or_insert_with(HashMap::new)
455 .insert(action, new_q);
456 }
457
458 fn get_current_state(&self, _content_hash: &ContentHash) -> CacheState {
460 let utilization = (self.current_size.load(std::sync::atomic::Ordering::Relaxed) * 10
461 / self.capacity) as u8;
462
463 let (request_rate_bucket, content_popularity, size_bucket) = (5, 5, 5);
466
467 CacheState {
468 utilization_bucket: utilization.min(10),
469 request_rate_bucket,
470 content_popularity,
471 size_bucket,
472 }
473 }
474
475 pub async fn get_current_state_async(&self, content_hash: &ContentHash) -> CacheState {
477 let utilization = (self.current_size.load(std::sync::atomic::Ordering::Relaxed) * 10
478 / self.capacity) as u8;
479
480 let stats = self.request_stats.read().await;
481 let (request_rate_bucket, content_popularity, size_bucket) =
482 if let Some(stat) = stats.get(content_hash) {
483 let hourly_rate = stat.hourly_requests.min(100) / 10;
485
486 let popularity = (stat.request_count.min(1000) / 100) as u8;
488
489 let size_bucket = match stat.content_size {
491 0..=1_024 => 0, 1_025..=10_240 => 1, 10_241..=102_400 => 2, 102_401..=1_048_576 => 3, 1_048_577..=10_485_760 => 4, 10_485_761..=104_857_600 => 5, 104_857_601..=1_073_741_824 => 6, _ => 7, };
500
501 (hourly_rate as u8, popularity, size_bucket)
502 } else {
503 (0, 0, 0) };
505
506 CacheState {
507 utilization_bucket: utilization.min(10),
508 request_rate_bucket,
509 content_popularity,
510 size_bucket,
511 }
512 }
513
514 fn random_action(&self) -> CacheAction {
516 match rand::random::<u8>() % 5 {
517 0 => CacheAction::Cache,
518 1 => CacheAction::Evict(EvictionPolicy::LRU),
519 2 => CacheAction::Evict(EvictionPolicy::LFU),
520 3 => CacheAction::IncreaseReplication,
521 4 => CacheAction::DecreaseReplication,
522 _ => CacheAction::NoAction,
523 }
524 }
525
526 pub async fn insert(&self, hash: ContentHash, data: Vec<u8>) -> bool {
528 let size = data.len();
529
530 while self.current_size.load(std::sync::atomic::Ordering::Relaxed) + size > self.capacity {
532 if !self.evict_one().await {
533 return false;
534 }
535 }
536
537 let mut cache = self.cache.write().await;
538 cache.insert(
539 hash,
540 CachedContent {
541 data,
542 access_count: 0,
543 last_access: std::time::Instant::now(),
544 insertion_time: std::time::Instant::now(),
545 },
546 );
547
548 self.current_size
549 .fetch_add(size, std::sync::atomic::Ordering::Relaxed);
550 true
551 }
552
553 async fn evict_one(&self) -> bool {
555 let mut cache = self.cache.write().await;
557 let oldest = cache
558 .iter()
559 .min_by_key(|(_, content)| content.last_access)
560 .map(|(k, _)| *k);
561
562 if let Some(key) = oldest
563 && let Some(value) = cache.remove(&key)
564 {
565 self.current_size
566 .fetch_sub(value.data.len(), std::sync::atomic::Ordering::Relaxed);
567 return true;
568 }
569
570 false
571 }
572
573 pub async fn get(&self, hash: &ContentHash) -> Option<Vec<u8>> {
575 let cache_result = {
576 let mut cache = self.cache.write().await;
577 cache.get_mut(hash).map(|entry| {
578 entry.access_count += 1;
579 entry.last_access = std::time::Instant::now();
580 entry.data.clone()
581 })
582 };
583
584 if cache_result.is_some() {
586 self.hit_count
587 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
588 } else {
589 self.miss_count
590 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
591 }
592
593 if let Some(ref data) = cache_result {
595 let mut stats = self.request_stats.write().await;
596 let stat = stats.entry(*hash).or_insert_with(|| RequestStats {
597 request_count: 0,
598 hourly_requests: 0,
599 last_request: std::time::Instant::now(),
600 content_size: data.len(),
601 });
602 stat.request_count += 1;
603 stat.hourly_requests += 1; stat.last_request = std::time::Instant::now();
605 }
606
607 cache_result
608 }
609
610 pub fn calculate_reward(&self, action: CacheAction, hit: bool, bandwidth_cost: u64) -> f64 {
612 let hits = self.hit_count.load(std::sync::atomic::Ordering::Relaxed) as f64;
613 let misses = self.miss_count.load(std::sync::atomic::Ordering::Relaxed) as f64;
614 let hit_rate = if hits + misses > 0.0 {
615 hits / (hits + misses)
616 } else {
617 0.0
618 };
619
620 let storage_cost = self.current_size.load(std::sync::atomic::Ordering::Relaxed) as f64
622 / self.capacity as f64;
623
624 let bandwidth_cost_normalized = bandwidth_cost as f64 / 1_000_000.0; match action {
629 CacheAction::Cache => {
630 if hit {
631 hit_rate - storage_cost * 0.1 - bandwidth_cost_normalized * 0.01
632 } else {
633 -0.1 - bandwidth_cost_normalized * 0.1 }
635 }
636 CacheAction::Evict(_) => {
637 if hit {
638 -0.5 } else {
640 0.1 - storage_cost * 0.05 }
642 }
643 CacheAction::IncreaseReplication => {
644 if hit {
645 hit_rate * 0.5 - bandwidth_cost_normalized * 0.2
646 } else {
647 -0.2 - bandwidth_cost_normalized * 0.2
648 }
649 }
650 CacheAction::DecreaseReplication => {
651 if hit {
652 -0.3 } else {
654 0.05 + storage_cost * 0.05 }
656 }
657 CacheAction::NoAction => {
658 hit_rate * 0.1 - storage_cost * 0.01 }
660 }
661 }
662
663 pub async fn execute_action(
665 &self,
666 hash: &ContentHash,
667 action: CacheAction,
668 data: Option<Vec<u8>>,
669 ) -> Result<()> {
670 match action {
671 CacheAction::Cache => {
672 if let Some(content) = data {
673 self.insert(*hash, content).await;
674 }
675 }
676 CacheAction::Evict(policy) => {
677 match policy {
678 EvictionPolicy::LRU => self.evict_lru().await,
679 EvictionPolicy::LFU => self.evict_lfu().await,
680 EvictionPolicy::Random => self.evict_random().await,
681 };
682 }
683 CacheAction::IncreaseReplication => {
684 }
687 CacheAction::DecreaseReplication => {
688 }
691 CacheAction::NoAction => {
692 }
694 }
695 Ok(())
696 }
697
698 async fn evict_lru(&self) -> bool {
700 self.evict_one().await
701 }
702
703 async fn evict_lfu(&self) -> bool {
705 let mut cache = self.cache.write().await;
706 let least_frequent = cache
707 .iter()
708 .min_by_key(|(_, content)| content.access_count)
709 .map(|(k, _)| *k);
710
711 if let Some(key) = least_frequent
712 && let Some(value) = cache.remove(&key)
713 {
714 self.current_size
715 .fetch_sub(value.data.len(), std::sync::atomic::Ordering::Relaxed);
716 return true;
717 }
718 false
719 }
720
721 async fn evict_random(&self) -> bool {
723 let cache = self.cache.read().await;
724 if cache.is_empty() {
725 return false;
726 }
727
728 let random_idx = rand::random::<usize>() % cache.len();
729 let random_key = cache.keys().nth(random_idx).cloned();
730 drop(cache);
731
732 if let Some(key) = random_key {
733 let mut cache = self.cache.write().await;
734 if let Some(value) = cache.remove(&key) {
735 self.current_size
736 .fetch_sub(value.data.len(), std::sync::atomic::Ordering::Relaxed);
737 return true;
738 }
739 }
740 false
741 }
742
743 pub fn get_stats(&self) -> CacheStats {
745 let hits = self.hit_count.load(std::sync::atomic::Ordering::Relaxed);
746 let misses = self.miss_count.load(std::sync::atomic::Ordering::Relaxed);
747 let _hit_rate = if hits + misses > 0 {
748 hits as f64 / (hits + misses) as f64
749 } else {
750 0.0
751 };
752
753 CacheStats {
754 hits,
755 misses,
756 size_bytes: self.current_size.load(std::sync::atomic::Ordering::Relaxed) as u64,
757 item_count: 0, evictions: 0, hit_rate: if hits + misses > 0 {
760 hits as f64 / (hits + misses) as f64
761 } else {
762 0.0
763 },
764 }
765 }
766
767 pub async fn decide_caching(
769 &self,
770 hash: ContentHash,
771 data: Vec<u8>,
772 _content_type: ContentType,
773 ) -> Result<()> {
774 let _state = self.get_current_state_async(&hash).await;
775 let action = self.decide_action(&hash).await;
776
777 if matches!(action, CacheAction::Cache) {
778 let _ = self.insert(hash, data).await;
779 }
780
781 Ok(())
782 }
783
784 pub async fn get_stats_async(&self) -> CacheStats {
786 let cache = self.cache.read().await;
787 let hit_count = self.hit_count.load(std::sync::atomic::Ordering::Relaxed);
788 let miss_count = self.miss_count.load(std::sync::atomic::Ordering::Relaxed);
789 let total = hit_count + miss_count;
790
791 CacheStats {
792 hits: hit_count,
793 misses: miss_count,
794 size_bytes: self.current_size.load(std::sync::atomic::Ordering::Relaxed) as u64,
795 item_count: cache.len() as u64,
796 evictions: 0, hit_rate: if total > 0 {
798 hit_count as f64 / total as f64
799 } else {
800 0.0
801 },
802 }
803 }
804}
805
806#[derive(Debug, Clone)]
808pub struct NodeFeatures {
809 pub online_duration: f64,
811 pub avg_response_time: f64,
813 pub resource_contribution: f64,
815 pub message_frequency: f64,
817 pub time_of_day: f64,
819 pub day_of_week: f64,
821 pub historical_reliability: f64,
823 pub recent_disconnections: f64,
825 pub avg_session_length: f64,
827 pub connection_stability: f64,
829}
830
831#[derive(Debug, Clone)]
833pub struct FeatureHistory {
834 pub node_id: NodeId,
836 pub snapshots: Vec<(std::time::Instant, NodeFeatures)>,
838 pub sessions: Vec<(std::time::Instant, Option<std::time::Instant>)>,
840 pub total_uptime: u64,
842 pub total_downtime: u64,
844}
845
846impl Default for FeatureHistory {
847 fn default() -> Self {
848 Self::new()
849 }
850}
851
852impl FeatureHistory {
853 pub fn new() -> Self {
855 Self {
856 node_id: NodeId { hash: [0u8; 32] },
857 snapshots: Vec::new(),
858 sessions: Vec::new(),
859 total_uptime: 0,
860 total_downtime: 0,
861 }
862 }
863}
864
865#[derive(Debug)]
867pub struct ChurnPredictor {
868 prediction_cache: Arc<tokio::sync::RwLock<HashMap<NodeId, ChurnPrediction>>>,
870
871 feature_history: Arc<tokio::sync::RwLock<HashMap<NodeId, FeatureHistory>>>,
873
874 model_weights: Arc<tokio::sync::RwLock<ModelWeights>>,
876
877 experience_buffer: Arc<tokio::sync::RwLock<Vec<TrainingExample>>>,
879
880 max_buffer_size: usize,
882
883 _update_interval: std::time::Duration,
885}
886
887#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
889pub struct ModelWeights {
890 pub feature_weights: Vec<f64>,
892 pub time_decay: Vec<f64>,
894 pub pattern_weights: HashMap<String, f64>,
896 pub bias: Vec<f64>,
898}
899
900impl Default for ModelWeights {
901 fn default() -> Self {
902 Self {
903 feature_weights: vec![
905 0.15, 0.20, 0.10, 0.05, 0.05, 0.05, 0.25, 0.10, 0.05, 0.00, ],
916 time_decay: vec![0.9, 0.8, 0.7], pattern_weights: HashMap::new(),
918 bias: vec![0.1, 0.2, 0.3], }
920 }
921}
922
923#[derive(Debug, Clone)]
925pub struct TrainingExample {
926 pub node_id: NodeId,
927 pub features: NodeFeatures,
928 pub timestamp: std::time::Instant,
929 pub actual_churn_1h: bool,
930 pub actual_churn_6h: bool,
931 pub actual_churn_24h: bool,
932}
933
934#[derive(Debug, Clone)]
936pub struct ChurnPrediction {
937 pub probability_1h: f64,
938 pub probability_6h: f64,
939 pub probability_24h: f64,
940 pub confidence: f64,
941 pub timestamp: std::time::Instant,
942}
943
944impl Default for ChurnPredictor {
945 fn default() -> Self {
946 Self::new()
947 }
948}
949
950impl ChurnPredictor {
951 pub fn new() -> Self {
953 Self {
954 prediction_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
955 feature_history: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
956 model_weights: Arc::new(tokio::sync::RwLock::new(ModelWeights::default())),
957 experience_buffer: Arc::new(tokio::sync::RwLock::new(Vec::new())),
958 max_buffer_size: 10000,
959 _update_interval: std::time::Duration::from_secs(3600), }
961 }
962
963 pub async fn should_replicate(&self, node_id: &NodeId) -> bool {
965 let prediction = self.predict(node_id).await;
966 prediction.probability_1h > 0.7
967 }
968
969 pub async fn update_node_features(
971 &self,
972 node_id: &NodeId,
973 features: Vec<f64>,
974 ) -> anyhow::Result<()> {
975 if features.len() != 10 {
976 return Err(anyhow::anyhow!(
977 "Expected 10 features, got {}",
978 features.len()
979 ));
980 }
981
982 let node_features = NodeFeatures {
983 online_duration: features[0],
984 avg_response_time: features[1],
985 resource_contribution: features[2],
986 message_frequency: features[3],
987 time_of_day: features[4],
988 day_of_week: features[5],
989 historical_reliability: features[6],
990 recent_disconnections: features[7],
991 avg_session_length: features[8],
992 connection_stability: features[9],
993 };
994
995 let mut history = self.feature_history.write().await;
997 let entry = history
998 .entry(node_id.clone())
999 .or_insert(FeatureHistory::new());
1000
1001 if entry.sessions.is_empty() || entry.sessions.last().map(|s| s.1.is_some()).unwrap_or(true)
1003 {
1004 entry.sessions.push((std::time::Instant::now(), None));
1006 }
1007
1008 entry
1010 .snapshots
1011 .push((std::time::Instant::now(), node_features));
1012
1013 while entry.snapshots.len() > 100 {
1015 entry.snapshots.remove(0);
1016 }
1017
1018 Ok(())
1019 }
1020
1021 pub async fn extract_features(&self, node_id: &NodeId) -> Option<NodeFeatures> {
1023 let history = self.feature_history.read().await;
1024 let node_history = history.get(node_id)?;
1025
1026 let now = std::time::Instant::now();
1028 let current_session = node_history.sessions.last()?;
1029 let online_duration = if current_session.1.is_none() {
1030 now.duration_since(current_session.0).as_secs() as f64
1031 } else {
1032 0.0
1033 };
1034
1035 let completed_sessions: Vec<_> = node_history
1037 .sessions
1038 .iter()
1039 .filter_map(|(start, end)| {
1040 end.as_ref()
1041 .map(|e| e.duration_since(*start).as_secs() as f64)
1042 })
1043 .collect();
1044 let avg_session_length = if !completed_sessions.is_empty() {
1045 completed_sessions.iter().sum::<f64>() / completed_sessions.len() as f64 / 3600.0
1046 } else {
1047 1.0 };
1049
1050 let recent_disconnections = if let Some(one_week_ago) =
1053 now.checked_sub(std::time::Duration::from_secs(7 * 24 * 3600))
1054 {
1055 node_history
1056 .sessions
1057 .iter()
1058 .filter(|(start, end)| end.is_some() && *start > one_week_ago)
1059 .count() as f64
1060 } else {
1061 node_history
1063 .sessions
1064 .iter()
1065 .filter(|(_, end)| end.is_some())
1066 .count() as f64
1067 };
1068
1069 let latest_snapshot = node_history
1071 .snapshots
1072 .last()
1073 .map(|(_, features)| features.clone())
1074 .unwrap_or_else(|| NodeFeatures {
1075 online_duration,
1076 avg_response_time: 100.0,
1077 resource_contribution: 0.5,
1078 message_frequency: 10.0,
1079 time_of_day: 12.0, day_of_week: 3.0, historical_reliability: node_history.total_uptime as f64
1082 / (node_history.total_uptime + node_history.total_downtime).max(1) as f64,
1083 recent_disconnections,
1084 avg_session_length,
1085 connection_stability: 1.0 - (recent_disconnections / 7.0).min(1.0),
1086 });
1087
1088 Some(NodeFeatures {
1089 online_duration,
1090 recent_disconnections,
1091 avg_session_length,
1092 historical_reliability: node_history.total_uptime as f64
1093 / (node_history.total_uptime + node_history.total_downtime).max(1) as f64,
1094 connection_stability: 1.0 - (recent_disconnections / 7.0).min(1.0),
1095 ..latest_snapshot
1096 })
1097 }
1098
1099 async fn analyze_patterns(&self, features: &NodeFeatures) -> HashMap<String, f64> {
1101 let mut patterns = HashMap::new();
1102
1103 let is_night = features.time_of_day < 6.0 || features.time_of_day > 22.0;
1105 let is_weekend = features.day_of_week == 0.0 || features.day_of_week == 6.0;
1106
1107 patterns.insert("night_time".to_string(), if is_night { 1.0 } else { 0.0 });
1108 patterns.insert("weekend".to_string(), if is_weekend { 1.0 } else { 0.0 });
1109
1110 patterns.insert(
1112 "short_session".to_string(),
1113 if features.online_duration < 1800.0 {
1114 1.0
1115 } else {
1116 0.0
1117 },
1118 );
1119 patterns.insert(
1120 "unstable".to_string(),
1121 if features.recent_disconnections > 5.0 {
1122 1.0
1123 } else {
1124 0.0
1125 },
1126 );
1127 patterns.insert(
1128 "low_contribution".to_string(),
1129 if features.resource_contribution < 0.3 {
1130 1.0
1131 } else {
1132 0.0
1133 },
1134 );
1135 patterns.insert(
1136 "slow_response".to_string(),
1137 if features.avg_response_time > 500.0 {
1138 1.0
1139 } else {
1140 0.0
1141 },
1142 );
1143
1144 let risk_score = (features.recent_disconnections / 10.0).min(1.0) * 0.3
1146 + (1.0 - features.historical_reliability) * 0.4
1147 + (1.0 - features.connection_stability) * 0.3;
1148 patterns.insert(
1149 "high_risk".to_string(),
1150 if risk_score > 0.6 { 1.0 } else { 0.0 },
1151 );
1152
1153 patterns
1154 }
1155
1156 pub async fn predict(&self, node_id: &NodeId) -> ChurnPrediction {
1158 {
1160 let cache = self.prediction_cache.read().await;
1161 if let Some(cached) = cache.get(node_id)
1162 && cached.timestamp.elapsed() < std::time::Duration::from_secs(300)
1163 {
1164 return cached.clone();
1165 }
1166 }
1167
1168 let features = match self.extract_features(node_id).await {
1170 Some(f) => f,
1171 None => {
1172 return ChurnPrediction {
1174 probability_1h: 0.1,
1175 probability_6h: 0.2,
1176 probability_24h: 0.3,
1177 confidence: 0.1,
1178 timestamp: std::time::Instant::now(),
1179 };
1180 }
1181 };
1182
1183 let patterns = self.analyze_patterns(&features).await;
1185
1186 let model = self.model_weights.read().await;
1188 let prediction = self.apply_model(&features, &patterns, &model).await;
1189
1190 let mut cache = self.prediction_cache.write().await;
1192 cache.insert(node_id.clone(), prediction.clone());
1193 prediction
1194 }
1195
1196 async fn apply_model(
1198 &self,
1199 features: &NodeFeatures,
1200 patterns: &HashMap<String, f64>,
1201 model: &ModelWeights,
1202 ) -> ChurnPrediction {
1203 let feature_vec = [
1205 features.online_duration / 3600.0, features.avg_response_time / 1000.0, features.resource_contribution,
1208 features.message_frequency / 100.0, features.time_of_day / 24.0, features.day_of_week / 7.0, features.historical_reliability,
1212 features.recent_disconnections / 10.0, features.avg_session_length / 24.0, features.connection_stability,
1215 ];
1216
1217 let mut base_scores = [0.0; 3]; for (i, &weight) in model.feature_weights.iter().enumerate() {
1220 if i < feature_vec.len() {
1221 for score in &mut base_scores {
1222 *score += weight * feature_vec[i];
1223 }
1224 }
1225 }
1226
1227 let mut pattern_score = 0.0;
1229 for (pattern, &value) in patterns {
1230 if let Some(&weight) = model.pattern_weights.get(pattern) {
1231 pattern_score += weight * value;
1232 } else {
1233 pattern_score += 0.1 * value;
1235 }
1236 }
1237
1238 let probabilities: Vec<f64> = base_scores
1240 .iter()
1241 .zip(&model.time_decay)
1242 .zip(&model.bias)
1243 .map(|((base, decay), bias)| {
1244 let raw_score = base + pattern_score * decay + bias;
1245 1.0 / (1.0 + (-raw_score).exp())
1247 })
1248 .collect();
1249
1250 let confidence = 0.8; ChurnPrediction {
1254 probability_1h: probabilities[0].min(0.99),
1255 probability_6h: probabilities[1].min(0.99),
1256 probability_24h: probabilities[2].min(0.99),
1257 confidence,
1258 timestamp: std::time::Instant::now(),
1259 }
1260 }
1261
1262 pub async fn update_node_behavior(
1264 &self,
1265 node_id: &NodeId,
1266 features: NodeFeatures,
1267 ) -> anyhow::Result<()> {
1268 let mut history = self.feature_history.write().await;
1269 let node_history = history
1270 .entry(node_id.clone())
1271 .or_insert_with(|| FeatureHistory {
1272 node_id: node_id.clone(),
1273 snapshots: Vec::new(),
1274 sessions: vec![(std::time::Instant::now(), None)],
1275 total_uptime: 0,
1276 total_downtime: 0,
1277 });
1278
1279 node_history
1281 .snapshots
1282 .push((std::time::Instant::now(), features));
1283
1284 if let Some(cutoff) =
1287 std::time::Instant::now().checked_sub(std::time::Duration::from_secs(24 * 3600))
1288 {
1289 node_history
1290 .snapshots
1291 .retain(|(timestamp, _)| *timestamp > cutoff);
1292 }
1293 Ok(())
1296 }
1297
1298 pub async fn record_node_event(&self, node_id: &NodeId, event: NodeEvent) -> Result<()> {
1300 let mut history = self.feature_history.write().await;
1301 let node_history = history
1302 .entry(node_id.clone())
1303 .or_insert_with(|| FeatureHistory {
1304 node_id: node_id.clone(),
1305 snapshots: Vec::new(),
1306 sessions: Vec::new(),
1307 total_uptime: 0,
1308 total_downtime: 0,
1309 });
1310
1311 match event {
1312 NodeEvent::Connected => {
1313 node_history
1315 .sessions
1316 .push((std::time::Instant::now(), None));
1317 }
1318 NodeEvent::Disconnected => {
1319 if let Some((start, end)) = node_history.sessions.last_mut()
1321 && end.is_none()
1322 {
1323 let now = std::time::Instant::now();
1324 *end = Some(now);
1325 let session_length = now.duration_since(*start).as_secs();
1326 node_history.total_uptime += session_length;
1327 }
1328 }
1329 }
1330
1331 Ok(())
1332 }
1333
1334 pub async fn add_training_example(
1336 &self,
1337 node_id: &NodeId,
1338 features: NodeFeatures,
1339 actual_churn_1h: bool,
1340 actual_churn_6h: bool,
1341 actual_churn_24h: bool,
1342 ) -> anyhow::Result<()> {
1343 let example = TrainingExample {
1344 node_id: node_id.clone(),
1345 features,
1346 timestamp: std::time::Instant::now(),
1347 actual_churn_1h,
1348 actual_churn_6h,
1349 actual_churn_24h,
1350 };
1351
1352 let mut buffer = self.experience_buffer.write().await;
1353 buffer.push(example);
1354
1355 if buffer.len() > self.max_buffer_size {
1357 let drain_count = buffer.len() - self.max_buffer_size;
1358 buffer.drain(0..drain_count);
1359 }
1360
1361 if buffer.len() >= 32 && buffer.len() % 32 == 0 {
1363 self.update_model().await?;
1364 }
1365
1366 Ok(())
1367 }
1368
1369 async fn update_model(&self) -> anyhow::Result<()> {
1371 let buffer = self.experience_buffer.read().await;
1372 if buffer.is_empty() {
1373 return Ok(());
1374 }
1375
1376 let mut model = self.model_weights.write().await;
1377
1378 let learning_rate = 0.01;
1380 let batch_size = 32.min(buffer.len());
1381
1382 let mut rng = rand::thread_rng();
1384 let batch: Vec<_> = (0..batch_size)
1385 .map(|_| &buffer[rng.gen_range(0..buffer.len())])
1386 .collect();
1387
1388 for example in batch {
1390 let feature_vec = [
1392 example.features.online_duration / 3600.0,
1393 example.features.avg_response_time / 1000.0,
1394 example.features.resource_contribution,
1395 example.features.message_frequency / 100.0,
1396 example.features.time_of_day / 24.0,
1397 example.features.day_of_week / 7.0,
1398 example.features.historical_reliability,
1399 example.features.recent_disconnections / 10.0,
1400 example.features.avg_session_length / 24.0,
1401 example.features.connection_stability,
1402 ];
1403
1404 let patterns = self.analyze_patterns(&example.features).await;
1406
1407 let prediction = self.apply_model(&example.features, &patterns, &model).await;
1409
1410 let errors = [
1412 if example.actual_churn_1h { 1.0 } else { 0.0 } - prediction.probability_1h,
1413 if example.actual_churn_6h { 1.0 } else { 0.0 } - prediction.probability_6h,
1414 if example.actual_churn_24h { 1.0 } else { 0.0 } - prediction.probability_24h,
1415 ];
1416
1417 for (i, &feature_value) in feature_vec.iter().enumerate() {
1419 if i < model.feature_weights.len() {
1420 for (j, &error) in errors.iter().enumerate() {
1421 model.feature_weights[i] +=
1422 learning_rate * error * feature_value * model.time_decay[j];
1423 }
1424 }
1425 }
1426
1427 for (pattern, &value) in &patterns {
1429 let avg_error = errors.iter().sum::<f64>() / errors.len() as f64;
1430 model
1431 .pattern_weights
1432 .entry(pattern.clone())
1433 .and_modify(|w| *w += learning_rate * avg_error * value)
1434 .or_insert(learning_rate * avg_error * value);
1435 }
1436 }
1437
1438 Ok(())
1439 }
1440
1441 pub async fn save_model(&self, path: &std::path::Path) -> anyhow::Result<()> {
1443 let model = self.model_weights.read().await;
1444 let serialized = serde_json::to_string(&*model)?;
1445 tokio::fs::write(path, serialized).await?;
1446 Ok(())
1447 }
1448
1449 pub async fn load_model(&self, path: &std::path::Path) -> anyhow::Result<()> {
1451 let data = tokio::fs::read_to_string(path).await?;
1452 let loaded_model: ModelWeights = serde_json::from_str(&data)?;
1453 let mut model = self.model_weights.write().await;
1454 *model = loaded_model;
1455 Ok(())
1456 }
1457}
1458
1459#[derive(Debug, Clone)]
1461pub enum NodeEvent {
1462 Connected,
1463 Disconnected,
1464}
1465
1466#[cfg(test)]
1467mod tests {
1468 use super::*;
1469
1470 #[tokio::test]
1471 async fn test_thompson_sampling_initialization() {
1472 let ts = ThompsonSampling::new();
1473 let metrics = ts.get_metrics().await;
1474
1475 assert_eq!(metrics.total_decisions, 0);
1476 assert!(metrics.decisions_by_type.is_empty());
1477 assert!(metrics.strategy_success_rates.is_empty());
1478 }
1479
1480 #[tokio::test]
1481 async fn test_thompson_sampling_selection() -> Result<()> {
1482 let ts = ThompsonSampling::new();
1483
1484 for content_type in [
1486 ContentType::DHTLookup,
1487 ContentType::DataRetrieval,
1488 ContentType::ComputeRequest,
1489 ContentType::RealtimeMessage,
1490 ] {
1491 let strategy = ts.select_strategy(content_type).await?;
1492 assert!(matches!(
1493 strategy,
1494 StrategyChoice::Kademlia
1495 | StrategyChoice::Hyperbolic
1496 | StrategyChoice::TrustPath
1497 | StrategyChoice::SOMRegion
1498 ));
1499 }
1500
1501 let metrics = ts.get_metrics().await;
1502 assert_eq!(metrics.total_decisions, 4);
1503 assert_eq!(metrics.decisions_by_type.len(), 4);
1504 Ok(())
1505 }
1506
1507 #[tokio::test]
1508 async fn test_thompson_sampling_update() -> Result<()> {
1509 let ts = ThompsonSampling::new();
1510
1511 for _ in 0..20 {
1513 ts.update(
1514 ContentType::DataRetrieval,
1515 StrategyChoice::Hyperbolic,
1516 true,
1517 50,
1518 )
1519 .await?;
1520 }
1521
1522 for _ in 0..10 {
1524 ts.update(
1525 ContentType::DataRetrieval,
1526 StrategyChoice::Kademlia,
1527 false,
1528 200,
1529 )
1530 .await?;
1531 }
1532
1533 let mut hyperbolic_count = 0;
1535 for _ in 0..100 {
1536 let strategy = ts.select_strategy(ContentType::DataRetrieval).await?;
1537 if matches!(strategy, StrategyChoice::Hyperbolic) {
1538 hyperbolic_count += 1;
1539 }
1540 }
1541
1542 assert!(
1544 hyperbolic_count >= 60,
1545 "Expected Hyperbolic to be selected at least 60% of the time, got {}%",
1546 hyperbolic_count
1547 );
1548 Ok(())
1549 }
1550
1551 #[tokio::test]
1552 async fn test_confidence_intervals() -> Result<()> {
1553 let ts = ThompsonSampling::new();
1554
1555 for i in 0..10 {
1557 ts.update(
1558 ContentType::DHTLookup,
1559 StrategyChoice::Kademlia,
1560 i % 3 != 0, 100,
1562 )
1563 .await?;
1564 }
1565
1566 let (lower, upper) = ts
1567 .get_confidence_interval(ContentType::DHTLookup, StrategyChoice::Kademlia)
1568 .await;
1569
1570 assert!(lower > 0.0 && lower < 1.0);
1571 assert!(upper > lower && upper <= 1.0);
1572 assert!(upper - lower < 0.5); Ok(())
1574 }
1575
1576 #[tokio::test]
1577 async fn test_exploration_bonus() -> Result<()> {
1578 let ts = ThompsonSampling::new();
1579
1580 for _ in 0..15 {
1582 ts.update(
1583 ContentType::ComputeRequest,
1584 StrategyChoice::TrustPath,
1585 true,
1586 100,
1587 )
1588 .await?;
1589 }
1590
1591 let mut strategy_counts = HashMap::new();
1593 for _ in 0..100 {
1594 let strategy = ts.select_strategy(ContentType::ComputeRequest).await?;
1595 *strategy_counts.entry(strategy).or_insert(0) += 1;
1596 }
1597
1598 assert!(
1600 strategy_counts.len() >= 3,
1601 "Expected at least 3 different strategies to be tried"
1602 );
1603 Ok(())
1604 }
1605
1606 #[tokio::test]
1607 async fn test_reset_strategy() -> Result<()> {
1608 let ts = ThompsonSampling::new();
1609
1610 for _ in 0..10 {
1612 ts.update(
1613 ContentType::RealtimeMessage,
1614 StrategyChoice::SOMRegion,
1615 true,
1616 50,
1617 )
1618 .await?;
1619 }
1620
1621 ts.reset_strategy(ContentType::RealtimeMessage, StrategyChoice::SOMRegion)
1623 .await;
1624
1625 let (lower, upper) = ts
1627 .get_confidence_interval(ContentType::RealtimeMessage, StrategyChoice::SOMRegion)
1628 .await;
1629
1630 assert_eq!(lower, 0.0);
1631 assert_eq!(upper, 1.0);
1632 Ok(())
1633 }
1634
1635 #[tokio::test]
1636 async fn test_learning_system_trait() {
1637 let mut ts = ThompsonSampling::new();
1638
1639 let context = LearningContext {
1640 content_type: ContentType::DataRetrieval,
1641 network_conditions: NetworkConditions {
1642 connected_peers: 100,
1643 avg_latency_ms: 50.0,
1644 churn_rate: 0.1,
1645 },
1646 historical_performance: vec![0.8, 0.85, 0.9],
1647 };
1648
1649 let choice = <ThompsonSampling as LearningSystem>::select_strategy(&ts, &context).await;
1651 assert!(matches!(
1652 choice,
1653 StrategyChoice::Kademlia
1654 | StrategyChoice::Hyperbolic
1655 | StrategyChoice::TrustPath
1656 | StrategyChoice::SOMRegion
1657 ));
1658
1659 let outcome = Outcome {
1660 success: true,
1661 latency_ms: 45,
1662 hops: 3,
1663 };
1664
1665 <ThompsonSampling as LearningSystem>::update(&mut ts, &context, &choice, &outcome).await;
1666
1667 let metrics = <ThompsonSampling as LearningSystem>::metrics(&ts).await;
1668 assert_eq!(metrics.total_decisions, 1);
1669 }
1670
1671 #[tokio::test]
1672 async fn test_cache_manager() {
1673 let manager = QLearnCacheManager::new(1024);
1674 let hash = ContentHash([1u8; 32]);
1675
1676 assert!(manager.insert(hash, vec![0u8; 100]).await);
1678
1679 assert!(manager.get(&hash).await.is_some());
1681
1682 let action = manager.decide_action(&hash).await;
1684 assert!(matches!(
1685 action,
1686 CacheAction::Cache
1687 | CacheAction::Evict(_)
1688 | CacheAction::IncreaseReplication
1689 | CacheAction::DecreaseReplication
1690 | CacheAction::NoAction
1691 ));
1692 }
1693
1694 #[tokio::test]
1695 async fn test_q_value_update() {
1696 let manager = QLearnCacheManager::new(1024);
1697
1698 let state = CacheState {
1699 utilization_bucket: 5,
1700 request_rate_bucket: 5,
1701 content_popularity: 5,
1702 size_bucket: 3,
1703 };
1704
1705 let next_state = CacheState {
1706 utilization_bucket: 6,
1707 request_rate_bucket: 5,
1708 content_popularity: 5,
1709 size_bucket: 3,
1710 };
1711
1712 manager
1713 .update_q_value(state, CacheAction::Cache, 1.0, next_state)
1714 .await;
1715
1716 let q_table = manager.q_table.read().await;
1718 assert!(!q_table.is_empty());
1719 }
1720
1721 #[tokio::test]
1722 async fn test_churn_predictor() {
1723 use crate::peer_record::UserId;
1724 use rand::RngCore;
1725
1726 let predictor = ChurnPredictor::new();
1727 let mut hash = [0u8; 32];
1728 rand::thread_rng().fill_bytes(&mut hash);
1729 let node_id = UserId::from_bytes(hash);
1730
1731 let prediction = predictor.predict(&node_id).await;
1732 assert!(prediction.probability_1h >= 0.0 && prediction.probability_1h <= 1.0);
1733 assert!(prediction.probability_6h >= 0.0 && prediction.probability_6h <= 1.0);
1734 assert!(prediction.probability_24h >= 0.0 && prediction.probability_24h <= 1.0);
1735
1736 let prediction2 = predictor.predict(&node_id).await;
1738 assert_eq!(prediction.probability_1h, prediction2.probability_1h);
1739 }
1740
1741 #[tokio::test]
1742 async fn test_cache_eviction_policies() {
1743 let manager = QLearnCacheManager::new(300); let hash1 = ContentHash([1u8; 32]);
1747 let hash2 = ContentHash([2u8; 32]);
1748 let hash3 = ContentHash([3u8; 32]);
1749
1750 manager.insert(hash1, vec![0u8; 100]).await;
1751 manager.insert(hash2, vec![0u8; 100]).await;
1752 manager.insert(hash3, vec![0u8; 100]).await;
1753
1754 manager.get(&hash1).await;
1756 manager.get(&hash2).await;
1757
1758 let hash4 = ContentHash([4u8; 32]);
1760 manager.insert(hash4, vec![0u8; 100]).await;
1761
1762 assert!(manager.get(&hash1).await.is_some());
1764 assert!(manager.get(&hash2).await.is_some());
1765 assert!(manager.get(&hash3).await.is_none());
1766 assert!(manager.get(&hash4).await.is_some());
1767 }
1768
1769 #[tokio::test]
1770 async fn test_reward_calculation() {
1771 let manager = QLearnCacheManager::new(1024);
1772
1773 let hash = ContentHash([1u8; 32]);
1775 manager.insert(hash, vec![0u8; 100]).await;
1776
1777 for _ in 0..5 {
1779 manager.get(&hash).await;
1780 }
1781
1782 let miss_hash = ContentHash([2u8; 32]);
1784 for _ in 0..2 {
1785 manager.get(&miss_hash).await;
1786 }
1787
1788 let cache_reward = manager.calculate_reward(CacheAction::Cache, true, 1000);
1790 assert!(cache_reward > 0.0); let evict_reward =
1793 manager.calculate_reward(CacheAction::Evict(EvictionPolicy::LRU), false, 0);
1794 assert!(evict_reward >= 0.0); let evict_penalty =
1797 manager.calculate_reward(CacheAction::Evict(EvictionPolicy::LRU), true, 0);
1798 assert!(evict_penalty < 0.0); }
1800
1801 #[tokio::test]
1802 async fn test_cache_statistics() {
1803 let manager = QLearnCacheManager::new(1024);
1804
1805 let hash1 = ContentHash([1u8; 32]);
1806 let hash2 = ContentHash([2u8; 32]);
1807
1808 manager.insert(hash1, vec![0u8; 100]).await;
1810
1811 manager.get(&hash1).await; manager.get(&hash1).await; manager.get(&hash2).await; let stats = manager.get_stats();
1817 assert_eq!(stats.hits, 2);
1818 assert_eq!(stats.misses, 1);
1819 assert!((stats.hit_rate - 0.666).abs() < 0.01); assert_eq!(stats.size_bytes, 100);
1821 }
1822
1823 #[tokio::test]
1824 async fn test_exploration_vs_exploitation() {
1825 let manager = QLearnCacheManager::new(1024);
1826
1827 let state = CacheState {
1830 utilization_bucket: 0, request_rate_bucket: 5,
1832 content_popularity: 5,
1833 size_bucket: 5,
1834 };
1835
1836 for _ in 0..10 {
1838 manager
1839 .update_q_value(state.clone(), CacheAction::Cache, 1.0, state.clone())
1840 .await;
1841 }
1842
1843 let mut cache_count = 0;
1845 for _ in 0..100 {
1846 let action = manager.decide_action(&ContentHash([1u8; 32])).await;
1849 if matches!(action, CacheAction::Cache) {
1850 cache_count += 1;
1851 }
1852 }
1853
1854 assert!((50..=100).contains(&cache_count));
1857 }
1858
1859 #[tokio::test]
1860 async fn test_state_representation() {
1861 let manager = QLearnCacheManager::new(1024);
1862
1863 let hash = ContentHash([1u8; 32]);
1865
1866 manager.insert(hash, vec![0u8; 100]).await;
1868
1869 for _ in 0..5 {
1871 manager.get(&hash).await;
1872 }
1873
1874 let state = manager.get_current_state_async(&hash).await;
1875
1876 assert!(state.utilization_bucket <= 10);
1878 assert!(state.request_rate_bucket <= 10);
1879 assert!(state.content_popularity <= 10);
1880 assert!(state.size_bucket <= 10);
1881 }
1882
1883 #[tokio::test]
1884 async fn test_action_execution() -> Result<()> {
1885 let manager = QLearnCacheManager::new(1024);
1886 let hash = ContentHash([1u8; 32]);
1887
1888 manager
1890 .execute_action(&hash, CacheAction::Cache, Some(vec![0u8; 100]))
1891 .await?;
1892 assert!(manager.get(&hash).await.is_some());
1893
1894 manager
1896 .execute_action(&hash, CacheAction::NoAction, None)
1897 .await?;
1898 assert!(manager.get(&hash).await.is_some()); manager
1902 .execute_action(&hash, CacheAction::Evict(EvictionPolicy::LRU), None)
1903 .await?;
1904 let stats = manager.get_stats();
1907 assert!(stats.size_bytes <= 100); Ok(())
1909 }
1910
1911 #[tokio::test]
1912 async fn test_churn_predictor_initialization() {
1913 let predictor = ChurnPredictor::new();
1914
1915 let node_id = NodeId { hash: [1u8; 32] };
1917 let prediction = predictor.predict(&node_id).await;
1918
1919 assert!(prediction.confidence < 0.2);
1921 assert!(prediction.probability_1h < 0.3);
1922 assert!(prediction.probability_6h < 0.4);
1923 assert!(prediction.probability_24h < 0.5);
1924 }
1925
1926 #[tokio::test]
1927 async fn test_churn_predictor_node_events() -> Result<()> {
1928 let predictor = ChurnPredictor::new();
1929 let node_id = NodeId { hash: [1u8; 32] };
1930
1931 predictor
1933 .record_node_event(&node_id, NodeEvent::Connected)
1934 .await?;
1935
1936 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1938 predictor
1939 .record_node_event(&node_id, NodeEvent::Disconnected)
1940 .await?;
1941
1942 let features = predictor.extract_features(&node_id).await;
1944 assert!(features.is_some());
1945 Ok(())
1946 }
1947
1948 #[tokio::test]
1949 async fn test_churn_predictor_feature_extraction() -> Result<()> {
1950 let predictor = ChurnPredictor::new();
1951 let node_id = NodeId { hash: [1u8; 32] };
1952
1953 predictor
1955 .record_node_event(&node_id, NodeEvent::Connected)
1956 .await?;
1957
1958 let features = NodeFeatures {
1960 online_duration: 3600.0,
1961 avg_response_time: 50.0,
1962 resource_contribution: 0.8,
1963 message_frequency: 20.0,
1964 time_of_day: 14.0,
1965 day_of_week: 2.0,
1966 historical_reliability: 0.9,
1967 recent_disconnections: 1.0,
1968 avg_session_length: 4.0,
1969 connection_stability: 0.85,
1970 };
1971
1972 predictor
1973 .update_node_behavior(&node_id, features.clone())
1974 .await?;
1975
1976 let extracted = predictor
1978 .extract_features(&node_id)
1979 .await
1980 .ok_or(anyhow::anyhow!("no features extracted"))?;
1981 assert_eq!(extracted.resource_contribution, 0.8);
1982 assert_eq!(extracted.avg_response_time, 50.0);
1983 Ok(())
1984 }
1985
1986 #[tokio::test]
1987 async fn test_churn_predictor_pattern_analysis() {
1988 let predictor = ChurnPredictor::new();
1989
1990 let night_features = NodeFeatures {
1992 online_duration: 1000.0,
1993 avg_response_time: 100.0,
1994 resource_contribution: 0.5,
1995 message_frequency: 10.0,
1996 time_of_day: 2.0, day_of_week: 3.0,
1998 historical_reliability: 0.8,
1999 recent_disconnections: 2.0,
2000 avg_session_length: 2.0,
2001 connection_stability: 0.8,
2002 };
2003
2004 let patterns = predictor.analyze_patterns(&night_features).await;
2005 assert_eq!(patterns.get("night_time"), Some(&1.0));
2006 assert_eq!(patterns.get("weekend"), Some(&0.0));
2007
2008 let unstable_features = NodeFeatures {
2010 recent_disconnections: 7.0,
2011 connection_stability: 0.3,
2012 ..night_features
2013 };
2014
2015 let patterns = predictor.analyze_patterns(&unstable_features).await;
2016 assert_eq!(patterns.get("unstable"), Some(&1.0));
2017 assert_eq!(patterns.get("high_risk"), Some(&0.0));
2020 }
2021
2022 #[tokio::test]
2023 async fn test_churn_predictor_proactive_replication() -> Result<()> {
2024 let predictor = ChurnPredictor::new();
2025 let node_id = NodeId { hash: [1u8; 32] };
2026
2027 predictor
2029 .record_node_event(&node_id, NodeEvent::Connected)
2030 .await?;
2031
2032 let features = NodeFeatures {
2033 online_duration: 600.0, avg_response_time: 500.0,
2035 resource_contribution: 0.1,
2036 message_frequency: 2.0,
2037 time_of_day: 23.0,
2038 day_of_week: 5.0,
2039 historical_reliability: 0.3,
2040 recent_disconnections: 10.0,
2041 avg_session_length: 0.5,
2042 connection_stability: 0.1,
2043 };
2044
2045 predictor.update_node_behavior(&node_id, features).await?;
2046
2047 let _should_replicate = predictor.should_replicate(&node_id).await;
2049 let prediction = predictor.predict(&node_id).await;
2051 assert!(prediction.probability_1h > 0.0);
2052 Ok(())
2053 }
2054
2055 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
2056 async fn test_churn_predictor_online_learning() -> Result<()> {
2057 let predictor = ChurnPredictor::new();
2058 let node_id = NodeId { hash: [1u8; 32] };
2059
2060 let train_future = async {
2062 for i in 0..20 {
2063 let features = NodeFeatures {
2064 online_duration: (i * 1000) as f64,
2065 avg_response_time: 100.0,
2066 resource_contribution: 0.5,
2067 message_frequency: 10.0,
2068 time_of_day: 12.0,
2069 day_of_week: 3.0,
2070 historical_reliability: 0.8,
2071 recent_disconnections: (i % 5) as f64,
2072 avg_session_length: 2.0,
2073 connection_stability: 0.8,
2074 };
2075
2076 let churned = i % 3 == 0;
2078 predictor
2079 .add_training_example(&node_id, features, churned, churned, churned)
2080 .await?;
2081 }
2082 anyhow::Ok(())
2083 };
2084
2085 tokio::time::timeout(std::time::Duration::from_secs(5), train_future)
2086 .await
2087 .map_err(|_| anyhow::anyhow!("training timed out"))??;
2088
2089 let prediction = tokio::time::timeout(
2091 std::time::Duration::from_secs(2),
2092 predictor.predict(&node_id),
2093 )
2094 .await
2095 .map_err(|_| anyhow::anyhow!("predict timed out"))?;
2096 assert!(prediction.confidence > 0.0);
2097 Ok(())
2098 }
2099
2100 #[tokio::test]
2101 async fn test_churn_predictor_model_persistence() -> Result<()> {
2102 let predictor = ChurnPredictor::new();
2103 let temp_dir = std::env::temp_dir();
2105 let temp_path = temp_dir.join("test_churn_model.json");
2106
2107 predictor.save_model(&temp_path).await?;
2109
2110 predictor.load_model(&temp_path).await?;
2112
2113 let _ = std::fs::remove_file(&temp_path);
2115 Ok(())
2116 }
2117}