saorsa_core/adaptive/
gossip.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: david@saorsalabs.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! Adaptive GossipSub implementation
15//!
16//! Enhanced gossip protocol with adaptive mesh degree, peer scoring,
17//! and priority message types
18
19use super::*;
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, HashSet, VecDeque};
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24use tokio::sync::{RwLock, mpsc};
25
26// Type aliases to reduce type complexity for channels
27type GossipMessageRx = mpsc::Receiver<(NodeId, GossipMessage)>;
28type ControlMessageTx = mpsc::Sender<(NodeId, ControlMessage)>;
29
30/// Topic identifier for gossip messages
31pub type Topic = String;
32
33/// Message identifier
34pub type MessageId = [u8; 32];
35
36/// Control messages for gossip protocol
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum ControlMessage {
39    Graft {
40        topic: Topic,
41    },
42    Prune {
43        topic: Topic,
44        backoff: Duration,
45    },
46    IHave {
47        topic: Topic,
48        message_ids: Vec<MessageId>,
49    },
50    IWant {
51        message_ids: Vec<MessageId>,
52    },
53}
54
55/// Topic priority levels
56#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
57pub enum TopicPriority {
58    Low,
59    Normal,
60    High,
61    Critical,
62}
63
64/// Message validation trait
65#[async_trait::async_trait]
66pub trait MessageValidator: Send + Sync {
67    /// Validate a message before propagation
68    async fn validate(&self, message: &GossipMessage) -> Result<bool>;
69}
70
71/// Gossip statistics
72#[derive(Debug, Clone, Default)]
73pub struct GossipStats {
74    /// Total messages sent
75    pub messages_sent: u64,
76
77    /// Total messages received
78    pub messages_received: u64,
79
80    /// Current mesh size
81    pub mesh_size: usize,
82
83    /// Number of active topics
84    pub topic_count: usize,
85
86    /// Total peers
87    pub peer_count: usize,
88
89    /// Messages by topic
90    pub messages_by_topic: HashMap<Topic, u64>,
91}
92
93/// Adaptive GossipSub implementation
94pub struct AdaptiveGossipSub {
95    /// Local node ID
96    _local_id: NodeId,
97
98    /// Mesh peers for each topic
99    mesh: Arc<RwLock<HashMap<Topic, HashSet<NodeId>>>>,
100
101    /// Fanout peers for topics we're not subscribed to
102    fanout: Arc<RwLock<HashMap<Topic, HashSet<NodeId>>>>,
103
104    /// Seen messages cache
105    seen_messages: Arc<RwLock<HashMap<MessageId, Instant>>>,
106
107    /// Message cache for IWANT requests
108    message_cache: Arc<RwLock<HashMap<MessageId, GossipMessage>>>,
109
110    /// Peer scores
111    peer_scores: Arc<RwLock<HashMap<NodeId, PeerScore>>>,
112
113    /// Topic parameters
114    topics: Arc<RwLock<HashMap<Topic, TopicParams>>>,
115
116    /// Topic priorities
117    topic_priorities: Arc<RwLock<HashMap<Topic, TopicPriority>>>,
118
119    /// Heartbeat interval
120    _heartbeat_interval: Duration,
121
122    /// Message validators by topic
123    message_validators: Arc<RwLock<HashMap<Topic, Box<dyn MessageValidator + Send + Sync>>>>,
124
125    /// Trust provider for peer scoring
126    trust_provider: Arc<dyn TrustProvider>,
127
128    /// Message receiver channel
129    _message_rx: Arc<RwLock<Option<GossipMessageRx>>>,
130
131    /// Control message sender
132    control_tx: Arc<RwLock<Option<ControlMessageTx>>>,
133
134    /// Churn detector
135    churn_detector: Arc<RwLock<ChurnDetector>>,
136
137    /// Statistics
138    stats: Arc<RwLock<GossipStats>>,
139}
140
141/// Gossip message
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct GossipMessage {
144    pub topic: Topic,
145    pub data: Vec<u8>,
146    pub from: NodeId,
147    pub seqno: u64,
148    pub timestamp: u64,
149}
150
151/// Peer score tracking
152#[derive(Debug, Clone)]
153pub struct PeerScore {
154    pub time_in_mesh: Duration,
155    pub first_message_deliveries: u64,
156    pub mesh_message_deliveries: u64,
157    pub invalid_messages: u64,
158    pub behavior_penalty: f64,
159    pub app_specific_score: f64, // From trust system
160}
161
162impl PeerScore {
163    #[allow(dead_code)]
164    fn new() -> Self {
165        Self {
166            time_in_mesh: Duration::ZERO,
167            first_message_deliveries: 0,
168            mesh_message_deliveries: 0,
169            invalid_messages: 0,
170            behavior_penalty: 0.0,
171            app_specific_score: 0.5,
172        }
173    }
174
175    pub fn score(&self) -> f64 {
176        let time_score = (self.time_in_mesh.as_secs() as f64 / 60.0).min(10.0) * 0.5;
177        let delivery_score = (self.first_message_deliveries as f64).min(100.0) / 100.0;
178        let mesh_score = (self.mesh_message_deliveries as f64).min(1000.0) / 1000.0 * 0.2;
179        let invalid_penalty = self.invalid_messages as f64 * -10.0;
180
181        time_score
182            + delivery_score
183            + mesh_score
184            + invalid_penalty
185            + self.behavior_penalty
186            + self.app_specific_score
187    }
188}
189
190/// Topic parameters
191#[derive(Debug, Clone)]
192pub struct TopicParams {
193    pub d: usize,                // Target mesh degree
194    pub d_low: usize,            // Lower bound
195    pub d_high: usize,           // Upper bound
196    pub d_out: usize,            // Outbound degree for neighbor exchange
197    pub graylist_threshold: f64, // Score below which peers are graylisted
198    pub mesh_message_deliveries_threshold: f64,
199    pub gossip_factor: f64, // % of peers to send IHave to
200    pub priority: TopicPriority,
201}
202
203impl Default for TopicParams {
204    fn default() -> Self {
205        Self {
206            d: 8,
207            d_low: 6,
208            d_high: 12,
209            d_out: 2,
210            graylist_threshold: -1.0,
211            mesh_message_deliveries_threshold: 0.5,
212            gossip_factor: 0.25,
213            priority: TopicPriority::Normal,
214        }
215    }
216}
217
218/// Churn detection and tracking
219#[derive(Debug, Clone)]
220pub struct ChurnDetector {
221    /// Recent peer join/leave events
222    events: VecDeque<(Instant, ChurnEvent)>,
223    /// Window size for churn calculation
224    window: Duration,
225    /// Current churn rate
226    churn_rate: f64,
227}
228
229#[derive(Debug, Clone)]
230#[allow(dead_code)]
231enum ChurnEvent {
232    PeerJoined(NodeId),
233    PeerLeft(NodeId),
234}
235
236/// Churn statistics for a time window
237#[derive(Debug)]
238pub struct ChurnStats {
239    /// Number of nodes that joined
240    pub joins: usize,
241    /// Number of nodes that left
242    pub leaves: usize,
243    /// Average session duration
244    pub avg_session_duration: Duration,
245    /// Node join times for uptime calculation
246    node_join_times: HashMap<NodeId, Instant>,
247}
248
249impl ChurnStats {
250    /// Get uptime for a specific node
251    pub fn get_node_uptime(&self, node_id: &NodeId) -> Duration {
252        self.node_join_times
253            .get(node_id)
254            .map(|join_time| Instant::now().duration_since(*join_time))
255            .unwrap_or(Duration::from_secs(0))
256    }
257}
258
259impl ChurnDetector {
260    fn new() -> Self {
261        Self {
262            events: VecDeque::new(),
263            window: Duration::from_secs(300), // 5 minute window
264            churn_rate: 0.0,
265        }
266    }
267
268    fn record_join(&mut self, peer: NodeId) {
269        self.events
270            .push_back((Instant::now(), ChurnEvent::PeerJoined(peer)));
271        self.update_rate();
272    }
273
274    fn record_leave(&mut self, peer: NodeId) {
275        self.events
276            .push_back((Instant::now(), ChurnEvent::PeerLeft(peer)));
277        self.update_rate();
278    }
279
280    fn update_rate(&mut self) {
281        // Use checked_sub to avoid panic on Windows when program uptime < window
282        if let Some(cutoff) = Instant::now().checked_sub(self.window) {
283            self.events.retain(|(time, _)| *time > cutoff);
284        }
285
286        let joins = self
287            .events
288            .iter()
289            .filter(|(_, event)| matches!(event, ChurnEvent::PeerJoined(_)))
290            .count();
291        let leaves = self
292            .events
293            .iter()
294            .filter(|(_, event)| matches!(event, ChurnEvent::PeerLeft(_)))
295            .count();
296
297        // Churn rate as percentage of changes
298        self.churn_rate = (joins + leaves) as f64 / self.window.as_secs() as f64;
299    }
300
301    fn get_rate(&self) -> f64 {
302        self.churn_rate
303    }
304
305    pub async fn get_hourly_rates(&self, hours: usize) -> Vec<f64> {
306        let now = Instant::now();
307        let mut hourly_rates = vec![0.0; hours];
308
309        for (time, event) in &self.events {
310            let age = now.duration_since(*time);
311            let hour_index = (age.as_secs() / 3600) as usize;
312
313            if hour_index < hours {
314                match event {
315                    ChurnEvent::PeerJoined(_) | ChurnEvent::PeerLeft(_) => {
316                        hourly_rates[hour_index] += 1.0;
317                    }
318                }
319            }
320        }
321
322        // Normalize to rates
323        for rate in &mut hourly_rates {
324            *rate /= 3600.0; // Events per second
325        }
326
327        hourly_rates
328    }
329
330    pub async fn get_recent_stats(&self, window: Duration) -> ChurnStats {
331        let now = Instant::now();
332        let mut joins = 0;
333        let mut leaves = 0;
334        let mut _session_durations = Vec::new();
335        let mut _node_join_times = HashMap::new();
336
337        for (time, event) in &self.events {
338            if now.duration_since(*time) <= window {
339                match event {
340                    ChurnEvent::PeerJoined(node_id) => {
341                        joins += 1;
342                        _node_join_times.insert(node_id.clone(), *time);
343                    }
344                    ChurnEvent::PeerLeft(_) => leaves += 1,
345                }
346            }
347        }
348
349        let avg_session_duration = if _session_durations.is_empty() {
350            Duration::from_secs(3600) // Default 1 hour
351        } else {
352            Duration::from_secs(
353                _session_durations
354                    .iter()
355                    .map(|d: &Duration| d.as_secs())
356                    .sum::<u64>()
357                    / _session_durations.len() as u64,
358            )
359        };
360
361        ChurnStats {
362            joins,
363            leaves,
364            avg_session_duration,
365            node_join_times: _node_join_times,
366        }
367    }
368}
369
370impl AdaptiveGossipSub {
371    /// Create a new adaptive gossipsub instance
372    pub fn new(local_id: NodeId, trust_provider: Arc<dyn TrustProvider>) -> Self {
373        let (control_tx, _control_rx) = mpsc::channel(1000);
374        let (_message_tx, message_rx) = mpsc::channel(1000);
375
376        Self {
377            _local_id: local_id,
378            mesh: Arc::new(RwLock::new(HashMap::new())),
379            fanout: Arc::new(RwLock::new(HashMap::new())),
380            seen_messages: Arc::new(RwLock::new(HashMap::new())),
381            message_cache: Arc::new(RwLock::new(HashMap::new())),
382            peer_scores: Arc::new(RwLock::new(HashMap::new())),
383            topics: Arc::new(RwLock::new(HashMap::new())),
384            topic_priorities: Arc::new(RwLock::new(HashMap::new())),
385            _heartbeat_interval: Duration::from_secs(1),
386            message_validators: Arc::new(RwLock::new(HashMap::new())),
387            trust_provider,
388            _message_rx: Arc::new(RwLock::new(Some(message_rx))),
389            control_tx: Arc::new(RwLock::new(Some(control_tx))),
390            churn_detector: Arc::new(RwLock::new(ChurnDetector::new())),
391            stats: Arc::new(RwLock::new(GossipStats::default())),
392        }
393    }
394
395    /// Subscribe to a topic
396    pub async fn subscribe(&self, topic: &str) -> Result<()> {
397        let mut topics = self.topics.write().await;
398        topics
399            .entry(topic.to_string())
400            .or_insert_with(TopicParams::default);
401
402        let mut mesh = self.mesh.write().await;
403        mesh.insert(topic.to_string(), HashSet::new());
404
405        Ok(())
406    }
407
408    /// Unsubscribe from a topic
409    pub async fn unsubscribe(&self, topic: &str) -> Result<()> {
410        let mut mesh = self.mesh.write().await;
411        mesh.remove(topic);
412
413        Ok(())
414    }
415
416    /// Publish a message to a topic
417    pub async fn publish(&self, topic: &str, message: GossipMessage) -> Result<()> {
418        // Validate message before publishing
419        if !self.validate_message(&message).await? {
420            return Err(AdaptiveNetworkError::Gossip(
421                "Message validation failed".to_string(),
422            ));
423        }
424
425        let msg_id = self.compute_message_id(&message);
426
427        // Add to seen messages and cache
428        {
429            let mut seen = self.seen_messages.write().await;
430            seen.insert(msg_id, Instant::now());
431
432            let mut cache = self.message_cache.write().await;
433            cache.insert(msg_id, message.clone());
434        }
435
436        // Send to mesh peers
437        let mesh = self.mesh.read().await;
438        if let Some(mesh_peers) = mesh.get(topic) {
439            for peer in mesh_peers {
440                // In real implementation, send via network
441                self.send_message(peer, &message).await?;
442            }
443        } else {
444            // Use fanout if not subscribed
445            let fanout = self.fanout.read().await;
446            let fanout_peers = fanout
447                .get(topic)
448                .cloned()
449                .unwrap_or_else(|| self.get_fanout_peers(topic).unwrap_or_default());
450
451            for peer in &fanout_peers {
452                self.send_message(peer, &message).await?;
453            }
454        }
455
456        Ok(())
457    }
458
459    /// Send GRAFT control message
460    pub async fn send_graft(&self, peer: &NodeId, topic: &str) -> Result<()> {
461        let control_tx = self.control_tx.read().await;
462        if let Some(tx) = control_tx.as_ref() {
463            let msg = ControlMessage::Graft {
464                topic: topic.to_string(),
465            };
466            tx.send((peer.clone(), msg))
467                .await
468                .map_err(|_| AdaptiveNetworkError::Other("Failed to send GRAFT".to_string()))?;
469        }
470        Ok(())
471    }
472
473    /// Send PRUNE control message
474    pub async fn send_prune(&self, peer: &NodeId, topic: &str, backoff: Duration) -> Result<()> {
475        let control_tx = self.control_tx.read().await;
476        if let Some(tx) = control_tx.as_ref() {
477            let msg = ControlMessage::Prune {
478                topic: topic.to_string(),
479                backoff,
480            };
481            tx.send((peer.clone(), msg))
482                .await
483                .map_err(|_| AdaptiveNetworkError::Other("Failed to send PRUNE".to_string()))?;
484        }
485        Ok(())
486    }
487
488    /// Send IHAVE control message
489    pub async fn send_ihave(
490        &self,
491        peer: &NodeId,
492        topic: &str,
493        message_ids: Vec<MessageId>,
494    ) -> Result<()> {
495        let control_tx = self.control_tx.read().await;
496        if let Some(tx) = control_tx.as_ref() {
497            let msg = ControlMessage::IHave {
498                topic: topic.to_string(),
499                message_ids,
500            };
501            tx.send((peer.clone(), msg))
502                .await
503                .map_err(|_| AdaptiveNetworkError::Other("Failed to send IHAVE".to_string()))?;
504        }
505        Ok(())
506    }
507
508    /// Send IWANT control message
509    pub async fn send_iwant(&self, peer: &NodeId, message_ids: Vec<MessageId>) -> Result<()> {
510        let control_tx = self.control_tx.read().await;
511        if let Some(tx) = control_tx.as_ref() {
512            let msg = ControlMessage::IWant { message_ids };
513            tx.send((peer.clone(), msg))
514                .await
515                .map_err(|_| AdaptiveNetworkError::Other("Failed to send IWANT".to_string()))?;
516        }
517        Ok(())
518    }
519
520    /// Handle periodic heartbeat
521    pub async fn heartbeat(&self) {
522        let mesh = self.mesh.read().await.clone();
523
524        for (topic, mesh_peers) in mesh {
525            let params = {
526                let topics = self.topics.read().await;
527                topics.get(&topic).cloned().unwrap_or_default()
528            };
529
530            // Calculate adaptive mesh size based on churn
531            let target_size = self.calculate_adaptive_mesh_size(&topic).await;
532
533            // Remove low-scoring peers
534            let mut peers_to_remove = Vec::new();
535            {
536                let scores = self.peer_scores.read().await;
537                for peer in &mesh_peers {
538                    if let Some(score) = scores.get(peer)
539                        && score.score() < params.graylist_threshold
540                    {
541                        peers_to_remove.push(peer.clone());
542                    }
543                }
544            }
545
546            // Update mesh
547            let mut mesh_write = self.mesh.write().await;
548            if let Some(topic_mesh) = mesh_write.get_mut(&topic) {
549                // Send PRUNE messages and update churn detector
550                for peer in peers_to_remove {
551                    topic_mesh.remove(&peer);
552                    let _ = self
553                        .send_prune(&peer, &topic, Duration::from_secs(60))
554                        .await;
555
556                    // Record peer leaving mesh
557                    let mut churn = self.churn_detector.write().await;
558                    churn.record_leave(peer);
559                }
560
561                // Add high-scoring peers if below target
562                while topic_mesh.len() < target_size {
563                    if let Some(peer) = self.select_peer_for_mesh(&topic, topic_mesh).await {
564                        topic_mesh.insert(peer.clone());
565                        let _ = self.send_graft(&peer, &topic).await;
566
567                        // Record peer joining mesh
568                        let mut churn = self.churn_detector.write().await;
569                        churn.record_join(peer);
570                    } else {
571                        break;
572                    }
573                }
574            }
575        }
576
577        // Update peer scores
578        self.update_peer_scores().await;
579
580        // Clean old seen messages
581        self.clean_seen_messages().await;
582    }
583
584    /// Calculate adaptive mesh size based on network conditions
585    pub async fn calculate_adaptive_mesh_size(&self, topic: &str) -> usize {
586        let base_size = 8;
587
588        // Get churn rate from detector
589        let churn_rate = {
590            let churn = self.churn_detector.read().await;
591            churn.get_rate()
592        };
593
594        // Get topic priority
595        let priority_factor = {
596            let priorities = self.topic_priorities.read().await;
597            match priorities.get(topic) {
598                Some(TopicPriority::Critical) => 2.0,
599                Some(TopicPriority::High) => 1.5,
600                Some(TopicPriority::Normal) => 1.0,
601                Some(TopicPriority::Low) => 0.8,
602                None => 1.0,
603            }
604        };
605
606        // Increase mesh size based on churn and priority
607        let churn_factor = 1.0 + (churn_rate * 0.1).min(0.5); // Max 50% increase
608
609        (base_size as f64 * churn_factor * priority_factor).round() as usize
610    }
611
612    /// Select a peer to add to mesh
613    async fn select_peer_for_mesh(
614        &self,
615        _topic: &str,
616        current_mesh: &HashSet<NodeId>,
617    ) -> Option<NodeId> {
618        // Select from known peers not in mesh, sorted by score
619        let scores = self.peer_scores.read().await;
620        let mut candidates: Vec<_> = scores
621            .iter()
622            .filter(|(peer_id, _)| !current_mesh.contains(peer_id))
623            .map(|(peer_id, score)| (peer_id.clone(), score.score()))
624            .collect();
625
626        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
627        candidates.first().map(|(peer, _)| peer.clone())
628    }
629
630    /// Update peer scores
631    async fn update_peer_scores(&self) {
632        let mut scores = self.peer_scores.write().await;
633        for (peer_id, score) in scores.iter_mut() {
634            // Update app-specific score from trust system
635            score.app_specific_score = self.trust_provider.get_trust(peer_id);
636
637            // Decay behavior penalty
638            score.behavior_penalty *= 0.99;
639        }
640    }
641
642    /// Clean old seen messages
643    async fn clean_seen_messages(&self) {
644        // Use checked_sub to avoid panic on Windows when program uptime < 5 minutes
645        if let Some(cutoff) = Instant::now().checked_sub(Duration::from_secs(300)) {
646            let mut seen = self.seen_messages.write().await;
647            seen.retain(|_, timestamp| *timestamp > cutoff);
648        }
649    }
650
651    /// Compute message ID
652    pub fn compute_message_id(&self, message: &GossipMessage) -> MessageId {
653        use sha2::{Digest, Sha256};
654        let mut hasher = Sha256::new();
655        hasher.update(message.topic.as_bytes());
656        hasher.update(message.from.hash);
657        hasher.update(message.seqno.to_le_bytes());
658        hasher.update(&message.data);
659
660        let result = hasher.finalize();
661        let mut id = [0u8; 32];
662        id.copy_from_slice(&result);
663        id
664    }
665
666    /// Send message to a peer (placeholder)
667    async fn send_message(&self, _peer: &NodeId, _message: &GossipMessage) -> Result<()> {
668        // In real implementation, send via network layer
669        Ok(())
670    }
671
672    /// Get fanout peers for a topic
673    fn get_fanout_peers(&self, _topic: &str) -> Option<HashSet<NodeId>> {
674        // In real implementation, select high-scoring peers
675        None
676    }
677
678    /// Handle incoming control message
679    pub async fn handle_control_message(
680        &self,
681        from: &NodeId,
682        message: ControlMessage,
683    ) -> Result<()> {
684        match message {
685            ControlMessage::Graft { topic } => {
686                // Peer wants to join our mesh
687                let mut mesh = self.mesh.write().await;
688                if let Some(topic_mesh) = mesh.get_mut(&topic) {
689                    // Check peer score before accepting
690                    let score = {
691                        let scores = self.peer_scores.read().await;
692                        scores.get(from).map(|s| s.score()).unwrap_or(0.0)
693                    };
694
695                    // If we have no prior score, fall back to trust provider's score
696                    let score = if score == 0.0 {
697                        self.trust_provider.get_trust(from)
698                    } else {
699                        score
700                    };
701
702                    if score > 0.0 {
703                        topic_mesh.insert(from.clone());
704                    } else {
705                        // Send PRUNE back if we don't want them
706                        let _ = self.send_prune(from, &topic, Duration::from_secs(60)).await;
707                    }
708                }
709            }
710            ControlMessage::Prune { topic, backoff: _ } => {
711                // Peer is removing us from their mesh
712                let mut mesh = self.mesh.write().await;
713                if let Some(topic_mesh) = mesh.get_mut(&topic) {
714                    topic_mesh.remove(from);
715                }
716            }
717            ControlMessage::IHave {
718                topic: _,
719                message_ids,
720            } => {
721                // Peer is announcing messages they have
722                let seen = self.seen_messages.read().await;
723                let mut want = Vec::new();
724
725                for msg_id in message_ids {
726                    if !seen.contains_key(&msg_id) {
727                        want.push(msg_id);
728                    }
729                }
730
731                if !want.is_empty() {
732                    let _ = self.send_iwant(from, want).await;
733                }
734            }
735            ControlMessage::IWant { message_ids } => {
736                // Peer wants specific messages
737                let cache = self.message_cache.read().await;
738                for msg_id in message_ids {
739                    if let Some(message) = cache.get(&msg_id) {
740                        let _ = self.send_message(from, message).await;
741                    }
742                }
743            }
744        }
745
746        Ok(())
747    }
748
749    /// Set topic priority
750    pub async fn set_topic_priority(&self, topic: &str, priority: TopicPriority) {
751        let mut priorities = self.topic_priorities.write().await;
752        priorities.insert(topic.to_string(), priority);
753    }
754
755    /// Register a message validator for a topic
756    pub async fn register_validator(
757        &self,
758        topic: &str,
759        validator: Box<dyn MessageValidator + Send + Sync>,
760    ) -> Result<()> {
761        let mut validators = self.message_validators.write().await;
762        validators.insert(topic.to_string(), validator);
763        Ok(())
764    }
765
766    /// Validate a message before processing
767    async fn validate_message(&self, message: &GossipMessage) -> Result<bool> {
768        let validators = self.message_validators.read().await;
769
770        if let Some(validator) = validators.get(&message.topic) {
771            validator.validate(message).await
772        } else {
773            // No validator registered, accept by default
774            Ok(true)
775        }
776    }
777
778    /// Reduce gossip fanout during high churn
779    pub async fn reduce_fanout(&self, factor: f64) {
780        // In a real implementation, would reduce mesh degree based on factor
781        // This would involve updating the target degree for mesh maintenance
782        let _ = factor; // Suppress unused warning
783    }
784
785    /// Get gossip statistics
786    pub async fn get_stats(&self) -> GossipStats {
787        let mut stats = self.stats.read().await.clone();
788
789        // Update current values
790        let mesh = self.mesh.read().await;
791        stats.mesh_size = mesh.values().map(|peers| peers.len()).sum();
792        stats.topic_count = mesh.len();
793
794        let peer_scores = self.peer_scores.read().await;
795        stats.peer_count = peer_scores.len();
796
797        stats
798    }
799}
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804
805    #[tokio::test]
806    async fn test_gossipsub_subscribe() {
807        struct MockTrustProvider;
808        impl TrustProvider for MockTrustProvider {
809            fn get_trust(&self, _node: &NodeId) -> f64 {
810                0.5
811            }
812            fn update_trust(&self, _from: &NodeId, _to: &NodeId, _success: bool) {}
813            fn get_global_trust(&self) -> HashMap<NodeId, f64> {
814                HashMap::new()
815            }
816            fn remove_node(&self, _node: &NodeId) {}
817        }
818
819        use crate::peer_record::UserId;
820        use rand::RngCore;
821
822        let mut hash = [0u8; 32];
823        rand::thread_rng().fill_bytes(&mut hash);
824        let local_id = UserId::from_bytes(hash);
825
826        let trust_provider = Arc::new(MockTrustProvider);
827        let gossip = AdaptiveGossipSub::new(local_id, trust_provider);
828
829        gossip.subscribe("test-topic").await.unwrap();
830
831        let mesh = gossip.mesh.read().await;
832        assert!(mesh.contains_key("test-topic"));
833    }
834
835    #[test]
836    fn test_peer_score() {
837        let mut score = PeerScore::new();
838        assert!(score.score() > 0.0);
839
840        score.invalid_messages = 5;
841        assert!(score.score() < 0.0);
842    }
843
844    #[test]
845    fn test_message_id() {
846        struct MockTrustProvider;
847        impl TrustProvider for MockTrustProvider {
848            fn get_trust(&self, _node: &NodeId) -> f64 {
849                0.5
850            }
851            fn update_trust(&self, _from: &NodeId, _to: &NodeId, _success: bool) {}
852            fn get_global_trust(&self) -> HashMap<NodeId, f64> {
853                HashMap::new()
854            }
855            fn remove_node(&self, _node: &NodeId) {}
856        }
857
858        use crate::peer_record::UserId;
859        use rand::RngCore;
860
861        let mut hash = [0u8; 32];
862        rand::thread_rng().fill_bytes(&mut hash);
863        let local_id = UserId::from_bytes(hash);
864
865        let trust_provider = Arc::new(MockTrustProvider);
866        let gossip = AdaptiveGossipSub::new(local_id, trust_provider);
867
868        let mut hash2 = [0u8; 32];
869        rand::thread_rng().fill_bytes(&mut hash2);
870
871        let msg = GossipMessage {
872            topic: "test".to_string(),
873            data: vec![1, 2, 3],
874            from: UserId::from_bytes(hash2),
875            seqno: 1,
876            timestamp: 12345,
877        };
878
879        let id1 = gossip.compute_message_id(&msg);
880        let id2 = gossip.compute_message_id(&msg);
881
882        assert_eq!(id1, id2);
883    }
884
885    #[tokio::test]
886    async fn test_adaptive_mesh_size() {
887        use crate::peer_record::UserId;
888        use rand::RngCore;
889
890        struct MockTrustProvider;
891        impl TrustProvider for MockTrustProvider {
892            fn get_trust(&self, _node: &NodeId) -> f64 {
893                0.5
894            }
895            fn update_trust(&self, _from: &NodeId, _to: &NodeId, _success: bool) {}
896            fn get_global_trust(&self) -> HashMap<NodeId, f64> {
897                HashMap::new()
898            }
899            fn remove_node(&self, _node: &NodeId) {}
900        }
901
902        let mut hash = [0u8; 32];
903        rand::thread_rng().fill_bytes(&mut hash);
904        let local_id = UserId::from_bytes(hash);
905
906        let trust_provider = Arc::new(MockTrustProvider);
907        let gossip = AdaptiveGossipSub::new(local_id, trust_provider);
908
909        // Set topic priority
910        gossip
911            .set_topic_priority("critical-topic", TopicPriority::Critical)
912            .await;
913        gossip
914            .set_topic_priority("low-topic", TopicPriority::Low)
915            .await;
916
917        // Test mesh size calculation
918        let critical_size = gossip.calculate_adaptive_mesh_size("critical-topic").await;
919        let normal_size = gossip.calculate_adaptive_mesh_size("normal-topic").await;
920        let low_size = gossip.calculate_adaptive_mesh_size("low-topic").await;
921
922        assert!(critical_size > normal_size);
923        assert!(normal_size > low_size);
924    }
925
926    #[test]
927    fn test_churn_detector() {
928        use crate::peer_record::UserId;
929        use rand::RngCore;
930
931        let mut detector = ChurnDetector::new();
932
933        // Add some join/leave events
934        for i in 0..10 {
935            let mut hash = [0u8; 32];
936            rand::thread_rng().fill_bytes(&mut hash);
937            hash[0] = i;
938            let peer = UserId::from_bytes(hash);
939
940            if i % 2 == 0 {
941                detector.record_join(peer);
942            } else {
943                detector.record_leave(peer);
944            }
945        }
946
947        let rate = detector.get_rate();
948        assert!(rate > 0.0);
949    }
950
951    #[tokio::test]
952    async fn test_control_messages() {
953        use crate::peer_record::UserId;
954        use rand::RngCore;
955
956        struct MockTrustProvider;
957        impl TrustProvider for MockTrustProvider {
958            fn get_trust(&self, _node: &NodeId) -> f64 {
959                0.8
960            }
961            fn update_trust(&self, _from: &NodeId, _to: &NodeId, _success: bool) {}
962            fn get_global_trust(&self) -> HashMap<NodeId, f64> {
963                HashMap::new()
964            }
965            fn remove_node(&self, _node: &NodeId) {}
966        }
967
968        let mut hash = [0u8; 32];
969        rand::thread_rng().fill_bytes(&mut hash);
970        let local_id = UserId::from_bytes(hash);
971
972        let trust_provider = Arc::new(MockTrustProvider);
973        let gossip = AdaptiveGossipSub::new(local_id, trust_provider);
974
975        // Subscribe to a topic
976        gossip.subscribe("test-topic").await.unwrap();
977
978        // Test GRAFT handling
979        let mut peer_hash = [0u8; 32];
980        rand::thread_rng().fill_bytes(&mut peer_hash);
981        let peer_id = UserId::from_bytes(peer_hash);
982
983        let graft_msg = ControlMessage::Graft {
984            topic: "test-topic".to_string(),
985        };
986        gossip
987            .handle_control_message(&peer_id, graft_msg)
988            .await
989            .unwrap();
990
991        // Peer should be in mesh due to good trust score
992        let mesh = gossip.mesh.read().await;
993        assert!(mesh.get("test-topic").unwrap().contains(&peer_id));
994    }
995
996    #[tokio::test]
997    async fn test_message_validation() {
998        use crate::peer_record::UserId;
999        use rand::RngCore;
1000
1001        struct MockTrustProvider;
1002        impl TrustProvider for MockTrustProvider {
1003            fn get_trust(&self, _node: &NodeId) -> f64 {
1004                0.8
1005            }
1006            fn update_trust(&self, _from: &NodeId, _to: &NodeId, _success: bool) {}
1007            fn get_global_trust(&self) -> HashMap<NodeId, f64> {
1008                HashMap::new()
1009            }
1010            fn remove_node(&self, _node: &NodeId) {}
1011        }
1012
1013        // Custom validator that rejects messages with "bad" in the data
1014        struct TestValidator;
1015        #[async_trait::async_trait]
1016        impl MessageValidator for TestValidator {
1017            async fn validate(&self, message: &GossipMessage) -> Result<bool> {
1018                Ok(!message.data.windows(3).any(|w| w == b"bad"))
1019            }
1020        }
1021
1022        let mut hash = [0u8; 32];
1023        rand::thread_rng().fill_bytes(&mut hash);
1024        let local_id = UserId::from_bytes(hash);
1025
1026        let trust_provider = Arc::new(MockTrustProvider);
1027        let gossip = AdaptiveGossipSub::new(local_id, trust_provider);
1028
1029        // Register validator
1030        gossip
1031            .register_validator("test-topic", Box::new(TestValidator))
1032            .await
1033            .unwrap();
1034
1035        // Test valid message
1036        let valid_message = GossipMessage {
1037            topic: "test-topic".to_string(),
1038            data: vec![1, 2, 3, 4], // No "bad" in data
1039            from: UserId::from_bytes([0; 32]),
1040            seqno: 1,
1041            timestamp: 12345,
1042        };
1043
1044        // Should succeed
1045        assert!(gossip.publish("test-topic", valid_message).await.is_ok());
1046
1047        // Test invalid message
1048        let invalid_message = GossipMessage {
1049            topic: "test-topic".to_string(),
1050            data: vec![b'b', b'a', b'd', b'!'], // Contains "bad"
1051            from: UserId::from_bytes([0; 32]),
1052            seqno: 2,
1053            timestamp: 12346,
1054        };
1055
1056        // Should fail validation
1057        assert!(gossip.publish("test-topic", invalid_message).await.is_err());
1058    }
1059
1060    #[tokio::test]
1061    async fn test_ihave_iwant_flow() {
1062        use crate::peer_record::UserId;
1063        use rand::RngCore;
1064
1065        struct MockTrustProvider;
1066        impl TrustProvider for MockTrustProvider {
1067            fn get_trust(&self, _node: &NodeId) -> f64 {
1068                0.8
1069            }
1070            fn update_trust(&self, _from: &NodeId, _to: &NodeId, _success: bool) {}
1071            fn get_global_trust(&self) -> HashMap<NodeId, f64> {
1072                HashMap::new()
1073            }
1074            fn remove_node(&self, _node: &NodeId) {}
1075        }
1076
1077        let mut hash = [0u8; 32];
1078        rand::thread_rng().fill_bytes(&mut hash);
1079        let local_id = UserId::from_bytes(hash);
1080
1081        let trust_provider = Arc::new(MockTrustProvider);
1082        let gossip = AdaptiveGossipSub::new(local_id, trust_provider);
1083
1084        // Create a test message
1085        let mut peer_hash = [0u8; 32];
1086        rand::thread_rng().fill_bytes(&mut peer_hash);
1087        let from_peer = UserId::from_bytes(peer_hash);
1088
1089        let message = GossipMessage {
1090            topic: "test-topic".to_string(),
1091            data: vec![1, 2, 3, 4],
1092            from: from_peer.clone(),
1093            seqno: 1,
1094            timestamp: 12345,
1095        };
1096
1097        // Publish message (adds to cache)
1098        gossip.publish("test-topic", message.clone()).await.unwrap();
1099
1100        // Message should be in cache
1101        let msg_id = gossip.compute_message_id(&message);
1102        let cache = gossip.message_cache.read().await;
1103        assert!(cache.contains_key(&msg_id));
1104    }
1105}