Skip to main content

shadow_dht/
node.rs

1//! DHT node implementation combining routing and storage
2
3use shadow_core::{PeerId, PeerInfo};
4use shadow_core::error::{Result, ShadowError};
5use crate::routing::RoutingTable;
6use crate::store::{DHTStore, StoredValue};
7use bytes::Bytes;
8use tokio::time::{interval, Duration};
9use std::sync::{Arc, RwLock};
10
11/// DHT node configuration
12#[derive(Debug, Clone)]
13pub struct NodeConfig {
14    /// K-bucket size
15    pub k: usize,
16    /// Alpha parameter for parallel queries
17    pub alpha: usize,
18    /// Storage size limit
19    pub max_storage: usize,
20    /// Value TTL (seconds)
21    pub value_ttl: u64,
22    /// Republish interval (seconds)
23    pub republish_interval: u64,
24}
25
26impl Default for NodeConfig {
27    fn default() -> Self {
28        Self {
29            k: 20,
30            alpha: 3,
31            max_storage: 100 * 1024 * 1024, // 100 MB
32            value_ttl: 3600, // 1 hour
33            republish_interval: 3600,
34        }
35    }
36}
37
38/// DHT node
39pub struct DHTNode {
40    /// Our peer ID
41    local_id: PeerId,
42    /// Routing table
43    routing: Arc<RwLock<RoutingTable>>,
44    /// Local storage
45    store: Arc<RwLock<DHTStore>>,
46    /// Configuration
47    config: NodeConfig,
48}
49
50impl DHTNode {
51    /// Create new DHT node
52    pub fn new(local_id: PeerId, config: NodeConfig) -> Self {
53        let routing = Arc::new(RwLock::new(RoutingTable::new(local_id, config.k)));
54        let store = Arc::new(RwLock::new(DHTStore::new(config.max_storage)));
55
56        Self {
57            local_id,
58            routing,
59            store,
60            config,
61        }
62    }
63
64    /// Get our peer ID
65    pub fn local_id(&self) -> PeerId {
66        self.local_id
67    }
68
69    /// Add peer to routing table
70    pub fn add_peer(&self, peer: PeerInfo) -> Result<bool> {
71        self.routing.write().unwrap().add_peer(peer)
72    }
73
74    /// Remove peer from routing table
75    pub fn remove_peer(&self, peer_id: &PeerId) -> Result<bool> {
76        self.routing.write().unwrap().remove_peer(peer_id)
77    }
78
79    /// Find closest peers to target
80    pub fn find_closest_peers(&self, target: &PeerId, count: usize) -> Vec<PeerInfo> {
81        self.routing.read().unwrap().find_closest(target, count)
82    }
83
84    /// Store value locally
85    pub fn store_value(&self, key: [u8; 32], data: Bytes, publisher: [u8; 32]) -> Result<()> {
86        let value = StoredValue::new(data, publisher, self.config.value_ttl);
87        self.store.write().unwrap().put(key, value)
88    }
89
90    /// Get value from local storage
91    pub fn get_value(&self, key: &[u8; 32]) -> Option<StoredValue> {
92        self.store.read().unwrap().get(key).cloned()
93    }
94
95    /// Check if we have a value
96    pub fn has_value(&self, key: &[u8; 32]) -> bool {
97        self.store.read().unwrap().contains(key)
98    }
99
100    /// Remove value from storage
101    pub fn remove_value(&self, key: &[u8; 32]) -> Option<StoredValue> {
102        self.store.write().unwrap().remove(key)
103    }
104
105    /// Get all known peers
106    pub fn all_peers(&self) -> Vec<PeerInfo> {
107        self.routing.read().unwrap().all_peers()
108    }
109
110    /// Get peer count
111    pub fn peer_count(&self) -> usize {
112        self.routing.read().unwrap().peer_count()
113    }
114
115    /// Get stored value count
116    pub fn value_count(&self) -> usize {
117        self.store.read().unwrap().len()
118    }
119
120    /// Get storage size
121    pub fn storage_size(&self) -> usize {
122        self.store.read().unwrap().size()
123    }
124
125    /// Mark peer as failed
126    pub fn mark_peer_failed(&self, peer_id: &PeerId) -> Result<()> {
127        self.routing.write().unwrap().mark_failed(peer_id)
128    }
129
130    /// Cleanup expired values
131    pub fn cleanup_expired(&self) -> usize {
132        self.store.write().unwrap().cleanup_expired()
133    }
134
135    /// Start background maintenance tasks
136    pub async fn start_maintenance(self: Arc<Self>) {
137        // Spawn cleanup task
138        let node_clone = self.clone();
139        tokio::spawn(async move {
140            let mut cleanup_interval = interval(Duration::from_secs(300)); // 5 minutes
141            
142            loop {
143                cleanup_interval.tick().await;
144                let cleaned = node_clone.cleanup_expired();
145                tracing::debug!("Cleaned {} expired values", cleaned);
146            }
147        });
148
149        // Spawn republish task
150        let node_clone = self.clone();
151        tokio::spawn(async move {
152            let mut republish_interval = interval(Duration::from_secs(node_clone.config.republish_interval));
153            
154            loop {
155                republish_interval.tick().await;
156                node_clone.republish_values();
157            }
158        });
159    }
160
161    /// Republish values that are expiring soon
162    fn republish_values(&self) {
163        let threshold = Duration::from_secs(self.config.republish_interval / 2);
164        let expiring = self.store.read().unwrap().get_expiring_soon(threshold);
165        
166        tracing::debug!("Found {} values expiring soon", expiring.len());
167        
168        // In a real implementation, we would:
169        // 1. Find K closest nodes to each key
170        // 2. Send STORE RPCs to those nodes
171        // For now, just log
172    }
173
174    /// Get DHT statistics
175    pub fn stats(&self) -> DHTStats {
176        DHTStats {
177            peer_count: self.peer_count(),
178            value_count: self.value_count(),
179            storage_size: self.storage_size(),
180            max_storage: self.config.max_storage,
181        }
182    }
183}
184
185/// DHT statistics
186#[derive(Debug, Clone)]
187pub struct DHTStats {
188    pub peer_count: usize,
189    pub value_count: usize,
190    pub storage_size: usize,
191    pub max_storage: usize,
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_dht_node_creation() {
200        let local_id = PeerId::random();
201        let config = NodeConfig::default();
202        let node = DHTNode::new(local_id, config);
203
204        assert_eq!(node.local_id(), local_id);
205        assert_eq!(node.peer_count(), 0);
206        assert_eq!(node.value_count(), 0);
207    }
208
209    #[test]
210    fn test_add_peer() {
211        let local_id = PeerId::random();
212        let node = DHTNode::new(local_id, NodeConfig::default());
213
214        let peer = PeerInfo::new(
215            PeerId::random(),
216            vec!["127.0.0.1:9000".to_string()],
217            [0u8; 32],
218            [0u8; 32],
219        );
220
221        node.add_peer(peer).unwrap();
222        assert_eq!(node.peer_count(), 1);
223    }
224
225    #[test]
226    fn test_store_retrieve() {
227        let local_id = PeerId::random();
228        let node = DHTNode::new(local_id, NodeConfig::default());
229
230        let key = [1u8; 32];
231        let data = Bytes::from("test data");
232        let publisher = [2u8; 32];
233
234        node.store_value(key, data.clone(), publisher).unwrap();
235
236        let retrieved = node.get_value(&key).unwrap();
237        assert_eq!(retrieved.data, data);
238        assert_eq!(retrieved.publisher, publisher);
239    }
240
241    #[test]
242    fn test_find_closest() {
243        let local_id = PeerId::random();
244        let node = DHTNode::new(local_id, NodeConfig::default());
245
246        // Add multiple peers
247        for _ in 0..10 {
248            let peer = PeerInfo::new(PeerId::random(), vec![], [0u8; 32], [0u8; 32]);
249            node.add_peer(peer).unwrap();
250        }
251
252        let target = PeerId::random();
253        let closest = node.find_closest_peers(&target, 3);
254
255        assert_eq!(closest.len(), 3);
256    }
257
258    #[test]
259    fn test_stats() {
260        let local_id = PeerId::random();
261        let node = DHTNode::new(local_id, NodeConfig::default());
262
263        let stats = node.stats();
264        assert_eq!(stats.peer_count, 0);
265        assert_eq!(stats.value_count, 0);
266    }
267}