1use std::cell::RefCell;
2use std::sync::Arc;
3use std::time::Duration;
4
5use anyhow::Result;
6use bytes::{Bytes, BytesMut};
7use bytesize::ByteSize;
8use moka::Expiry;
9use moka::sync::{Cache, CacheBuilder};
10use tl_proto::TlWrite;
11use tycho_util::FastDashMap;
12use tycho_util::time::now_sec;
13
14use crate::proto::dht::{MergedValue, MergedValueRef, PeerValueRef, ValueRef};
15
16type DhtCache<S> = Cache<StorageKeyId, StoredValue, S>;
17type DhtCacheBuilder<S> = CacheBuilder<StorageKeyId, StoredValue, DhtCache<S>>;
18
19#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
20pub enum DhtValueSource {
21 Local,
22 Remote,
23}
24
25pub trait DhtValueMerger: Send + Sync + 'static {
26 fn check_value(
27 &self,
28 source: DhtValueSource,
29 new: &MergedValueRef<'_>,
30 ) -> Result<(), StorageError>;
31
32 fn merge_value(
33 &self,
34 source: DhtValueSource,
35 new: &MergedValueRef<'_>,
36 stored: &mut MergedValue,
37 ) -> bool;
38}
39
40pub(crate) struct StorageBuilder {
41 cache_builder: DhtCacheBuilder<std::hash::RandomState>,
42 value_mergers: FastDashMap<[u8; 32], Arc<dyn DhtValueMerger>>,
43 max_ttl: Duration,
44}
45
46impl Default for StorageBuilder {
47 fn default() -> Self {
48 Self {
49 cache_builder: Default::default(),
50 value_mergers: Default::default(),
51 max_ttl: Duration::from_secs(3600),
52 }
53 }
54}
55
56impl StorageBuilder {
57 pub fn build(self) -> Storage {
58 fn weigher(_key: &StorageKeyId, value: &StoredValue) -> u32 {
59 std::mem::size_of::<StorageKeyId>() as u32
60 + std::mem::size_of::<StoredValue>() as u32
61 + value.data.len() as u32
62 }
63
64 Storage {
65 cache: self
66 .cache_builder
67 .time_to_live(self.max_ttl)
68 .weigher(weigher)
69 .expire_after(ValueExpiry)
70 .build_with_hasher(ahash::RandomState::default()),
71 value_mergers: self.value_mergers,
72 max_ttl_sec: self.max_ttl.as_secs().try_into().unwrap_or(u32::MAX),
73 }
74 }
75
76 #[allow(unused)]
77 pub fn with_value_merger(
78 self,
79 group_id: &[u8; 32],
80 value_merger: Arc<dyn DhtValueMerger>,
81 ) -> Self {
82 self.value_mergers.insert(*group_id, value_merger);
83 self
84 }
85
86 pub fn with_max_capacity(mut self, max_capacity: ByteSize) -> Self {
87 self.cache_builder = self.cache_builder.max_capacity(max_capacity.0);
88 self
89 }
90
91 pub fn with_max_ttl(mut self, ttl: Duration) -> Self {
92 self.max_ttl = ttl;
93 self
94 }
95
96 pub fn with_max_idle(mut self, duration: Duration) -> Self {
97 self.cache_builder = self.cache_builder.time_to_idle(duration);
98 self
99 }
100}
101
102pub(crate) struct Storage {
103 cache: DhtCache<ahash::RandomState>,
104 value_mergers: FastDashMap<[u8; 32], Arc<dyn DhtValueMerger>>,
105 max_ttl_sec: u32,
106}
107
108impl Storage {
109 pub fn builder() -> StorageBuilder {
110 StorageBuilder::default()
111 }
112
113 pub fn insert_merger(
114 &self,
115 group_id: &[u8; 32],
116 merger: Arc<dyn DhtValueMerger>,
117 ) -> Option<Arc<dyn DhtValueMerger>> {
118 self.value_mergers.insert(*group_id, merger)
119 }
120
121 pub fn remove_merger(&self, group_id: &[u8; 32]) -> Option<Arc<dyn DhtValueMerger>> {
122 self.value_mergers
123 .remove(group_id)
124 .map(|(_, merger)| merger)
125 }
126
127 pub fn get(&self, key: &[u8; 32]) -> Option<Bytes> {
128 let stored_value = self.cache.get(key)?;
129 (stored_value.expires_at > now_sec()).then_some(stored_value.data)
130 }
131
132 pub fn insert(
133 &self,
134 source: DhtValueSource,
135 value: &ValueRef<'_>,
136 ) -> Result<bool, StorageError> {
137 match value.expires_at().checked_sub(now_sec()) {
138 Some(0) | None => return Err(StorageError::ValueExpired),
139 Some(remaining_ttl) if remaining_ttl > self.max_ttl_sec => {
140 return Err(StorageError::UnsupportedTtl);
141 }
142 _ => {}
143 }
144
145 match value {
146 ValueRef::Peer(value) => self.insert_signed_value(value),
147 ValueRef::Merged(value) => self.insert_merged_value(source, value),
148 }
149 }
150
151 fn insert_signed_value(&self, value: &PeerValueRef<'_>) -> Result<bool, StorageError> {
152 let Some(public_key) = value.key.peer_id.as_public_key() else {
153 return Err(StorageError::InvalidSignature);
154 };
155
156 if !public_key.verify_tl(value, value.signature) {
157 return Err(StorageError::InvalidSignature);
158 }
159
160 Ok(self
161 .cache
162 .entry(tl_proto::hash(&value.key))
163 .or_insert_with_if(
164 || StoredValue::new(value, value.expires_at),
165 |prev| prev.expires_at < value.expires_at,
166 )
167 .is_fresh())
168 }
169
170 fn insert_merged_value(
171 &self,
172 source: DhtValueSource,
173 value: &MergedValueRef<'_>,
174 ) -> Result<bool, StorageError> {
175 let merger = match self.value_mergers.get(value.key.group_id) {
176 Some(merger) => merger.clone(),
177 None => return Ok(false),
178 };
179
180 merger.check_value(source, value)?;
181
182 enum MergedValueCow<'a, 'b> {
183 Borrowed(&'a MergedValueRef<'b>),
184 Owned(MergedValue),
185 }
186
187 impl MergedValueCow<'_, '_> {
188 fn make_stored_value(&self) -> StoredValue {
189 match self {
190 Self::Borrowed(value) => StoredValue::new(*value, value.expires_at),
191 Self::Owned(value) => StoredValue::new(value, value.expires_at),
192 }
193 }
194 }
195
196 let new_value = RefCell::new(MergedValueCow::Borrowed(value));
197
198 Ok(self
199 .cache
200 .entry(tl_proto::hash(&value.key))
201 .or_insert_with_if(
202 || {
203 let value = new_value.borrow();
204 value.make_stored_value()
205 },
206 |prev| {
207 let Ok(mut prev) = tl_proto::deserialize::<MergedValue>(&prev.data) else {
208 return true;
210 };
211
212 if merger.merge_value(source, value, &mut prev) {
213 *new_value.borrow_mut() = MergedValueCow::Owned(prev);
214 true
215 } else {
216 false
217 }
218 },
219 )
220 .is_fresh())
221 }
222}
223
224#[derive(Clone)]
225struct StoredValue {
226 expires_at: u32,
227 data: Bytes,
228}
229
230impl StoredValue {
231 fn new<T: TlWrite<Repr = tl_proto::Boxed>>(value: &T, expires_at: u32) -> Self {
232 let mut data = BytesMut::with_capacity(value.max_size_hint());
233 value.write_to(&mut data);
234
235 StoredValue {
236 expires_at,
237 data: data.freeze(),
238 }
239 }
240}
241
242struct ValueExpiry;
243
244impl Expiry<StorageKeyId, StoredValue> for ValueExpiry {
245 fn expire_after_create(
246 &self,
247 _key: &StorageKeyId,
248 value: &StoredValue,
249 _created_at: std::time::Instant,
250 ) -> Option<Duration> {
251 Some(ttl_since_now(value.expires_at))
252 }
253
254 fn expire_after_update(
255 &self,
256 _key: &StorageKeyId,
257 value: &StoredValue,
258 _updated_at: std::time::Instant,
259 _duration_until_expiry: Option<Duration>,
260 ) -> Option<Duration> {
261 Some(ttl_since_now(value.expires_at))
262 }
263}
264
265fn ttl_since_now(expires_at: u32) -> Duration {
266 let now = std::time::SystemTime::now()
267 .duration_since(std::time::SystemTime::UNIX_EPOCH)
268 .unwrap();
269
270 Duration::from_secs(expires_at as u64).saturating_sub(now)
271}
272
273pub type StorageKeyId = [u8; 32];
274
275#[derive(Debug, thiserror::Error)]
276pub enum StorageError {
277 #[error("value expired")]
278 ValueExpired,
279 #[error("unsupported ttl")]
280 UnsupportedTtl,
281 #[error("invalid key")]
282 InvalidKey,
283 #[error("invalid signature")]
284 InvalidSignature,
285 #[error("value too big")]
286 ValueTooBig,
287 #[error("invalid source")]
288 InvalidSource,
289}