1use shadow_core::error::{Result, ShadowError};
4use std::collections::HashMap;
5use std::time::{Instant, Duration};
6use bytes::Bytes;
7use serde::{Serialize, Deserialize};
8
9#[derive(Debug, Clone)]
11pub struct StoredValue {
12 pub data: Bytes,
14 pub publisher: [u8; 32],
16 pub timestamp: Instant,
18 pub ttl: u64,
20 pub signature: Option<Vec<u8>>,
22}
23
24impl StoredValue {
25 pub fn new(data: Bytes, publisher: [u8; 32], ttl: u64) -> Self {
27 Self {
28 data,
29 publisher,
30 timestamp: Instant::now(),
31 ttl,
32 signature: None,
33 }
34 }
35
36 pub fn is_expired(&self) -> bool {
38 self.timestamp.elapsed() > Duration::from_secs(self.ttl)
39 }
40
41 pub fn time_until_expiry(&self) -> Option<Duration> {
43 let elapsed = self.timestamp.elapsed();
44 let ttl_duration = Duration::from_secs(self.ttl);
45
46 if elapsed < ttl_duration {
47 Some(ttl_duration - elapsed)
48 } else {
49 None
50 }
51 }
52}
53
54pub struct DHTStore {
56 store: HashMap<[u8; 32], StoredValue>,
58 max_size: usize,
60 current_size: usize,
62}
63
64impl DHTStore {
65 pub fn new(max_size: usize) -> Self {
67 Self {
68 store: HashMap::new(),
69 max_size,
70 current_size: 0,
71 }
72 }
73
74 pub fn put(&mut self, key: [u8; 32], value: StoredValue) -> Result<()> {
76 if !self.store.contains_key(&key) {
78 let value_size = value.data.len();
79
80 self.cleanup_expired();
82
83 if self.current_size + value_size > self.max_size {
85 if !self.make_space(value_size) {
87 return Err(ShadowError::Storage("Storage full".into()));
88 }
89 }
90
91 self.current_size += value_size;
92 } else {
93 if let Some(old_value) = self.store.get(&key) {
95 self.current_size -= old_value.data.len();
96 }
97 self.current_size += value.data.len();
98 }
99
100 self.store.insert(key, value);
101 Ok(())
102 }
103
104 pub fn get(&self, key: &[u8; 32]) -> Option<&StoredValue> {
106 self.store.get(key).filter(|v| !v.is_expired())
107 }
108
109 pub fn remove(&mut self, key: &[u8; 32]) -> Option<StoredValue> {
111 if let Some(value) = self.store.remove(key) {
112 self.current_size -= value.data.len();
113 Some(value)
114 } else {
115 None
116 }
117 }
118
119 pub fn contains(&self, key: &[u8; 32]) -> bool {
121 self.store.get(key).map_or(false, |v| !v.is_expired())
122 }
123
124 pub fn cleanup_expired(&mut self) -> usize {
126 let expired_keys: Vec<[u8; 32]> = self.store
127 .iter()
128 .filter(|(_, v)| v.is_expired())
129 .map(|(k, _)| *k)
130 .collect();
131
132 let count = expired_keys.len();
133 for key in expired_keys {
134 self.remove(&key);
135 }
136
137 count
138 }
139
140 fn make_space(&mut self, needed: usize) -> bool {
142 let mut entries: Vec<_> = self.store.iter().collect();
144 entries.sort_by_key(|(_, v)| v.timestamp);
145
146 let mut freed = 0;
147 let mut to_remove = Vec::new();
148
149 for (key, value) in entries {
150 if freed >= needed {
151 break;
152 }
153 to_remove.push(*key);
154 freed += value.data.len();
155 }
156
157 for key in to_remove {
159 self.remove(&key);
160 }
161
162 freed >= needed
163 }
164
165 pub fn keys(&self) -> Vec<[u8; 32]> {
167 self.store.keys().copied().collect()
168 }
169
170 pub fn len(&self) -> usize {
172 self.store.len()
173 }
174
175 pub fn is_empty(&self) -> bool {
177 self.store.is_empty()
178 }
179
180 pub fn size(&self) -> usize {
182 self.current_size
183 }
184
185 pub fn get_expiring_soon(&self, threshold: Duration) -> Vec<([u8; 32], StoredValue)> {
187 self.store
188 .iter()
189 .filter(|(_, v)| {
190 if let Some(time_left) = v.time_until_expiry() {
191 time_left < threshold
192 } else {
193 false
194 }
195 })
196 .map(|(k, v)| (*k, v.clone()))
197 .collect()
198 }
199
200 pub fn clear(&mut self) {
202 self.store.clear();
203 self.current_size = 0;
204 }
205}
206
207impl Default for DHTStore {
208 fn default() -> Self {
209 Self::new(100 * 1024 * 1024) }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_store_put_get() {
219 let mut store = DHTStore::new(1024 * 1024);
220
221 let key = [1u8; 32];
222 let value = StoredValue::new(
223 Bytes::from("test data"),
224 [2u8; 32],
225 3600,
226 );
227
228 store.put(key, value.clone()).unwrap();
229
230 let retrieved = store.get(&key).unwrap();
231 assert_eq!(retrieved.data, value.data);
232 }
233
234 #[test]
235 fn test_store_remove() {
236 let mut store = DHTStore::new(1024 * 1024);
237
238 let key = [1u8; 32];
239 let value = StoredValue::new(Bytes::from("test"), [2u8; 32], 3600);
240
241 store.put(key, value).unwrap();
242 assert!(store.contains(&key));
243
244 store.remove(&key);
245 assert!(!store.contains(&key));
246 }
247
248 #[test]
249 fn test_store_expiry() {
250 let mut store = DHTStore::new(1024 * 1024);
251
252 let key = [1u8; 32];
253 let mut value = StoredValue::new(Bytes::from("test"), [2u8; 32], 0);
254 value.timestamp = Instant::now() - Duration::from_secs(10);
255
256 store.put(key, value).unwrap();
257
258 assert!(store.get(&key).is_none());
260 }
261
262 #[test]
263 fn test_store_size_limit() {
264 let mut store = DHTStore::new(100); let key1 = [1u8; 32];
267 let key2 = [2u8; 32];
268
269 let value1 = StoredValue::new(Bytes::from(vec![0u8; 60]), [0u8; 32], 3600);
270 let value2 = StoredValue::new(Bytes::from(vec![0u8; 60]), [0u8; 32], 3600);
271
272 store.put(key1, value1).unwrap();
273
274 let result = store.put(key2, value2);
276 assert!(result.is_ok());
277 }
278
279 #[test]
280 fn test_cleanup_expired() {
281 let mut store = DHTStore::new(1024 * 1024);
282
283 let key = [1u8; 32];
285 let mut value = StoredValue::new(Bytes::from("test"), [2u8; 32], 0);
286 value.timestamp = Instant::now() - Duration::from_secs(10);
287 store.put(key, value).unwrap();
288
289 assert_eq!(store.len(), 1);
290 let cleaned = store.cleanup_expired();
291 assert_eq!(cleaned, 1);
292 assert_eq!(store.len(), 0);
293 }
294}