tycho_network/dht/
storage.rs

1use std::cell::RefCell;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::Result;
6use bytes::{Bytes, BytesMut};
7use bytesize::ByteSize;
8use moka::Expiry;
9use moka::sync::{Cache, CacheBuilder};
10use tl_proto::TlWrite;
11use tycho_util::FastDashMap;
12use tycho_util::time::now_sec;
13
14use crate::proto::dht::{MergedValue, MergedValueRef, PeerValueRef, ValueRef};
15
16type DhtCache<S> = Cache<StorageKeyId, StoredValue, S>;
17type DhtCacheBuilder<S> = CacheBuilder<StorageKeyId, StoredValue, DhtCache<S>>;
18
19#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
20pub enum DhtValueSource {
21    Local,
22    Remote,
23}
24
25pub trait DhtValueMerger: Send + Sync + 'static {
26    fn check_value(
27        &self,
28        source: DhtValueSource,
29        new: &MergedValueRef<'_>,
30    ) -> Result<(), StorageError>;
31
32    fn merge_value(
33        &self,
34        source: DhtValueSource,
35        new: &MergedValueRef<'_>,
36        stored: &mut MergedValue,
37    ) -> bool;
38}
39
40pub(crate) struct StorageBuilder {
41    cache_builder: DhtCacheBuilder<std::hash::RandomState>,
42    value_mergers: FastDashMap<[u8; 32], Arc<dyn DhtValueMerger>>,
43    max_ttl: Duration,
44}
45
46impl Default for StorageBuilder {
47    fn default() -> Self {
48        Self {
49            cache_builder: Default::default(),
50            value_mergers: Default::default(),
51            max_ttl: Duration::from_secs(3600),
52        }
53    }
54}
55
56impl StorageBuilder {
57    pub fn build(self) -> Storage {
58        fn weigher(_key: &StorageKeyId, value: &StoredValue) -> u32 {
59            std::mem::size_of::<StorageKeyId>() as u32
60                + std::mem::size_of::<StoredValue>() as u32
61                + value.data.len() as u32
62        }
63
64        Storage {
65            cache: self
66                .cache_builder
67                .time_to_live(self.max_ttl)
68                .weigher(weigher)
69                .expire_after(ValueExpiry)
70                .build_with_hasher(ahash::RandomState::default()),
71            value_mergers: self.value_mergers,
72            max_ttl_sec: self.max_ttl.as_secs().try_into().unwrap_or(u32::MAX),
73        }
74    }
75
76    #[allow(unused)]
77    pub fn with_value_merger(
78        self,
79        group_id: &[u8; 32],
80        value_merger: Arc<dyn DhtValueMerger>,
81    ) -> Self {
82        self.value_mergers.insert(*group_id, value_merger);
83        self
84    }
85
86    pub fn with_max_capacity(mut self, max_capacity: ByteSize) -> Self {
87        self.cache_builder = self.cache_builder.max_capacity(max_capacity.0);
88        self
89    }
90
91    pub fn with_max_ttl(mut self, ttl: Duration) -> Self {
92        self.max_ttl = ttl;
93        self
94    }
95
96    pub fn with_max_idle(mut self, duration: Duration) -> Self {
97        self.cache_builder = self.cache_builder.time_to_idle(duration);
98        self
99    }
100}
101
102pub(crate) struct Storage {
103    cache: DhtCache<ahash::RandomState>,
104    value_mergers: FastDashMap<[u8; 32], Arc<dyn DhtValueMerger>>,
105    max_ttl_sec: u32,
106}
107
108impl Storage {
109    pub fn builder() -> StorageBuilder {
110        StorageBuilder::default()
111    }
112
113    pub fn insert_merger(
114        &self,
115        group_id: &[u8; 32],
116        merger: Arc<dyn DhtValueMerger>,
117    ) -> Option<Arc<dyn DhtValueMerger>> {
118        self.value_mergers.insert(*group_id, merger)
119    }
120
121    pub fn remove_merger(&self, group_id: &[u8; 32]) -> Option<Arc<dyn DhtValueMerger>> {
122        self.value_mergers
123            .remove(group_id)
124            .map(|(_, merger)| merger)
125    }
126
127    pub fn get(&self, key: &[u8; 32]) -> Option<Bytes> {
128        let stored_value = self.cache.get(key)?;
129        (stored_value.expires_at > now_sec()).then_some(stored_value.data)
130    }
131
132    pub fn insert(
133        &self,
134        source: DhtValueSource,
135        value: &ValueRef<'_>,
136    ) -> Result<bool, StorageError> {
137        match value.expires_at().checked_sub(now_sec()) {
138            Some(0) | None => return Err(StorageError::ValueExpired),
139            Some(remaining_ttl) if remaining_ttl > self.max_ttl_sec => {
140                return Err(StorageError::UnsupportedTtl);
141            }
142            _ => {}
143        }
144
145        match value {
146            ValueRef::Peer(value) => self.insert_signed_value(value),
147            ValueRef::Merged(value) => self.insert_merged_value(source, value),
148        }
149    }
150
151    fn insert_signed_value(&self, value: &PeerValueRef<'_>) -> Result<bool, StorageError> {
152        let Some(public_key) = value.key.peer_id.as_public_key() else {
153            return Err(StorageError::InvalidSignature);
154        };
155
156        if !public_key.verify_tl(value, value.signature) {
157            return Err(StorageError::InvalidSignature);
158        }
159
160        Ok(self
161            .cache
162            .entry(tl_proto::hash(&value.key))
163            .or_insert_with_if(
164                || StoredValue::new(value, value.expires_at),
165                |prev| prev.expires_at < value.expires_at,
166            )
167            .is_fresh())
168    }
169
170    fn insert_merged_value(
171        &self,
172        source: DhtValueSource,
173        value: &MergedValueRef<'_>,
174    ) -> Result<bool, StorageError> {
175        let merger = match self.value_mergers.get(value.key.group_id) {
176            Some(merger) => merger.clone(),
177            None => return Ok(false),
178        };
179
180        merger.check_value(source, value)?;
181
182        enum MergedValueCow<'a, 'b> {
183            Borrowed(&'a MergedValueRef<'b>),
184            Owned(MergedValue),
185        }
186
187        impl MergedValueCow<'_, '_> {
188            fn make_stored_value(&self) -> StoredValue {
189                match self {
190                    Self::Borrowed(value) => StoredValue::new(*value, value.expires_at),
191                    Self::Owned(value) => StoredValue::new(value, value.expires_at),
192                }
193            }
194        }
195
196        let new_value = RefCell::new(MergedValueCow::Borrowed(value));
197
198        Ok(self
199            .cache
200            .entry(tl_proto::hash(&value.key))
201            .or_insert_with_if(
202                || {
203                    let value = new_value.borrow();
204                    value.make_stored_value()
205                },
206                |prev| {
207                    let Ok(mut prev) = tl_proto::deserialize::<MergedValue>(&prev.data) else {
208                        // Invalid values are always replaced with new values
209                        return true;
210                    };
211
212                    if merger.merge_value(source, value, &mut prev) {
213                        *new_value.borrow_mut() = MergedValueCow::Owned(prev);
214                        true
215                    } else {
216                        false
217                    }
218                },
219            )
220            .is_fresh())
221    }
222}
223
224#[derive(Clone)]
225struct StoredValue {
226    expires_at: u32,
227    data: Bytes,
228}
229
230impl StoredValue {
231    fn new<T: TlWrite<Repr = tl_proto::Boxed>>(value: &T, expires_at: u32) -> Self {
232        let mut data = BytesMut::with_capacity(value.max_size_hint());
233        value.write_to(&mut data);
234
235        StoredValue {
236            expires_at,
237            data: data.freeze(),
238        }
239    }
240}
241
242struct ValueExpiry;
243
244impl Expiry<StorageKeyId, StoredValue> for ValueExpiry {
245    fn expire_after_create(
246        &self,
247        _key: &StorageKeyId,
248        value: &StoredValue,
249        _created_at: std::time::Instant,
250    ) -> Option<Duration> {
251        Some(ttl_since_now(value.expires_at))
252    }
253
254    fn expire_after_update(
255        &self,
256        _key: &StorageKeyId,
257        value: &StoredValue,
258        _updated_at: std::time::Instant,
259        _duration_until_expiry: Option<Duration>,
260    ) -> Option<Duration> {
261        Some(ttl_since_now(value.expires_at))
262    }
263}
264
265fn ttl_since_now(expires_at: u32) -> Duration {
266    let now = std::time::SystemTime::now()
267        .duration_since(std::time::SystemTime::UNIX_EPOCH)
268        .unwrap();
269
270    Duration::from_secs(expires_at as u64).saturating_sub(now)
271}
272
273pub type StorageKeyId = [u8; 32];
274
275#[derive(Debug, thiserror::Error)]
276pub enum StorageError {
277    #[error("value expired")]
278    ValueExpired,
279    #[error("unsupported ttl")]
280    UnsupportedTtl,
281    #[error("invalid key")]
282    InvalidKey,
283    #[error("invalid signature")]
284    InvalidSignature,
285    #[error("value too big")]
286    ValueTooBig,
287    #[error("invalid source")]
288    InvalidSource,
289}