Skip to main content

shadow_dht/
routing.rs

1//! Kademlia routing table implementation
2
3use shadow_core::{PeerId, PeerInfo};
4use shadow_core::error::{Result, ShadowError};
5use std::collections::VecDeque;
6use std::time::{Instant, Duration};
7
8/// K-bucket containing peers at a specific distance range
9#[derive(Debug, Clone)]
10pub struct KBucket {
11    /// Maximum bucket size (k parameter)
12    pub k: usize,
13    /// Peers in this bucket (oldest first)
14    pub peers: VecDeque<BucketEntry>,
15    /// Last update time
16    pub last_updated: Instant,
17}
18
19/// Entry in a K-bucket
20#[derive(Debug, Clone)]
21pub struct BucketEntry {
22    pub peer: PeerInfo,
23    pub last_seen: Instant,
24    pub rtt: Option<Duration>,
25    pub failed_queries: u32,
26}
27
28impl KBucket {
29    /// Create new K-bucket
30    pub fn new(k: usize) -> Self {
31        Self {
32            k,
33            peers: VecDeque::new(),
34            last_updated: Instant::now(),
35        }
36    }
37
38    /// Try to add peer to bucket
39    pub fn add_peer(&mut self, peer: PeerInfo) -> bool {
40        // Check if peer already exists
41        if let Some(pos) = self.peers.iter().position(|e| e.peer.id == peer.id) {
42            // Move to end (most recently seen)
43            let mut entry = self.peers.remove(pos).unwrap();
44            entry.last_seen = Instant::now();
45            self.peers.push_back(entry);
46            self.last_updated = Instant::now();
47            return true;
48        }
49
50        // If bucket not full, add peer
51        if self.peers.len() < self.k {
52            self.peers.push_back(BucketEntry {
53                peer,
54                last_seen: Instant::now(),
55                rtt: None,
56                failed_queries: 0,
57            });
58            self.last_updated = Instant::now();
59            return true;
60        }
61
62        // Bucket full: try to replace stale peer
63        // Check if oldest peer (front) is stale
64        if let Some(oldest) = self.peers.front() {
65            if oldest.last_seen.elapsed() > Duration::from_secs(900) || oldest.failed_queries > 3 {
66                // Replace stale peer
67                self.peers.pop_front();
68                self.peers.push_back(BucketEntry {
69                    peer,
70                    last_seen: Instant::now(),
71                    rtt: None,
72                    failed_queries: 0,
73                });
74                self.last_updated = Instant::now();
75                return true;
76            }
77        }
78
79        false // Couldn't add peer
80    }
81
82    /// Remove peer from bucket
83    pub fn remove_peer(&mut self, peer_id: &PeerId) -> bool {
84        if let Some(pos) = self.peers.iter().position(|e| &e.peer.id == peer_id) {
85            self.peers.remove(pos);
86            self.last_updated = Instant::now();
87            return true;
88        }
89        false
90    }
91
92    /// Get peer by ID
93    pub fn get_peer(&self, peer_id: &PeerId) -> Option<&PeerInfo> {
94        self.peers.iter()
95            .find(|e| &e.peer.id == peer_id)
96            .map(|e| &e.peer)
97    }
98
99    /// Mark peer as failed
100    pub fn mark_failed(&mut self, peer_id: &PeerId) {
101        if let Some(entry) = self.peers.iter_mut().find(|e| &e.peer.id == peer_id) {
102            entry.failed_queries += 1;
103        }
104    }
105
106    /// Get all peers in bucket
107    pub fn all_peers(&self) -> Vec<PeerInfo> {
108        self.peers.iter().map(|e| e.peer.clone()).collect()
109    }
110
111    /// Get number of peers
112    pub fn len(&self) -> usize {
113        self.peers.len()
114    }
115
116    /// Check if bucket is empty
117    pub fn is_empty(&self) -> bool {
118        self.peers.is_empty()
119    }
120}
121
122/// Kademlia routing table
123pub struct RoutingTable {
124    /// Our peer ID
125    local_id: PeerId,
126    /// K-buckets (160 buckets for 160-bit IDs)
127    buckets: Vec<KBucket>,
128    /// K parameter (bucket size)
129    k: usize,
130}
131
132impl RoutingTable {
133    /// Create new routing table
134    pub fn new(local_id: PeerId, k: usize) -> Self {
135        let buckets = (0..160)
136            .map(|_| KBucket::new(k))
137            .collect();
138
139        Self {
140            local_id,
141            buckets,
142            k,
143        }
144    }
145
146    /// Add peer to routing table
147    pub fn add_peer(&mut self, peer: PeerInfo) -> Result<bool> {
148        if peer.id == self.local_id {
149            return Ok(false); // Don't add ourselves
150        }
151
152        let bucket_index = self.bucket_index(&peer.id)?;
153        Ok(self.buckets[bucket_index].add_peer(peer))
154    }
155
156    /// Remove peer from routing table
157    pub fn remove_peer(&mut self, peer_id: &PeerId) -> Result<bool> {
158        let bucket_index = self.bucket_index(peer_id)?;
159        Ok(self.buckets[bucket_index].remove_peer(peer_id))
160    }
161
162    /// Find closest K peers to target ID
163    pub fn find_closest(&self, target: &PeerId, count: usize) -> Vec<PeerInfo> {
164        let mut peers: Vec<(PeerInfo, [u8; 32])> = Vec::new();
165
166        // Collect all peers with their XOR distances to target
167        for bucket in &self.buckets {
168            for peer in bucket.all_peers() {
169                let distance = target.xor_distance(&peer.id);
170                peers.push((peer, distance));
171            }
172        }
173
174        // Sort by distance to target (compare byte arrays directly)
175        peers.sort_by(|a, b| a.1.cmp(&b.1));
176
177        // Return top K
178        peers.into_iter()
179            .take(count)
180            .map(|(peer, _)| peer)
181            .collect()
182    }
183
184    /// Get bucket index for peer ID (based on XOR distance)
185    fn bucket_index(&self, peer_id: &PeerId) -> Result<usize> {
186        let distance = self.local_id.xor_distance(peer_id);
187
188        // Find first differing bit (bucket index)
189        for (i, &byte) in distance.iter().enumerate() {
190            if byte != 0 {
191                // Find position of first 1 bit
192                let bit_pos = 7 - byte.leading_zeros() as usize;
193                let bucket_idx = i * 8 + bit_pos;
194                return Ok(bucket_idx.min(159)); // Max bucket index
195            }
196        }
197
198        // All bits same (shouldn't happen with different IDs)
199        Err(ShadowError::Dht("Identical peer IDs".into()))
200    }
201
202    /// Get all known peers
203    pub fn all_peers(&self) -> Vec<PeerInfo> {
204        let mut all = Vec::new();
205        for bucket in &self.buckets {
206            all.extend(bucket.all_peers());
207        }
208        all
209    }
210
211    /// Get total number of peers
212    pub fn peer_count(&self) -> usize {
213        self.buckets.iter().map(|b| b.len()).sum()
214    }
215
216    /// Mark peer as failed
217    pub fn mark_failed(&mut self, peer_id: &PeerId) -> Result<()> {
218        let bucket_index = self.bucket_index(peer_id)?;
219        self.buckets[bucket_index].mark_failed(peer_id);
220        Ok(())
221    }
222
223    /// Get our local ID
224    pub fn local_id(&self) -> &PeerId {
225        &self.local_id
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_kbucket_add() {
235        let mut bucket = KBucket::new(3);
236
237        let peer1 = PeerInfo::new(
238            PeerId::random(),
239            vec!["127.0.0.1:9000".to_string()],
240            [0u8; 32],
241            [0u8; 32],
242        );
243
244        assert!(bucket.add_peer(peer1.clone()));
245        assert_eq!(bucket.len(), 1);
246
247        // Adding same peer should move to end
248        assert!(bucket.add_peer(peer1.clone()));
249        assert_eq!(bucket.len(), 1);
250    }
251
252    #[test]
253    fn test_kbucket_full() {
254        let mut bucket = KBucket::new(2);
255
256        let peer1 = PeerInfo::new(PeerId::random(), vec![], [0u8; 32], [0u8; 32]);
257        let peer2 = PeerInfo::new(PeerId::random(), vec![], [0u8; 32], [0u8; 32]);
258        let peer3 = PeerInfo::new(PeerId::random(), vec![], [0u8; 32], [0u8; 32]);
259
260        assert!(bucket.add_peer(peer1));
261        assert!(bucket.add_peer(peer2));
262        assert!(!bucket.add_peer(peer3)); // Bucket full, can't add
263
264        assert_eq!(bucket.len(), 2);
265    }
266
267    #[test]
268    fn test_routing_table() {
269        let local_id = PeerId::random();
270        let mut table = RoutingTable::new(local_id, 20);
271
272        let peer1 = PeerInfo::new(
273            PeerId::random(),
274            vec!["127.0.0.1:9001".to_string()],
275            [0u8; 32],
276            [0u8; 32],
277        );
278
279        assert!(table.add_peer(peer1.clone()).unwrap());
280        assert_eq!(table.peer_count(), 1);
281
282        // Don't add ourselves
283        let self_peer = PeerInfo::new(local_id, vec![], [0u8; 32], [0u8; 32]);
284        assert!(!table.add_peer(self_peer).unwrap());
285    }
286
287    #[test]
288    fn test_find_closest() {
289        let local_id = PeerId::random();
290        let mut table = RoutingTable::new(local_id, 20);
291
292        // Add several peers
293        for _ in 0..10 {
294            let peer = PeerInfo::new(PeerId::random(), vec![], [0u8; 32], [0u8; 32]);
295            table.add_peer(peer).unwrap();
296        }
297
298        let target = PeerId::random();
299        let closest = table.find_closest(&target, 5);
300
301        assert_eq!(closest.len(), 5);
302    }
303}