1use super::beta_distribution::BetaDistribution;
20use super::*;
21use rand::Rng;
22use std::collections::HashMap;
23use std::sync::Arc;
24use tokio::sync::RwLock;
25
26pub struct ThompsonSampling {
31 arms: Arc<RwLock<HashMap<(ContentType, StrategyChoice), BetaParams>>>,
34
35 min_samples: u32,
37
38 decay_factor: f64,
40
41 metrics: Arc<RwLock<RoutingMetrics>>,
43}
44
45#[derive(Debug, Clone)]
47struct BetaParams {
48 distribution: BetaDistribution,
50 trials: u32,
52 last_update: std::time::Instant,
54}
55
56impl Default for BetaParams {
57 fn default() -> Self {
58 let distribution = BetaDistribution::new(1.0, 1.0).unwrap_or(BetaDistribution {
59 alpha: 1.0,
60 beta: 1.0,
61 });
62 Self {
63 distribution,
64 trials: 0,
65 last_update: std::time::Instant::now(),
66 }
67 }
68}
69
70#[derive(Debug, Default, Clone)]
72pub struct RoutingMetrics {
73 pub total_decisions: u64,
75 pub decisions_by_type: HashMap<ContentType, u64>,
77 pub strategy_success_rates: HashMap<StrategyChoice, f64>,
79 pub strategy_latencies: HashMap<StrategyChoice, f64>,
81}
82
83impl Default for ThompsonSampling {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl ThompsonSampling {
90 pub fn new() -> Self {
92 Self {
93 arms: Arc::new(RwLock::new(HashMap::new())),
94 min_samples: 10,
95 decay_factor: 0.99,
96 metrics: Arc::new(RwLock::new(RoutingMetrics::default())),
97 }
98 }
99
100 pub async fn select_strategy(&self, content_type: ContentType) -> Result<StrategyChoice> {
102 let mut arms = self.arms.write().await;
103 let mut metrics = self.metrics.write().await;
104
105 metrics.total_decisions += 1;
106 *metrics.decisions_by_type.entry(content_type).or_insert(0) += 1;
107
108 let strategies = vec![
109 StrategyChoice::Kademlia,
110 StrategyChoice::Hyperbolic,
111 StrategyChoice::TrustPath,
112 StrategyChoice::SOMRegion,
113 ];
114
115 let mut best_strategy = StrategyChoice::Kademlia;
116 let mut best_sample = 0.0;
117
118 for strategy in &strategies {
120 let key = (content_type, *strategy);
121 let params = arms.entry(key).or_default();
122
123 if params.trials > 0 {
125 let elapsed = params.last_update.elapsed().as_secs() as f64;
126 let decay = self.decay_factor.powf(elapsed / 3600.0); let alpha = params.distribution.alpha;
128 let beta = params.distribution.beta;
129 let new_alpha = 1.0 + (alpha - 1.0) * decay;
130 let new_beta = 1.0 + (beta - 1.0) * decay;
131 params.distribution =
132 BetaDistribution::new(new_alpha, new_beta).unwrap_or(BetaDistribution {
133 alpha: new_alpha.max(f64::MIN_POSITIVE),
134 beta: new_beta.max(f64::MIN_POSITIVE),
135 });
136 }
137
138 let mut rng = rand::thread_rng();
140 let sample = params.distribution.sample(&mut rng);
141
142 let exploration_bonus = if params.trials < self.min_samples {
144 0.1 * (1.0 - (params.trials as f64 / self.min_samples as f64))
145 } else {
146 0.0
147 };
148
149 let adjusted_sample = sample + exploration_bonus;
150
151 if adjusted_sample > best_sample {
152 best_sample = adjusted_sample;
153 best_strategy = *strategy;
154 }
155 }
156
157 Ok(best_strategy)
158 }
159
160 pub async fn update(
162 &self,
163 _content_type: ContentType,
164 strategy: StrategyChoice,
165 success: bool,
166 latency_ms: u64,
167 ) -> anyhow::Result<()> {
168 let mut arms = self.arms.write().await;
169 let mut metrics = self.metrics.write().await;
170
171 let key = (_content_type, strategy);
172 let params = arms.entry(key).or_default();
173
174 params.distribution.update(success);
176 params.trials += 1;
177 params.last_update = std::time::Instant::now();
178
179 let success_rate = params.distribution.mean();
181 let current_rate = metrics
182 .strategy_success_rates
183 .entry(strategy)
184 .or_insert(0.5);
185 *current_rate = 0.9 * (*current_rate) + 0.1 * success_rate;
186
187 let current_latency = metrics
189 .strategy_latencies
190 .entry(strategy)
191 .or_insert(latency_ms as f64);
192 *current_latency = 0.9 * (*current_latency) + 0.1 * (latency_ms as f64);
193
194 Ok(())
195 }
196
197 pub async fn get_metrics(&self) -> RoutingMetrics {
199 self.metrics.read().await.clone()
200 }
201
202 pub async fn get_confidence_interval(
204 &self,
205 _content_type: ContentType,
206 strategy: StrategyChoice,
207 ) -> (f64, f64) {
208 let arms = self.arms.read().await;
209 let key = (_content_type, strategy);
210
211 if let Some(params) = arms.get(&key) {
212 if params.trials == 0 {
213 return (0.0, 1.0);
214 }
215
216 params.distribution.confidence_interval()
218 } else {
219 (0.0, 1.0)
220 }
221 }
222
223 pub async fn reset_strategy(&self, _content_type: ContentType, strategy: StrategyChoice) {
225 let mut arms = self.arms.write().await;
226 arms.remove(&(_content_type, strategy));
227 }
228}
229
230#[async_trait]
231impl LearningSystem for ThompsonSampling {
232 async fn select_strategy(&self, context: &LearningContext) -> StrategyChoice {
233 self.select_strategy(context.content_type)
234 .await
235 .unwrap_or(StrategyChoice::Kademlia)
236 }
237
238 async fn update(
239 &mut self,
240 context: &LearningContext,
241 choice: &StrategyChoice,
242 outcome: &Outcome,
243 ) {
244 let _ = ThompsonSampling::update(
245 self,
246 context.content_type,
247 *choice,
248 outcome.success,
249 outcome.latency_ms,
250 )
251 .await;
252 }
253
254 async fn metrics(&self) -> LearningMetrics {
255 let metrics = self.get_metrics().await;
256
257 LearningMetrics {
258 total_decisions: metrics.total_decisions,
259 success_rate: metrics.strategy_success_rates.values().sum::<f64>()
260 / metrics.strategy_success_rates.len().max(1) as f64,
261 avg_latency_ms: metrics.strategy_latencies.values().sum::<f64>()
262 / metrics.strategy_latencies.len().max(1) as f64,
263 strategy_performance: metrics.strategy_success_rates.clone(),
264 }
265 }
266}
267
268#[derive(Debug, Clone, Default)]
270pub struct CacheStats {
271 pub hits: u64,
273
274 pub misses: u64,
276
277 pub size_bytes: u64,
279
280 pub item_count: u64,
282
283 pub evictions: u64,
285
286 pub hit_rate: f64,
288}
289
290pub struct QLearnCacheManager {
292 q_table: Arc<tokio::sync::RwLock<HashMap<CacheState, HashMap<CacheAction, f64>>>>,
294
295 learning_rate: f64,
297
298 discount_factor: f64,
300
301 epsilon: f64,
303
304 cache: Arc<tokio::sync::RwLock<HashMap<ContentHash, CachedContent>>>,
306
307 capacity: usize,
309
310 current_size: Arc<std::sync::atomic::AtomicUsize>,
312
313 request_stats: Arc<tokio::sync::RwLock<HashMap<ContentHash, RequestStats>>>,
315
316 hit_count: Arc<std::sync::atomic::AtomicU64>,
318 miss_count: Arc<std::sync::atomic::AtomicU64>,
319
320 _bandwidth_used: Arc<std::sync::atomic::AtomicU64>,
322}
323
324#[derive(Debug, Clone)]
326pub struct RequestStats {
327 request_count: u64,
329 hourly_requests: u64,
331 last_request: std::time::Instant,
333 content_size: usize,
335}
336
337#[derive(Debug, Clone, Hash, PartialEq, Eq)]
339pub struct CacheState {
340 utilization_bucket: u8,
342
343 request_rate_bucket: u8,
345
346 content_popularity: u8,
348
349 size_bucket: u8,
351}
352
353#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
355pub enum CacheAction {
356 Cache,
357 Evict(EvictionPolicy),
358 IncreaseReplication,
359 DecreaseReplication,
360 NoAction,
361}
362
363#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
365pub enum EvictionPolicy {
366 LRU,
367 LFU,
368 Random,
369}
370
371#[derive(Debug, Clone)]
373pub struct CachedContent {
374 pub data: Vec<u8>,
375 pub access_count: u64,
376 pub last_access: std::time::Instant,
377 pub insertion_time: std::time::Instant,
378}
379
380impl QLearnCacheManager {
381 pub fn new(capacity: usize) -> Self {
383 Self {
384 q_table: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
385 learning_rate: 0.1,
386 discount_factor: 0.9,
387 epsilon: 0.1,
388 cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
389 capacity,
390 current_size: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
391 request_stats: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
392 hit_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
393 miss_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
394 _bandwidth_used: Arc::new(std::sync::atomic::AtomicU64::new(0)),
395 }
396 }
397
398 pub async fn decide_action(&self, content_hash: &ContentHash) -> CacheAction {
400 let state = self.get_current_state(content_hash);
401
402 if rand::random::<f64>() < self.epsilon {
403 self.random_action()
405 } else {
406 let q_table = self.q_table.read().await;
408 q_table
409 .get(&state)
410 .and_then(|actions| {
411 actions
412 .iter()
413 .max_by(|(_, a), (_, b)| {
414 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
415 })
416 .map(|(action, _)| *action)
417 })
418 .unwrap_or(CacheAction::NoAction)
419 }
420 }
421
422 pub async fn update_q_value(
424 &self,
425 state: CacheState,
426 action: CacheAction,
427 reward: f64,
428 next_state: CacheState,
429 ) {
430 let mut q_table = self.q_table.write().await;
431
432 let current_q = q_table
433 .entry(state.clone())
434 .or_insert_with(HashMap::new)
435 .get(&action)
436 .copied()
437 .unwrap_or(0.0);
438
439 let max_next_q = q_table
440 .get(&next_state)
441 .and_then(|actions| {
442 actions
443 .values()
444 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
445 .copied()
446 })
447 .unwrap_or(0.0);
448
449 let new_q = current_q
450 + self.learning_rate * (reward + self.discount_factor * max_next_q - current_q);
451
452 q_table
453 .entry(state)
454 .or_insert_with(HashMap::new)
455 .insert(action, new_q);
456 }
457
458 fn get_current_state(&self, _content_hash: &ContentHash) -> CacheState {
460 let utilization = (self.current_size.load(std::sync::atomic::Ordering::Relaxed) * 10
461 / self.capacity) as u8;
462
463 let (request_rate_bucket, content_popularity, size_bucket) = (5, 5, 5);
466
467 CacheState {
468 utilization_bucket: utilization.min(10),
469 request_rate_bucket,
470 content_popularity,
471 size_bucket,
472 }
473 }
474
475 pub async fn get_current_state_async(&self, content_hash: &ContentHash) -> CacheState {
477 let utilization = (self.current_size.load(std::sync::atomic::Ordering::Relaxed) * 10
478 / self.capacity) as u8;
479
480 let stats = self.request_stats.read().await;
481 let (request_rate_bucket, content_popularity, size_bucket) =
482 if let Some(stat) = stats.get(content_hash) {
483 let hourly_rate = stat.hourly_requests.min(100) / 10;
485
486 let popularity = (stat.request_count.min(1000) / 100) as u8;
488
489 let size_bucket = match stat.content_size {
491 0..=1_024 => 0, 1_025..=10_240 => 1, 10_241..=102_400 => 2, 102_401..=1_048_576 => 3, 1_048_577..=10_485_760 => 4, 10_485_761..=104_857_600 => 5, 104_857_601..=1_073_741_824 => 6, _ => 7, };
500
501 (hourly_rate as u8, popularity, size_bucket)
502 } else {
503 (0, 0, 0) };
505
506 CacheState {
507 utilization_bucket: utilization.min(10),
508 request_rate_bucket,
509 content_popularity,
510 size_bucket,
511 }
512 }
513
514 fn random_action(&self) -> CacheAction {
516 match rand::random::<u8>() % 5 {
517 0 => CacheAction::Cache,
518 1 => CacheAction::Evict(EvictionPolicy::LRU),
519 2 => CacheAction::Evict(EvictionPolicy::LFU),
520 3 => CacheAction::IncreaseReplication,
521 4 => CacheAction::DecreaseReplication,
522 _ => CacheAction::NoAction,
523 }
524 }
525
526 pub async fn insert(&self, hash: ContentHash, data: Vec<u8>) -> bool {
528 let size = data.len();
529
530 while self.current_size.load(std::sync::atomic::Ordering::Relaxed) + size > self.capacity {
532 if !self.evict_one().await {
533 return false;
534 }
535 }
536
537 let mut cache = self.cache.write().await;
538 cache.insert(
539 hash,
540 CachedContent {
541 data,
542 access_count: 0,
543 last_access: std::time::Instant::now(),
544 insertion_time: std::time::Instant::now(),
545 },
546 );
547
548 self.current_size
549 .fetch_add(size, std::sync::atomic::Ordering::Relaxed);
550 true
551 }
552
553 async fn evict_one(&self) -> bool {
555 let mut cache = self.cache.write().await;
557 let oldest = cache
558 .iter()
559 .min_by_key(|(_, content)| content.last_access)
560 .map(|(k, _)| *k);
561
562 if let Some(key) = oldest
563 && let Some(value) = cache.remove(&key)
564 {
565 self.current_size
566 .fetch_sub(value.data.len(), std::sync::atomic::Ordering::Relaxed);
567 return true;
568 }
569
570 false
571 }
572
573 pub async fn get(&self, hash: &ContentHash) -> Option<Vec<u8>> {
575 let cache_result = {
576 let mut cache = self.cache.write().await;
577 cache.get_mut(hash).map(|entry| {
578 entry.access_count += 1;
579 entry.last_access = std::time::Instant::now();
580 entry.data.clone()
581 })
582 };
583
584 if cache_result.is_some() {
586 self.hit_count
587 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
588 } else {
589 self.miss_count
590 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
591 }
592
593 if let Some(ref data) = cache_result {
595 let mut stats = self.request_stats.write().await;
596 let stat = stats.entry(*hash).or_insert_with(|| RequestStats {
597 request_count: 0,
598 hourly_requests: 0,
599 last_request: std::time::Instant::now(),
600 content_size: data.len(),
601 });
602 stat.request_count += 1;
603 stat.hourly_requests += 1; stat.last_request = std::time::Instant::now();
605 }
606
607 cache_result
608 }
609
610 pub fn calculate_reward(&self, action: CacheAction, hit: bool, bandwidth_cost: u64) -> f64 {
612 let hits = self.hit_count.load(std::sync::atomic::Ordering::Relaxed) as f64;
613 let misses = self.miss_count.load(std::sync::atomic::Ordering::Relaxed) as f64;
614 let hit_rate = if hits + misses > 0.0 {
615 hits / (hits + misses)
616 } else {
617 0.0
618 };
619
620 let storage_cost = self.current_size.load(std::sync::atomic::Ordering::Relaxed) as f64
622 / self.capacity as f64;
623
624 let bandwidth_cost_normalized = bandwidth_cost as f64 / 1_000_000.0; match action {
629 CacheAction::Cache => {
630 if hit {
631 hit_rate - storage_cost * 0.1 - bandwidth_cost_normalized * 0.01
632 } else {
633 -0.1 - bandwidth_cost_normalized * 0.1 }
635 }
636 CacheAction::Evict(_) => {
637 if hit {
638 -0.5 } else {
640 0.1 - storage_cost * 0.05 }
642 }
643 CacheAction::IncreaseReplication => {
644 if hit {
645 hit_rate * 0.5 - bandwidth_cost_normalized * 0.2
646 } else {
647 -0.2 - bandwidth_cost_normalized * 0.2
648 }
649 }
650 CacheAction::DecreaseReplication => {
651 if hit {
652 -0.3 } else {
654 0.05 + storage_cost * 0.05 }
656 }
657 CacheAction::NoAction => {
658 hit_rate * 0.1 - storage_cost * 0.01 }
660 }
661 }
662
663 pub async fn execute_action(
665 &self,
666 hash: &ContentHash,
667 action: CacheAction,
668 data: Option<Vec<u8>>,
669 ) -> Result<()> {
670 match action {
671 CacheAction::Cache => {
672 if let Some(content) = data {
673 self.insert(*hash, content).await;
674 }
675 }
676 CacheAction::Evict(policy) => {
677 match policy {
678 EvictionPolicy::LRU => self.evict_lru().await,
679 EvictionPolicy::LFU => self.evict_lfu().await,
680 EvictionPolicy::Random => self.evict_random().await,
681 };
682 }
683 CacheAction::IncreaseReplication => {
684 }
687 CacheAction::DecreaseReplication => {
688 }
691 CacheAction::NoAction => {
692 }
694 }
695 Ok(())
696 }
697
698 async fn evict_lru(&self) -> bool {
700 self.evict_one().await
701 }
702
703 async fn evict_lfu(&self) -> bool {
705 let mut cache = self.cache.write().await;
706 let least_frequent = cache
707 .iter()
708 .min_by_key(|(_, content)| content.access_count)
709 .map(|(k, _)| *k);
710
711 if let Some(key) = least_frequent
712 && let Some(value) = cache.remove(&key)
713 {
714 self.current_size
715 .fetch_sub(value.data.len(), std::sync::atomic::Ordering::Relaxed);
716 return true;
717 }
718 false
719 }
720
721 async fn evict_random(&self) -> bool {
723 let cache = self.cache.read().await;
724 if cache.is_empty() {
725 return false;
726 }
727
728 let random_idx = rand::random::<usize>() % cache.len();
729 let random_key = cache.keys().nth(random_idx).cloned();
730 drop(cache);
731
732 if let Some(key) = random_key {
733 let mut cache = self.cache.write().await;
734 if let Some(value) = cache.remove(&key) {
735 self.current_size
736 .fetch_sub(value.data.len(), std::sync::atomic::Ordering::Relaxed);
737 return true;
738 }
739 }
740 false
741 }
742
743 pub fn get_stats(&self) -> CacheStats {
745 let hits = self.hit_count.load(std::sync::atomic::Ordering::Relaxed);
746 let misses = self.miss_count.load(std::sync::atomic::Ordering::Relaxed);
747 let _hit_rate = if hits + misses > 0 {
748 hits as f64 / (hits + misses) as f64
749 } else {
750 0.0
751 };
752
753 CacheStats {
754 hits,
755 misses,
756 size_bytes: self.current_size.load(std::sync::atomic::Ordering::Relaxed) as u64,
757 item_count: 0, evictions: 0, hit_rate: if hits + misses > 0 {
760 hits as f64 / (hits + misses) as f64
761 } else {
762 0.0
763 },
764 }
765 }
766
767 pub async fn decide_caching(
769 &self,
770 hash: ContentHash,
771 data: Vec<u8>,
772 _content_type: ContentType,
773 ) -> Result<()> {
774 let _state = self.get_current_state_async(&hash).await;
775 let action = self.decide_action(&hash).await;
776
777 if matches!(action, CacheAction::Cache) {
778 let _ = self.insert(hash, data).await;
779 }
780
781 Ok(())
782 }
783
784 pub async fn get_stats_async(&self) -> CacheStats {
786 let cache = self.cache.read().await;
787 let hit_count = self.hit_count.load(std::sync::atomic::Ordering::Relaxed);
788 let miss_count = self.miss_count.load(std::sync::atomic::Ordering::Relaxed);
789 let total = hit_count + miss_count;
790
791 CacheStats {
792 hits: hit_count,
793 misses: miss_count,
794 size_bytes: self.current_size.load(std::sync::atomic::Ordering::Relaxed) as u64,
795 item_count: cache.len() as u64,
796 evictions: 0, hit_rate: if total > 0 {
798 hit_count as f64 / total as f64
799 } else {
800 0.0
801 },
802 }
803 }
804}
805
806#[derive(Debug, Clone)]
808pub struct NodeFeatures {
809 pub online_duration: f64,
811 pub avg_response_time: f64,
813 pub resource_contribution: f64,
815 pub message_frequency: f64,
817 pub time_of_day: f64,
819 pub day_of_week: f64,
821 pub historical_reliability: f64,
823 pub recent_disconnections: f64,
825 pub avg_session_length: f64,
827 pub connection_stability: f64,
829}
830
831#[derive(Debug, Clone)]
833pub struct FeatureHistory {
834 pub node_id: NodeId,
836 pub snapshots: Vec<(std::time::Instant, NodeFeatures)>,
838 pub sessions: Vec<(std::time::Instant, Option<std::time::Instant>)>,
840 pub total_uptime: u64,
842 pub total_downtime: u64,
844}
845
846impl Default for FeatureHistory {
847 fn default() -> Self {
848 Self::new()
849 }
850}
851
852impl FeatureHistory {
853 pub fn new() -> Self {
855 Self {
856 node_id: NodeId { hash: [0u8; 32] },
857 snapshots: Vec::new(),
858 sessions: Vec::new(),
859 total_uptime: 0,
860 total_downtime: 0,
861 }
862 }
863}
864
865#[derive(Debug)]
867pub struct ChurnPredictor {
868 prediction_cache: Arc<tokio::sync::RwLock<HashMap<NodeId, ChurnPrediction>>>,
870
871 feature_history: Arc<tokio::sync::RwLock<HashMap<NodeId, FeatureHistory>>>,
873
874 model_weights: Arc<tokio::sync::RwLock<ModelWeights>>,
876
877 experience_buffer: Arc<tokio::sync::RwLock<Vec<TrainingExample>>>,
879
880 max_buffer_size: usize,
882
883 _update_interval: std::time::Duration,
885}
886
887#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
889pub struct ModelWeights {
890 pub feature_weights: Vec<f64>,
892 pub time_decay: Vec<f64>,
894 pub pattern_weights: HashMap<String, f64>,
896 pub bias: Vec<f64>,
898}
899
900impl Default for ModelWeights {
901 fn default() -> Self {
902 Self {
903 feature_weights: vec![
905 0.15, 0.20, 0.10, 0.05, 0.05, 0.05, 0.25, 0.10, 0.05, 0.00, ],
916 time_decay: vec![0.9, 0.8, 0.7], pattern_weights: HashMap::new(),
918 bias: vec![0.1, 0.2, 0.3], }
920 }
921}
922
923#[derive(Debug, Clone)]
925pub struct TrainingExample {
926 pub node_id: NodeId,
927 pub features: NodeFeatures,
928 pub timestamp: std::time::Instant,
929 pub actual_churn_1h: bool,
930 pub actual_churn_6h: bool,
931 pub actual_churn_24h: bool,
932}
933
934#[derive(Debug, Clone)]
936pub struct ChurnPrediction {
937 pub probability_1h: f64,
938 pub probability_6h: f64,
939 pub probability_24h: f64,
940 pub confidence: f64,
941 pub timestamp: std::time::Instant,
942}
943
944impl Default for ChurnPredictor {
945 fn default() -> Self {
946 Self::new()
947 }
948}
949
950impl ChurnPredictor {
951 pub fn new() -> Self {
953 Self {
954 prediction_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
955 feature_history: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
956 model_weights: Arc::new(tokio::sync::RwLock::new(ModelWeights::default())),
957 experience_buffer: Arc::new(tokio::sync::RwLock::new(Vec::new())),
958 max_buffer_size: 10000,
959 _update_interval: std::time::Duration::from_secs(3600), }
961 }
962
963 pub async fn should_replicate(&self, node_id: &NodeId) -> bool {
965 let prediction = self.predict(node_id).await;
966 prediction.probability_1h > 0.7
967 }
968
969 pub async fn update_node_features(
971 &self,
972 node_id: &NodeId,
973 features: Vec<f64>,
974 ) -> anyhow::Result<()> {
975 if features.len() != 10 {
976 return Err(anyhow::anyhow!(
977 "Expected 10 features, got {}",
978 features.len()
979 ));
980 }
981
982 let node_features = NodeFeatures {
983 online_duration: features[0],
984 avg_response_time: features[1],
985 resource_contribution: features[2],
986 message_frequency: features[3],
987 time_of_day: features[4],
988 day_of_week: features[5],
989 historical_reliability: features[6],
990 recent_disconnections: features[7],
991 avg_session_length: features[8],
992 connection_stability: features[9],
993 };
994
995 let mut history = self.feature_history.write().await;
997 let entry = history
998 .entry(node_id.clone())
999 .or_insert(FeatureHistory::new());
1000
1001 if entry.sessions.is_empty() || entry.sessions.last().map(|s| s.1.is_some()).unwrap_or(true)
1003 {
1004 entry.sessions.push((std::time::Instant::now(), None));
1006 }
1007
1008 entry
1010 .snapshots
1011 .push((std::time::Instant::now(), node_features));
1012
1013 while entry.snapshots.len() > 100 {
1015 entry.snapshots.remove(0);
1016 }
1017
1018 Ok(())
1019 }
1020
1021 pub async fn extract_features(&self, node_id: &NodeId) -> Option<NodeFeatures> {
1023 let history = self.feature_history.read().await;
1024 let node_history = history.get(node_id)?;
1025
1026 let now = std::time::Instant::now();
1028 let current_session = node_history.sessions.last()?;
1029 let online_duration = if current_session.1.is_none() {
1030 now.duration_since(current_session.0).as_secs() as f64
1031 } else {
1032 0.0
1033 };
1034
1035 let completed_sessions: Vec<_> = node_history
1037 .sessions
1038 .iter()
1039 .filter_map(|(start, end)| {
1040 end.as_ref()
1041 .map(|e| e.duration_since(*start).as_secs() as f64)
1042 })
1043 .collect();
1044 let avg_session_length = if !completed_sessions.is_empty() {
1045 completed_sessions.iter().sum::<f64>() / completed_sessions.len() as f64 / 3600.0
1046 } else {
1047 1.0 };
1049
1050 let one_week_ago = now - std::time::Duration::from_secs(7 * 24 * 3600);
1052 let recent_disconnections = node_history
1053 .sessions
1054 .iter()
1055 .filter(|(start, end)| end.is_some() && *start > one_week_ago)
1056 .count() as f64;
1057
1058 let latest_snapshot = node_history
1060 .snapshots
1061 .last()
1062 .map(|(_, features)| features.clone())
1063 .unwrap_or_else(|| NodeFeatures {
1064 online_duration,
1065 avg_response_time: 100.0,
1066 resource_contribution: 0.5,
1067 message_frequency: 10.0,
1068 time_of_day: 12.0, day_of_week: 3.0, historical_reliability: node_history.total_uptime as f64
1071 / (node_history.total_uptime + node_history.total_downtime).max(1) as f64,
1072 recent_disconnections,
1073 avg_session_length,
1074 connection_stability: 1.0 - (recent_disconnections / 7.0).min(1.0),
1075 });
1076
1077 Some(NodeFeatures {
1078 online_duration,
1079 recent_disconnections,
1080 avg_session_length,
1081 historical_reliability: node_history.total_uptime as f64
1082 / (node_history.total_uptime + node_history.total_downtime).max(1) as f64,
1083 connection_stability: 1.0 - (recent_disconnections / 7.0).min(1.0),
1084 ..latest_snapshot
1085 })
1086 }
1087
1088 async fn analyze_patterns(&self, features: &NodeFeatures) -> HashMap<String, f64> {
1090 let mut patterns = HashMap::new();
1091
1092 let is_night = features.time_of_day < 6.0 || features.time_of_day > 22.0;
1094 let is_weekend = features.day_of_week == 0.0 || features.day_of_week == 6.0;
1095
1096 patterns.insert("night_time".to_string(), if is_night { 1.0 } else { 0.0 });
1097 patterns.insert("weekend".to_string(), if is_weekend { 1.0 } else { 0.0 });
1098
1099 patterns.insert(
1101 "short_session".to_string(),
1102 if features.online_duration < 1800.0 {
1103 1.0
1104 } else {
1105 0.0
1106 },
1107 );
1108 patterns.insert(
1109 "unstable".to_string(),
1110 if features.recent_disconnections > 5.0 {
1111 1.0
1112 } else {
1113 0.0
1114 },
1115 );
1116 patterns.insert(
1117 "low_contribution".to_string(),
1118 if features.resource_contribution < 0.3 {
1119 1.0
1120 } else {
1121 0.0
1122 },
1123 );
1124 patterns.insert(
1125 "slow_response".to_string(),
1126 if features.avg_response_time > 500.0 {
1127 1.0
1128 } else {
1129 0.0
1130 },
1131 );
1132
1133 let risk_score = (features.recent_disconnections / 10.0).min(1.0) * 0.3
1135 + (1.0 - features.historical_reliability) * 0.4
1136 + (1.0 - features.connection_stability) * 0.3;
1137 patterns.insert(
1138 "high_risk".to_string(),
1139 if risk_score > 0.6 { 1.0 } else { 0.0 },
1140 );
1141
1142 patterns
1143 }
1144
1145 pub async fn predict(&self, node_id: &NodeId) -> ChurnPrediction {
1147 {
1149 let cache = self.prediction_cache.read().await;
1150 if let Some(cached) = cache.get(node_id)
1151 && cached.timestamp.elapsed() < std::time::Duration::from_secs(300)
1152 {
1153 return cached.clone();
1154 }
1155 }
1156
1157 let features = match self.extract_features(node_id).await {
1159 Some(f) => f,
1160 None => {
1161 return ChurnPrediction {
1163 probability_1h: 0.1,
1164 probability_6h: 0.2,
1165 probability_24h: 0.3,
1166 confidence: 0.1,
1167 timestamp: std::time::Instant::now(),
1168 };
1169 }
1170 };
1171
1172 let patterns = self.analyze_patterns(&features).await;
1174
1175 let model = self.model_weights.read().await;
1177 let prediction = self.apply_model(&features, &patterns, &model).await;
1178
1179 let mut cache = self.prediction_cache.write().await;
1181 cache.insert(node_id.clone(), prediction.clone());
1182 prediction
1183 }
1184
1185 async fn apply_model(
1187 &self,
1188 features: &NodeFeatures,
1189 patterns: &HashMap<String, f64>,
1190 model: &ModelWeights,
1191 ) -> ChurnPrediction {
1192 let feature_vec = [
1194 features.online_duration / 3600.0, features.avg_response_time / 1000.0, features.resource_contribution,
1197 features.message_frequency / 100.0, features.time_of_day / 24.0, features.day_of_week / 7.0, features.historical_reliability,
1201 features.recent_disconnections / 10.0, features.avg_session_length / 24.0, features.connection_stability,
1204 ];
1205
1206 let mut base_scores = [0.0; 3]; for (i, &weight) in model.feature_weights.iter().enumerate() {
1209 if i < feature_vec.len() {
1210 for score in &mut base_scores {
1211 *score += weight * feature_vec[i];
1212 }
1213 }
1214 }
1215
1216 let mut pattern_score = 0.0;
1218 for (pattern, &value) in patterns {
1219 if let Some(&weight) = model.pattern_weights.get(pattern) {
1220 pattern_score += weight * value;
1221 } else {
1222 pattern_score += 0.1 * value;
1224 }
1225 }
1226
1227 let probabilities: Vec<f64> = base_scores
1229 .iter()
1230 .zip(&model.time_decay)
1231 .zip(&model.bias)
1232 .map(|((base, decay), bias)| {
1233 let raw_score = base + pattern_score * decay + bias;
1234 1.0 / (1.0 + (-raw_score).exp())
1236 })
1237 .collect();
1238
1239 let confidence = 0.8; ChurnPrediction {
1243 probability_1h: probabilities[0].min(0.99),
1244 probability_6h: probabilities[1].min(0.99),
1245 probability_24h: probabilities[2].min(0.99),
1246 confidence,
1247 timestamp: std::time::Instant::now(),
1248 }
1249 }
1250
1251 pub async fn update_node_behavior(
1253 &self,
1254 node_id: &NodeId,
1255 features: NodeFeatures,
1256 ) -> anyhow::Result<()> {
1257 let mut history = self.feature_history.write().await;
1258 let node_history = history
1259 .entry(node_id.clone())
1260 .or_insert_with(|| FeatureHistory {
1261 node_id: node_id.clone(),
1262 snapshots: Vec::new(),
1263 sessions: vec![(std::time::Instant::now(), None)],
1264 total_uptime: 0,
1265 total_downtime: 0,
1266 });
1267
1268 node_history
1270 .snapshots
1271 .push((std::time::Instant::now(), features));
1272
1273 let cutoff = std::time::Instant::now() - std::time::Duration::from_secs(24 * 3600);
1275 node_history
1276 .snapshots
1277 .retain(|(timestamp, _)| *timestamp > cutoff);
1278
1279 Ok(())
1280 }
1281
1282 pub async fn record_node_event(&self, node_id: &NodeId, event: NodeEvent) -> Result<()> {
1284 let mut history = self.feature_history.write().await;
1285 let node_history = history
1286 .entry(node_id.clone())
1287 .or_insert_with(|| FeatureHistory {
1288 node_id: node_id.clone(),
1289 snapshots: Vec::new(),
1290 sessions: Vec::new(),
1291 total_uptime: 0,
1292 total_downtime: 0,
1293 });
1294
1295 match event {
1296 NodeEvent::Connected => {
1297 node_history
1299 .sessions
1300 .push((std::time::Instant::now(), None));
1301 }
1302 NodeEvent::Disconnected => {
1303 if let Some((start, end)) = node_history.sessions.last_mut()
1305 && end.is_none()
1306 {
1307 let now = std::time::Instant::now();
1308 *end = Some(now);
1309 let session_length = now.duration_since(*start).as_secs();
1310 node_history.total_uptime += session_length;
1311 }
1312 }
1313 }
1314
1315 Ok(())
1316 }
1317
1318 pub async fn add_training_example(
1320 &self,
1321 node_id: &NodeId,
1322 features: NodeFeatures,
1323 actual_churn_1h: bool,
1324 actual_churn_6h: bool,
1325 actual_churn_24h: bool,
1326 ) -> anyhow::Result<()> {
1327 let example = TrainingExample {
1328 node_id: node_id.clone(),
1329 features,
1330 timestamp: std::time::Instant::now(),
1331 actual_churn_1h,
1332 actual_churn_6h,
1333 actual_churn_24h,
1334 };
1335
1336 let mut buffer = self.experience_buffer.write().await;
1337 buffer.push(example);
1338
1339 if buffer.len() > self.max_buffer_size {
1341 let drain_count = buffer.len() - self.max_buffer_size;
1342 buffer.drain(0..drain_count);
1343 }
1344
1345 if buffer.len() >= 32 && buffer.len() % 32 == 0 {
1347 self.update_model().await?;
1348 }
1349
1350 Ok(())
1351 }
1352
1353 async fn update_model(&self) -> anyhow::Result<()> {
1355 let buffer = self.experience_buffer.read().await;
1356 if buffer.is_empty() {
1357 return Ok(());
1358 }
1359
1360 let mut model = self.model_weights.write().await;
1361
1362 let learning_rate = 0.01;
1364 let batch_size = 32.min(buffer.len());
1365
1366 let mut rng = rand::thread_rng();
1368 let batch: Vec<_> = (0..batch_size)
1369 .map(|_| &buffer[rng.gen_range(0..buffer.len())])
1370 .collect();
1371
1372 for example in batch {
1374 let feature_vec = [
1376 example.features.online_duration / 3600.0,
1377 example.features.avg_response_time / 1000.0,
1378 example.features.resource_contribution,
1379 example.features.message_frequency / 100.0,
1380 example.features.time_of_day / 24.0,
1381 example.features.day_of_week / 7.0,
1382 example.features.historical_reliability,
1383 example.features.recent_disconnections / 10.0,
1384 example.features.avg_session_length / 24.0,
1385 example.features.connection_stability,
1386 ];
1387
1388 let patterns = self.analyze_patterns(&example.features).await;
1390
1391 let prediction = self.apply_model(&example.features, &patterns, &model).await;
1393
1394 let errors = [
1396 if example.actual_churn_1h { 1.0 } else { 0.0 } - prediction.probability_1h,
1397 if example.actual_churn_6h { 1.0 } else { 0.0 } - prediction.probability_6h,
1398 if example.actual_churn_24h { 1.0 } else { 0.0 } - prediction.probability_24h,
1399 ];
1400
1401 for (i, &feature_value) in feature_vec.iter().enumerate() {
1403 if i < model.feature_weights.len() {
1404 for (j, &error) in errors.iter().enumerate() {
1405 model.feature_weights[i] +=
1406 learning_rate * error * feature_value * model.time_decay[j];
1407 }
1408 }
1409 }
1410
1411 for (pattern, &value) in &patterns {
1413 let avg_error = errors.iter().sum::<f64>() / errors.len() as f64;
1414 model
1415 .pattern_weights
1416 .entry(pattern.clone())
1417 .and_modify(|w| *w += learning_rate * avg_error * value)
1418 .or_insert(learning_rate * avg_error * value);
1419 }
1420 }
1421
1422 Ok(())
1423 }
1424
1425 pub async fn save_model(&self, path: &std::path::Path) -> anyhow::Result<()> {
1427 let model = self.model_weights.read().await;
1428 let serialized = serde_json::to_string(&*model)?;
1429 tokio::fs::write(path, serialized).await?;
1430 Ok(())
1431 }
1432
1433 pub async fn load_model(&self, path: &std::path::Path) -> anyhow::Result<()> {
1435 let data = tokio::fs::read_to_string(path).await?;
1436 let loaded_model: ModelWeights = serde_json::from_str(&data)?;
1437 let mut model = self.model_weights.write().await;
1438 *model = loaded_model;
1439 Ok(())
1440 }
1441}
1442
1443#[derive(Debug, Clone)]
1445pub enum NodeEvent {
1446 Connected,
1447 Disconnected,
1448}
1449
1450#[cfg(test)]
1451mod tests {
1452 use super::*;
1453
1454 #[tokio::test]
1455 async fn test_thompson_sampling_initialization() {
1456 let ts = ThompsonSampling::new();
1457 let metrics = ts.get_metrics().await;
1458
1459 assert_eq!(metrics.total_decisions, 0);
1460 assert!(metrics.decisions_by_type.is_empty());
1461 assert!(metrics.strategy_success_rates.is_empty());
1462 }
1463
1464 #[tokio::test]
1465 async fn test_thompson_sampling_selection() -> Result<()> {
1466 let ts = ThompsonSampling::new();
1467
1468 for content_type in [
1470 ContentType::DHTLookup,
1471 ContentType::DataRetrieval,
1472 ContentType::ComputeRequest,
1473 ContentType::RealtimeMessage,
1474 ] {
1475 let strategy = ts.select_strategy(content_type).await?;
1476 assert!(matches!(
1477 strategy,
1478 StrategyChoice::Kademlia
1479 | StrategyChoice::Hyperbolic
1480 | StrategyChoice::TrustPath
1481 | StrategyChoice::SOMRegion
1482 ));
1483 }
1484
1485 let metrics = ts.get_metrics().await;
1486 assert_eq!(metrics.total_decisions, 4);
1487 assert_eq!(metrics.decisions_by_type.len(), 4);
1488 Ok(())
1489 }
1490
1491 #[tokio::test]
1492 async fn test_thompson_sampling_update() -> Result<()> {
1493 let ts = ThompsonSampling::new();
1494
1495 for _ in 0..20 {
1497 ts.update(
1498 ContentType::DataRetrieval,
1499 StrategyChoice::Hyperbolic,
1500 true,
1501 50,
1502 )
1503 .await?;
1504 }
1505
1506 for _ in 0..10 {
1508 ts.update(
1509 ContentType::DataRetrieval,
1510 StrategyChoice::Kademlia,
1511 false,
1512 200,
1513 )
1514 .await?;
1515 }
1516
1517 let mut hyperbolic_count = 0;
1519 for _ in 0..100 {
1520 let strategy = ts.select_strategy(ContentType::DataRetrieval).await?;
1521 if matches!(strategy, StrategyChoice::Hyperbolic) {
1522 hyperbolic_count += 1;
1523 }
1524 }
1525
1526 assert!(
1528 hyperbolic_count > 60,
1529 "Expected Hyperbolic to be selected more than 60% of the time, got {}%",
1530 hyperbolic_count
1531 );
1532 Ok(())
1533 }
1534
1535 #[tokio::test]
1536 async fn test_confidence_intervals() -> Result<()> {
1537 let ts = ThompsonSampling::new();
1538
1539 for i in 0..10 {
1541 ts.update(
1542 ContentType::DHTLookup,
1543 StrategyChoice::Kademlia,
1544 i % 3 != 0, 100,
1546 )
1547 .await?;
1548 }
1549
1550 let (lower, upper) = ts
1551 .get_confidence_interval(ContentType::DHTLookup, StrategyChoice::Kademlia)
1552 .await;
1553
1554 assert!(lower > 0.0 && lower < 1.0);
1555 assert!(upper > lower && upper <= 1.0);
1556 assert!(upper - lower < 0.5); Ok(())
1558 }
1559
1560 #[tokio::test]
1561 async fn test_exploration_bonus() -> Result<()> {
1562 let ts = ThompsonSampling::new();
1563
1564 for _ in 0..15 {
1566 ts.update(
1567 ContentType::ComputeRequest,
1568 StrategyChoice::TrustPath,
1569 true,
1570 100,
1571 )
1572 .await?;
1573 }
1574
1575 let mut strategy_counts = HashMap::new();
1577 for _ in 0..100 {
1578 let strategy = ts.select_strategy(ContentType::ComputeRequest).await?;
1579 *strategy_counts.entry(strategy).or_insert(0) += 1;
1580 }
1581
1582 assert!(
1584 strategy_counts.len() >= 3,
1585 "Expected at least 3 different strategies to be tried"
1586 );
1587 Ok(())
1588 }
1589
1590 #[tokio::test]
1591 async fn test_reset_strategy() -> Result<()> {
1592 let ts = ThompsonSampling::new();
1593
1594 for _ in 0..10 {
1596 ts.update(
1597 ContentType::RealtimeMessage,
1598 StrategyChoice::SOMRegion,
1599 true,
1600 50,
1601 )
1602 .await?;
1603 }
1604
1605 ts.reset_strategy(ContentType::RealtimeMessage, StrategyChoice::SOMRegion)
1607 .await;
1608
1609 let (lower, upper) = ts
1611 .get_confidence_interval(ContentType::RealtimeMessage, StrategyChoice::SOMRegion)
1612 .await;
1613
1614 assert_eq!(lower, 0.0);
1615 assert_eq!(upper, 1.0);
1616 Ok(())
1617 }
1618
1619 #[tokio::test]
1620 async fn test_learning_system_trait() {
1621 let mut ts = ThompsonSampling::new();
1622
1623 let context = LearningContext {
1624 content_type: ContentType::DataRetrieval,
1625 network_conditions: NetworkConditions {
1626 connected_peers: 100,
1627 avg_latency_ms: 50.0,
1628 churn_rate: 0.1,
1629 },
1630 historical_performance: vec![0.8, 0.85, 0.9],
1631 };
1632
1633 let choice = <ThompsonSampling as LearningSystem>::select_strategy(&ts, &context).await;
1635 assert!(matches!(
1636 choice,
1637 StrategyChoice::Kademlia
1638 | StrategyChoice::Hyperbolic
1639 | StrategyChoice::TrustPath
1640 | StrategyChoice::SOMRegion
1641 ));
1642
1643 let outcome = Outcome {
1644 success: true,
1645 latency_ms: 45,
1646 hops: 3,
1647 };
1648
1649 <ThompsonSampling as LearningSystem>::update(&mut ts, &context, &choice, &outcome).await;
1650
1651 let metrics = <ThompsonSampling as LearningSystem>::metrics(&ts).await;
1652 assert_eq!(metrics.total_decisions, 1);
1653 }
1654
1655 #[tokio::test]
1656 async fn test_cache_manager() {
1657 let manager = QLearnCacheManager::new(1024);
1658 let hash = ContentHash([1u8; 32]);
1659
1660 assert!(manager.insert(hash.clone(), vec![0u8; 100]).await);
1662
1663 assert!(manager.get(&hash).await.is_some());
1665
1666 let action = manager.decide_action(&hash).await;
1668 assert!(matches!(
1669 action,
1670 CacheAction::Cache
1671 | CacheAction::Evict(_)
1672 | CacheAction::IncreaseReplication
1673 | CacheAction::DecreaseReplication
1674 | CacheAction::NoAction
1675 ));
1676 }
1677
1678 #[tokio::test]
1679 async fn test_q_value_update() {
1680 let manager = QLearnCacheManager::new(1024);
1681
1682 let state = CacheState {
1683 utilization_bucket: 5,
1684 request_rate_bucket: 5,
1685 content_popularity: 5,
1686 size_bucket: 3,
1687 };
1688
1689 let next_state = CacheState {
1690 utilization_bucket: 6,
1691 request_rate_bucket: 5,
1692 content_popularity: 5,
1693 size_bucket: 3,
1694 };
1695
1696 manager
1697 .update_q_value(state, CacheAction::Cache, 1.0, next_state)
1698 .await;
1699
1700 let q_table = manager.q_table.read().await;
1702 assert!(q_table.len() > 0);
1703 }
1704
1705 #[tokio::test]
1706 async fn test_churn_predictor() {
1707 use crate::peer_record::UserId;
1708 use rand::RngCore;
1709
1710 let predictor = ChurnPredictor::new();
1711 let mut hash = [0u8; 32];
1712 rand::thread_rng().fill_bytes(&mut hash);
1713 let node_id = UserId::from_bytes(hash);
1714
1715 let prediction = predictor.predict(&node_id).await;
1716 assert!(prediction.probability_1h >= 0.0 && prediction.probability_1h <= 1.0);
1717 assert!(prediction.probability_6h >= 0.0 && prediction.probability_6h <= 1.0);
1718 assert!(prediction.probability_24h >= 0.0 && prediction.probability_24h <= 1.0);
1719
1720 let prediction2 = predictor.predict(&node_id).await;
1722 assert_eq!(prediction.probability_1h, prediction2.probability_1h);
1723 }
1724
1725 #[tokio::test]
1726 async fn test_cache_eviction_policies() {
1727 let manager = QLearnCacheManager::new(300); let hash1 = ContentHash([1u8; 32]);
1731 let hash2 = ContentHash([2u8; 32]);
1732 let hash3 = ContentHash([3u8; 32]);
1733
1734 manager.insert(hash1.clone(), vec![0u8; 100]).await;
1735 manager.insert(hash2.clone(), vec![0u8; 100]).await;
1736 manager.insert(hash3.clone(), vec![0u8; 100]).await;
1737
1738 manager.get(&hash1).await;
1740 manager.get(&hash2).await;
1741
1742 let hash4 = ContentHash([4u8; 32]);
1744 manager.insert(hash4.clone(), vec![0u8; 100]).await;
1745
1746 assert!(manager.get(&hash1).await.is_some());
1748 assert!(manager.get(&hash2).await.is_some());
1749 assert!(manager.get(&hash3).await.is_none());
1750 assert!(manager.get(&hash4).await.is_some());
1751 }
1752
1753 #[tokio::test]
1754 async fn test_reward_calculation() {
1755 let manager = QLearnCacheManager::new(1024);
1756
1757 let hash = ContentHash([1u8; 32]);
1759 manager.insert(hash.clone(), vec![0u8; 100]).await;
1760
1761 for _ in 0..5 {
1763 manager.get(&hash).await;
1764 }
1765
1766 let miss_hash = ContentHash([2u8; 32]);
1768 for _ in 0..2 {
1769 manager.get(&miss_hash).await;
1770 }
1771
1772 let cache_reward = manager.calculate_reward(CacheAction::Cache, true, 1000);
1774 assert!(cache_reward > 0.0); let evict_reward =
1777 manager.calculate_reward(CacheAction::Evict(EvictionPolicy::LRU), false, 0);
1778 assert!(evict_reward >= 0.0); let evict_penalty =
1781 manager.calculate_reward(CacheAction::Evict(EvictionPolicy::LRU), true, 0);
1782 assert!(evict_penalty < 0.0); }
1784
1785 #[tokio::test]
1786 async fn test_cache_statistics() {
1787 let manager = QLearnCacheManager::new(1024);
1788
1789 let hash1 = ContentHash([1u8; 32]);
1790 let hash2 = ContentHash([2u8; 32]);
1791
1792 manager.insert(hash1.clone(), vec![0u8; 100]).await;
1794
1795 manager.get(&hash1).await; manager.get(&hash1).await; manager.get(&hash2).await; let stats = manager.get_stats();
1801 assert_eq!(stats.hits, 2);
1802 assert_eq!(stats.misses, 1);
1803 assert!((stats.hit_rate - 0.666).abs() < 0.01); assert_eq!(stats.size_bytes, 100);
1805 }
1806
1807 #[tokio::test]
1808 async fn test_exploration_vs_exploitation() {
1809 let manager = QLearnCacheManager::new(1024);
1810
1811 let state = CacheState {
1813 utilization_bucket: 5,
1814 request_rate_bucket: 7,
1815 content_popularity: 8,
1816 size_bucket: 2,
1817 };
1818
1819 for _ in 0..10 {
1821 manager
1822 .update_q_value(state.clone(), CacheAction::Cache, 1.0, state.clone())
1823 .await;
1824 }
1825
1826 let mut cache_count = 0;
1828 for _ in 0..100 {
1829 let action = manager.decide_action(&ContentHash([1u8; 32])).await;
1832 if matches!(action, CacheAction::Cache) {
1833 cache_count += 1;
1834 }
1835 }
1836
1837 assert!(cache_count > 80 && cache_count < 95);
1840 }
1841
1842 #[tokio::test]
1843 async fn test_state_representation() {
1844 let manager = QLearnCacheManager::new(1024);
1845
1846 let hash = ContentHash([1u8; 32]);
1848
1849 manager.insert(hash.clone(), vec![0u8; 100]).await;
1851
1852 for _ in 0..5 {
1854 manager.get(&hash).await;
1855 }
1856
1857 let state = manager.get_current_state_async(&hash).await;
1858
1859 assert!(state.utilization_bucket <= 10);
1861 assert!(state.request_rate_bucket <= 10);
1862 assert!(state.content_popularity <= 10);
1863 assert!(state.size_bucket <= 10);
1864 }
1865
1866 #[tokio::test]
1867 async fn test_action_execution() -> Result<()> {
1868 let manager = QLearnCacheManager::new(1024);
1869 let hash = ContentHash([1u8; 32]);
1870
1871 manager
1873 .execute_action(&hash, CacheAction::Cache, Some(vec![0u8; 100]))
1874 .await?;
1875 assert!(manager.get(&hash).await.is_some());
1876
1877 manager
1879 .execute_action(&hash, CacheAction::NoAction, None)
1880 .await?;
1881 assert!(manager.get(&hash).await.is_some()); manager
1885 .execute_action(&hash, CacheAction::Evict(EvictionPolicy::LRU), None)
1886 .await?;
1887 let stats = manager.get_stats();
1890 assert!(stats.size_bytes <= 100); Ok(())
1892 }
1893
1894 #[tokio::test]
1895 async fn test_churn_predictor_initialization() {
1896 let predictor = ChurnPredictor::new();
1897
1898 let node_id = NodeId { hash: [1u8; 32] };
1900 let prediction = predictor.predict(&node_id).await;
1901
1902 assert!(prediction.confidence < 0.2);
1904 assert!(prediction.probability_1h < 0.3);
1905 assert!(prediction.probability_6h < 0.4);
1906 assert!(prediction.probability_24h < 0.5);
1907 }
1908
1909 #[tokio::test]
1910 async fn test_churn_predictor_node_events() -> Result<()> {
1911 let predictor = ChurnPredictor::new();
1912 let node_id = NodeId { hash: [1u8; 32] };
1913
1914 predictor
1916 .record_node_event(&node_id, NodeEvent::Connected)
1917 .await?;
1918
1919 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1921 predictor
1922 .record_node_event(&node_id, NodeEvent::Disconnected)
1923 .await?;
1924
1925 let features = predictor.extract_features(&node_id).await;
1927 assert!(features.is_some());
1928 Ok(())
1929 }
1930
1931 #[tokio::test]
1932 async fn test_churn_predictor_feature_extraction() -> Result<()> {
1933 let predictor = ChurnPredictor::new();
1934 let node_id = NodeId { hash: [1u8; 32] };
1935
1936 predictor
1938 .record_node_event(&node_id, NodeEvent::Connected)
1939 .await?;
1940
1941 let features = NodeFeatures {
1943 online_duration: 3600.0,
1944 avg_response_time: 50.0,
1945 resource_contribution: 0.8,
1946 message_frequency: 20.0,
1947 time_of_day: 14.0,
1948 day_of_week: 2.0,
1949 historical_reliability: 0.9,
1950 recent_disconnections: 1.0,
1951 avg_session_length: 4.0,
1952 connection_stability: 0.85,
1953 };
1954
1955 predictor
1956 .update_node_behavior(&node_id, features.clone())
1957 .await?;
1958
1959 let extracted = predictor
1961 .extract_features(&node_id)
1962 .await
1963 .ok_or(anyhow::anyhow!("no features extracted"))?;
1964 assert_eq!(extracted.resource_contribution, 0.8);
1965 assert_eq!(extracted.avg_response_time, 50.0);
1966 Ok(())
1967 }
1968
1969 #[tokio::test]
1970 async fn test_churn_predictor_pattern_analysis() {
1971 let predictor = ChurnPredictor::new();
1972
1973 let night_features = NodeFeatures {
1975 online_duration: 1000.0,
1976 avg_response_time: 100.0,
1977 resource_contribution: 0.5,
1978 message_frequency: 10.0,
1979 time_of_day: 2.0, day_of_week: 3.0,
1981 historical_reliability: 0.8,
1982 recent_disconnections: 2.0,
1983 avg_session_length: 2.0,
1984 connection_stability: 0.8,
1985 };
1986
1987 let patterns = predictor.analyze_patterns(&night_features).await;
1988 assert_eq!(patterns.get("night_time"), Some(&1.0));
1989 assert_eq!(patterns.get("weekend"), Some(&0.0));
1990
1991 let unstable_features = NodeFeatures {
1993 recent_disconnections: 7.0,
1994 connection_stability: 0.3,
1995 ..night_features
1996 };
1997
1998 let patterns = predictor.analyze_patterns(&unstable_features).await;
1999 assert_eq!(patterns.get("unstable"), Some(&1.0));
2000 assert_eq!(patterns.get("high_risk"), Some(&1.0));
2001 }
2002
2003 #[tokio::test]
2004 async fn test_churn_predictor_proactive_replication() -> Result<()> {
2005 let predictor = ChurnPredictor::new();
2006 let node_id = NodeId { hash: [1u8; 32] };
2007
2008 predictor
2010 .record_node_event(&node_id, NodeEvent::Connected)
2011 .await?;
2012
2013 let features = NodeFeatures {
2014 online_duration: 600.0, avg_response_time: 500.0,
2016 resource_contribution: 0.1,
2017 message_frequency: 2.0,
2018 time_of_day: 23.0,
2019 day_of_week: 5.0,
2020 historical_reliability: 0.3,
2021 recent_disconnections: 10.0,
2022 avg_session_length: 0.5,
2023 connection_stability: 0.1,
2024 };
2025
2026 predictor.update_node_behavior(&node_id, features).await?;
2027
2028 let _should_replicate = predictor.should_replicate(&node_id).await;
2030 let prediction = predictor.predict(&node_id).await;
2032 assert!(prediction.probability_1h > 0.0);
2033 Ok(())
2034 }
2035
2036 #[tokio::test]
2037 async fn test_churn_predictor_online_learning() -> Result<()> {
2038 let predictor = ChurnPredictor::new();
2039 let node_id = NodeId { hash: [1u8; 32] };
2040
2041 for i in 0..40 {
2043 let features = NodeFeatures {
2044 online_duration: (i * 1000) as f64,
2045 avg_response_time: 100.0,
2046 resource_contribution: 0.5,
2047 message_frequency: 10.0,
2048 time_of_day: 12.0,
2049 day_of_week: 3.0,
2050 historical_reliability: 0.8,
2051 recent_disconnections: (i % 5) as f64,
2052 avg_session_length: 2.0,
2053 connection_stability: 0.8,
2054 };
2055
2056 let churned = i % 3 == 0;
2058 predictor
2059 .add_training_example(&node_id, features, churned, churned, churned)
2060 .await?;
2061 }
2062
2063 let prediction = predictor.predict(&node_id).await;
2065 assert!(prediction.confidence > 0.0);
2066 Ok(())
2067 }
2068
2069 #[tokio::test]
2070 async fn test_churn_predictor_model_persistence() -> Result<()> {
2071 let predictor = ChurnPredictor::new();
2072 let temp_path = std::path::Path::new("/tmp/test_model.json");
2073
2074 predictor.save_model(temp_path).await?;
2076
2077 predictor.load_model(temp_path).await?;
2079
2080 let _ = std::fs::remove_file(temp_path);
2082 Ok(())
2083 }
2084}