1use shadow_core::{PeerId, PeerInfo};
4use shadow_core::error::{Result, ShadowError};
5use crate::routing::RoutingTable;
6use crate::store::{DHTStore, StoredValue};
7use bytes::Bytes;
8use tokio::time::{interval, Duration};
9use std::sync::{Arc, RwLock};
10
11#[derive(Debug, Clone)]
13pub struct NodeConfig {
14 pub k: usize,
16 pub alpha: usize,
18 pub max_storage: usize,
20 pub value_ttl: u64,
22 pub republish_interval: u64,
24}
25
26impl Default for NodeConfig {
27 fn default() -> Self {
28 Self {
29 k: 20,
30 alpha: 3,
31 max_storage: 100 * 1024 * 1024, value_ttl: 3600, republish_interval: 3600,
34 }
35 }
36}
37
38pub struct DHTNode {
40 local_id: PeerId,
42 routing: Arc<RwLock<RoutingTable>>,
44 store: Arc<RwLock<DHTStore>>,
46 config: NodeConfig,
48}
49
50impl DHTNode {
51 pub fn new(local_id: PeerId, config: NodeConfig) -> Self {
53 let routing = Arc::new(RwLock::new(RoutingTable::new(local_id, config.k)));
54 let store = Arc::new(RwLock::new(DHTStore::new(config.max_storage)));
55
56 Self {
57 local_id,
58 routing,
59 store,
60 config,
61 }
62 }
63
64 pub fn local_id(&self) -> PeerId {
66 self.local_id
67 }
68
69 pub fn add_peer(&self, peer: PeerInfo) -> Result<bool> {
71 self.routing.write().unwrap().add_peer(peer)
72 }
73
74 pub fn remove_peer(&self, peer_id: &PeerId) -> Result<bool> {
76 self.routing.write().unwrap().remove_peer(peer_id)
77 }
78
79 pub fn find_closest_peers(&self, target: &PeerId, count: usize) -> Vec<PeerInfo> {
81 self.routing.read().unwrap().find_closest(target, count)
82 }
83
84 pub fn store_value(&self, key: [u8; 32], data: Bytes, publisher: [u8; 32]) -> Result<()> {
86 let value = StoredValue::new(data, publisher, self.config.value_ttl);
87 self.store.write().unwrap().put(key, value)
88 }
89
90 pub fn get_value(&self, key: &[u8; 32]) -> Option<StoredValue> {
92 self.store.read().unwrap().get(key).cloned()
93 }
94
95 pub fn has_value(&self, key: &[u8; 32]) -> bool {
97 self.store.read().unwrap().contains(key)
98 }
99
100 pub fn remove_value(&self, key: &[u8; 32]) -> Option<StoredValue> {
102 self.store.write().unwrap().remove(key)
103 }
104
105 pub fn all_peers(&self) -> Vec<PeerInfo> {
107 self.routing.read().unwrap().all_peers()
108 }
109
110 pub fn peer_count(&self) -> usize {
112 self.routing.read().unwrap().peer_count()
113 }
114
115 pub fn value_count(&self) -> usize {
117 self.store.read().unwrap().len()
118 }
119
120 pub fn storage_size(&self) -> usize {
122 self.store.read().unwrap().size()
123 }
124
125 pub fn mark_peer_failed(&self, peer_id: &PeerId) -> Result<()> {
127 self.routing.write().unwrap().mark_failed(peer_id)
128 }
129
130 pub fn cleanup_expired(&self) -> usize {
132 self.store.write().unwrap().cleanup_expired()
133 }
134
135 pub async fn start_maintenance(self: Arc<Self>) {
137 let node_clone = self.clone();
139 tokio::spawn(async move {
140 let mut cleanup_interval = interval(Duration::from_secs(300)); loop {
143 cleanup_interval.tick().await;
144 let cleaned = node_clone.cleanup_expired();
145 tracing::debug!("Cleaned {} expired values", cleaned);
146 }
147 });
148
149 let node_clone = self.clone();
151 tokio::spawn(async move {
152 let mut republish_interval = interval(Duration::from_secs(node_clone.config.republish_interval));
153
154 loop {
155 republish_interval.tick().await;
156 node_clone.republish_values();
157 }
158 });
159 }
160
161 fn republish_values(&self) {
163 let threshold = Duration::from_secs(self.config.republish_interval / 2);
164 let expiring = self.store.read().unwrap().get_expiring_soon(threshold);
165
166 tracing::debug!("Found {} values expiring soon", expiring.len());
167
168 }
173
174 pub fn stats(&self) -> DHTStats {
176 DHTStats {
177 peer_count: self.peer_count(),
178 value_count: self.value_count(),
179 storage_size: self.storage_size(),
180 max_storage: self.config.max_storage,
181 }
182 }
183}
184
185#[derive(Debug, Clone)]
187pub struct DHTStats {
188 pub peer_count: usize,
189 pub value_count: usize,
190 pub storage_size: usize,
191 pub max_storage: usize,
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_dht_node_creation() {
200 let local_id = PeerId::random();
201 let config = NodeConfig::default();
202 let node = DHTNode::new(local_id, config);
203
204 assert_eq!(node.local_id(), local_id);
205 assert_eq!(node.peer_count(), 0);
206 assert_eq!(node.value_count(), 0);
207 }
208
209 #[test]
210 fn test_add_peer() {
211 let local_id = PeerId::random();
212 let node = DHTNode::new(local_id, NodeConfig::default());
213
214 let peer = PeerInfo::new(
215 PeerId::random(),
216 vec!["127.0.0.1:9000".to_string()],
217 [0u8; 32],
218 [0u8; 32],
219 );
220
221 node.add_peer(peer).unwrap();
222 assert_eq!(node.peer_count(), 1);
223 }
224
225 #[test]
226 fn test_store_retrieve() {
227 let local_id = PeerId::random();
228 let node = DHTNode::new(local_id, NodeConfig::default());
229
230 let key = [1u8; 32];
231 let data = Bytes::from("test data");
232 let publisher = [2u8; 32];
233
234 node.store_value(key, data.clone(), publisher).unwrap();
235
236 let retrieved = node.get_value(&key).unwrap();
237 assert_eq!(retrieved.data, data);
238 assert_eq!(retrieved.publisher, publisher);
239 }
240
241 #[test]
242 fn test_find_closest() {
243 let local_id = PeerId::random();
244 let node = DHTNode::new(local_id, NodeConfig::default());
245
246 for _ in 0..10 {
248 let peer = PeerInfo::new(PeerId::random(), vec![], [0u8; 32], [0u8; 32]);
249 node.add_peer(peer).unwrap();
250 }
251
252 let target = PeerId::random();
253 let closest = node.find_closest_peers(&target, 3);
254
255 assert_eq!(closest.len(), 3);
256 }
257
258 #[test]
259 fn test_stats() {
260 let local_id = PeerId::random();
261 let node = DHTNode::new(local_id, NodeConfig::default());
262
263 let stats = node.stats();
264 assert_eq!(stats.peer_count, 0);
265 assert_eq!(stats.value_count, 0);
266 }
267}