1use chrono::{offset::Utc, DateTime};
2use log::debug;
3use sled::IVec;
4use std::{collections::HashSet, marker::PhantomData};
5
6use crate::{
7 encoding::Encoding,
8 error::Result,
9 structured_tree::{
10 CompareAndSwapError, StructuredBatch, StructuredIter, StructuredTransactionalTree,
11 StructuredTree,
12 },
13};
14
15#[derive(Clone)]
16pub struct ExpiringTree<V, E, F> {
20 data: StructuredTree<V, F>,
21 expires_at: StructuredTree<DateTime<Utc>, E>,
22 expires_at_inverse: StructuredTree<HashSet<IVec>, E>,
23 extend_on_update: bool,
24 extend_on_fetch: bool,
25 expiration_length: chrono::Duration,
26}
27
28pub struct ExpiringTreeBuilder<V, E, F> {
33 db: sled::Db,
34 data: String,
35 extend_on_update: bool,
36 extend_on_fetch: bool,
37 expiration_length: chrono::Duration,
38 value: PhantomData<V>,
39 encoding: PhantomData<E>,
40 data_encoding: PhantomData<F>,
41}
42
43pub struct ExpiringIter<'a, V, E, F>(StructuredIter<V, F>, &'a ExpiringTree<V, E, F>);
45
46#[derive(Clone, Debug, Default)]
47pub struct ExpiringBatch<V, F>(StructuredBatch<V, F>, HashSet<IVec>);
49
50#[derive(Clone)]
51pub struct ExpiringTransactionalTree<'a, V, E, F>(
53 StructuredTransactionalTree<'a, V, F>,
54 &'a ExpiringTree<V, E, F>,
55);
56
57impl<V, E, F> ExpiringTree<V, E, F>
58where
59 E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
60 F: Encoding<V> + 'static,
61{
62 pub fn cloned(&self) -> Self {
64 ExpiringTree {
65 data: self.data.cloned(),
66 expires_at: self.expires_at.cloned(),
67 expires_at_inverse: self.expires_at_inverse.cloned(),
68 extend_on_update: self.extend_on_update,
69 extend_on_fetch: self.extend_on_fetch,
70 expiration_length: self.expiration_length,
71 }
72 }
73
74 pub fn transaction<G, R>(&self, g: G) -> sled::TransactionResult<Result<R>>
80 where
81 G: Fn(ExpiringTransactionalTree<V, E, F>) -> sled::ConflictableTransactionResult<Result<R>>,
82 {
83 self.data
84 .transaction(move |trans_tree| (g)(ExpiringTransactionalTree(trans_tree, &self)))
85 }
86
87 pub fn apply_batch(&self, batch: ExpiringBatch<V, F>) -> Result<()> {
91 let keys = batch.1;
92 self.data.apply_batch(batch.0)?;
93
94 if self.extend_on_update {
95 let now = Utc::now();
96 for key in keys {
97 self.update_expires_at(key, now)?;
98 }
99 }
100
101 Ok(())
102 }
103
104 pub fn compare_and_swap<K>(
115 &self,
116 key: K,
117 old: Option<V>,
118 new: Option<V>,
119 ) -> Result<std::result::Result<(), CompareAndSwapError<V>>>
120 where
121 K: AsRef<[u8]>,
122 {
123 let to_delete = old.is_some() && new.is_none();
124 let to_update = new.is_some();
125
126 let ivec = IVec::from(key.as_ref());
127
128 let res = self.data.compare_and_swap(key, old, new)?;
129
130 let success = res.is_ok();
131
132 if to_delete && success {
133 self.remove_expires_at(ivec)?;
134 } else if to_update && success && self.extend_on_update {
135 self.update_expires_at(ivec, Utc::now())?;
136 }
137
138 Ok(res)
139 }
140
141 pub fn get<K>(&self, key: K) -> Result<Option<V>>
143 where
144 K: AsRef<[u8]>,
145 {
146 let ivec = IVec::from(key.as_ref());
147 let opt = self.data.get(key)?;
148
149 if self.extend_on_fetch {
150 self.update_expires_at(ivec, Utc::now())?;
151 }
152
153 Ok(opt)
154 }
155
156 pub fn insert<K>(&self, key: K, value: V) -> Result<Option<V>>
158 where
159 IVec: From<K>,
160 K: AsRef<[u8]>,
161 {
162 let ivec: IVec = key.as_ref().into();
163 let opt = self.data.insert::<K>(key, value)?;
164
165 if self.extend_on_update {
166 self.update_expires_at(ivec, Utc::now())?;
167 }
168
169 Ok(opt)
170 }
171
172 pub fn remove<K>(&self, key: K) -> Result<Option<V>>
174 where
175 K: AsRef<[u8]>,
176 {
177 let ivec = IVec::from(key.as_ref());
178 let opt = self.data.remove(key)?;
179
180 self.remove_expires_at(ivec)?;
181
182 Ok(opt)
183 }
184
185 pub fn update_and_fetch<K>(
191 &self,
192 key: K,
193 f: impl Fn(Option<V>) -> Option<V>,
194 ) -> Result<Option<V>>
195 where
196 K: AsRef<[u8]>,
197 {
198 let ivec = IVec::from(key.as_ref());
199 let opt = self.data.update_and_fetch(key, f)?;
200
201 if opt.is_some() && self.extend_on_update {
202 self.update_expires_at(ivec, Utc::now())?;
203 } else {
204 self.remove_expires_at(ivec)?;
205 }
206
207 Ok(opt)
208 }
209
210 pub fn fetch_and_update<K>(
215 &self,
216 key: K,
217 f: impl Fn(Option<V>) -> Option<V>,
218 ) -> Result<Option<V>>
219 where
220 K: AsRef<[u8]>,
221 {
222 let ivec = IVec::from(key.as_ref());
223 let opt = self.data.fetch_and_update(key, f)?;
224
225 if opt.is_some() && self.extend_on_update {
226 self.update_expires_at(ivec, Utc::now())?;
227 } else {
228 self.remove_expires_at(ivec)?;
229 }
230
231 Ok(opt)
232 }
233
234 pub fn flush(&self) -> Result<()> {
238 self.data.flush()?;
239 self.expires_at.flush()?;
240 self.expires_at_inverse.flush()?;
241 Ok(())
242 }
243
244 pub fn contains_key<K>(&self, key: K) -> Result<bool>
246 where
247 K: AsRef<[u8]>,
248 {
249 self.data.contains_key(key)
250 }
251
252 pub fn iter<'a>(&'a self) -> ExpiringIter<'a, V, E, F> {
254 ExpiringIter(self.data.iter(), &self)
255 }
256
257 pub fn range<'a, K, R>(&'a self, range: R) -> ExpiringIter<'a, V, E, F>
260 where
261 K: AsRef<[u8]>,
262 R: std::ops::RangeBounds<K>,
263 {
264 ExpiringIter(self.data.range(range), &self)
265 }
266
267 pub fn get_lt<K>(&self, key: K) -> Result<Option<(IVec, V)>>
269 where
270 K: AsRef<[u8]>,
271 {
272 if let Some((k, v)) = self.data.get_lt(key)? {
273 if self.extend_on_fetch {
274 self.update_expires_at(k.clone(), Utc::now())?;
275 }
276
277 return Ok(Some((k, v)));
278 }
279
280 Ok(None)
281 }
282
283 pub fn get_gt<K>(&self, key: K) -> Result<Option<(IVec, V)>>
292 where
293 K: AsRef<[u8]>,
294 {
295 if let Some((k, v)) = self.data.get_gt(key)? {
296 if self.extend_on_fetch {
297 self.update_expires_at(k.clone(), Utc::now())?;
298 }
299
300 return Ok(Some((k, v)));
301 }
302
303 Ok(None)
304 }
305
306 pub fn scan_prefix<'a, P>(&'a self, prefix: P) -> ExpiringIter<'a, V, E, F>
309 where
310 P: AsRef<[u8]>,
311 {
312 ExpiringIter(self.data.scan_prefix(prefix), &self)
313 }
314
315 pub fn pop_max(&self) -> Result<Option<(IVec, V)>> {
317 if let Some((k, v)) = self.data.pop_max()? {
318 self.remove_expires_at(k.clone())?;
319
320 return Ok(Some((k, v)));
321 }
322
323 Ok(None)
324 }
325
326 pub fn pop_min(&self) -> Result<Option<(IVec, V)>> {
328 if let Some((k, v)) = self.data.pop_min()? {
329 self.remove_expires_at(k.clone())?;
330
331 return Ok(Some((k, v)));
332 }
333
334 Ok(None)
335 }
336
337 pub fn len(&self) -> usize {
341 self.data.len()
342 }
343
344 pub fn is_empty(&self) -> bool {
346 self.data.is_empty()
347 }
348
349 pub fn clear(&self) -> Result<()> {
353 self.data.clear()?;
354 self.expires_at.clear()?;
355 self.expires_at_inverse.clear()?;
356 Ok(())
357 }
358
359 pub fn name(&self) -> String {
361 self.data.name()
362 }
363
364 pub fn expired<'a>(&'a self) -> impl 'a + Iterator<Item = IVec> {
366 let now: IVec = Utc::now().to_string().into_bytes().into();
367 debug!("now: {:?}", now);
368
369 self.expires_at_inverse
370 .range(..now)
371 .values()
372 .filter_map(|res| res.ok())
373 .flat_map(|res| res.into_iter())
374 }
375
376 fn remove_expires_at(&self, key: IVec) -> Result<()> {
377 if let Some(prev) = self.expires_at.remove(key.clone())? {
378 self.expires_at_inverse
379 .update_and_fetch(prev.to_string().into_bytes(), |opt| {
380 opt.and_then(|mut hs| {
381 hs.remove(&key);
382 if hs.is_empty() {
383 None
384 } else {
385 Some(hs)
386 }
387 })
388 })?;
389 }
390
391 Ok(())
392 }
393
394 fn update_expires_at(&self, key: IVec, now: DateTime<Utc>) -> Result<()> {
395 let expires_at = now + self.expiration_length;
396
397 if let Some(prev) = self.expires_at.insert(key.clone(), expires_at)? {
398 self.expires_at_inverse
399 .update_and_fetch(prev.to_string().into_bytes(), |opt| {
400 opt.and_then(|mut hs| {
401 hs.remove(&key);
402 if hs.is_empty() {
403 None
404 } else {
405 Some(hs)
406 }
407 })
408 })?;
409 }
410
411 self.expires_at_inverse
412 .update_and_fetch(expires_at.to_string().into_bytes(), |opt| {
413 let mut hs = opt.unwrap_or(HashSet::new());
414 hs.insert(key.clone());
415 Some(hs)
416 })?;
417
418 Ok(())
419 }
420}
421
422impl<V, E, F> ExpiringTreeBuilder<V, E, F>
423where
424 E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
425 F: Encoding<V> + 'static,
426{
427 pub(crate) fn new(db: &sled::Db, data: &str) -> Self {
428 ExpiringTreeBuilder {
429 db: db.clone(),
430 data: data.to_owned(),
431 extend_on_update: false,
432 extend_on_fetch: false,
433 expiration_length: chrono::Duration::hours(12),
434 value: PhantomData,
435 encoding: PhantomData,
436 data_encoding: PhantomData,
437 }
438 }
439
440 pub fn extend_on_update(&mut self) -> &mut Self {
442 self.extend_on_update = true;
443 self
444 }
445
446 pub fn extend_on_fetch(&mut self) -> &mut Self {
448 self.extend_on_fetch = true;
449 self
450 }
451
452 pub fn expiration_length(&mut self, expiration_length: chrono::Duration) -> &mut Self {
454 self.expiration_length = expiration_length;
455 self
456 }
457
458 pub fn build(&self) -> Result<ExpiringTree<V, E, F>> {
460 Ok(ExpiringTree {
461 data: StructuredTree::new(&self.db, &self.data)?,
462 expires_at: StructuredTree::new(&self.db, &format!("{}-expires-at", self.data))?,
463 expires_at_inverse: StructuredTree::new(
464 &self.db,
465 &format!("{}-expires-at-inverse", self.data),
466 )?,
467 extend_on_update: self.extend_on_update,
468 extend_on_fetch: self.extend_on_fetch,
469 expiration_length: self.expiration_length,
470 })
471 }
472}
473
474impl<'a, V, E, F> ExpiringIter<'a, V, E, F>
475where
476 E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
477 F: Encoding<V> + 'static,
478{
479 pub fn keys(self) -> impl 'a + DoubleEndedIterator<Item = Result<IVec>> {
481 self.map(|res| res.map(|(key, _)| key))
482 }
483
484 pub fn values(self) -> impl 'a + DoubleEndedIterator<Item = Result<V>> {
486 self.map(|res| res.map(|(_, v)| v))
487 }
488}
489
490impl<V, F> ExpiringBatch<V, F>
491where
492 F: Encoding<V> + 'static,
493{
494 pub fn insert<K>(&mut self, key: K, value: V) -> Result<()>
496 where
497 IVec: From<K>,
498 {
499 let k = IVec::from(key);
500 self.1.insert(k.clone());
501 self.0.insert::<IVec>(k, value)
502 }
503
504 pub fn remove<K>(&mut self, key: K)
506 where
507 IVec: From<K>,
508 {
509 let k = IVec::from(key);
510 self.1.remove(&k);
511 self.0.remove::<IVec>(k)
512 }
513}
514
515impl<'a, V, E, F> ExpiringTransactionalTree<'a, V, E, F>
516where
517 E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
518 F: Encoding<V> + 'static,
519{
520 pub fn insert<K>(
522 &self,
523 key: K,
524 value: V,
525 ) -> sled::ConflictableTransactionResult<Result<Option<V>>>
526 where
527 IVec: From<K>,
528 K: AsRef<[u8]>,
529 {
530 let k = IVec::from(key);
531 let r = self.0.insert::<IVec>(k.clone(), value)?;
532
533 if self.1.extend_on_update {
534 if let Err(e) = self.1.update_expires_at(k, Utc::now()) {
535 return Ok(Err(e));
536 }
537 }
538
539 Ok(r)
540 }
541
542 pub fn remove<K>(&self, key: K) -> sled::ConflictableTransactionResult<Result<Option<V>>>
544 where
545 IVec: From<K>,
546 K: AsRef<[u8]>,
547 {
548 let k = IVec::from(key);
549 let r = self.0.remove::<IVec>(k.clone())?;
550
551 if let Err(e) = self.1.remove_expires_at(k) {
552 return Ok(Err(e));
553 }
554
555 Ok(r)
556 }
557
558 pub fn get<K>(&self, key: K) -> sled::ConflictableTransactionResult<Result<Option<V>>>
560 where
561 K: AsRef<[u8]>,
562 {
563 let k = key.as_ref().to_vec();
564
565 let r = self.0.get(key)?;
566
567 if self.1.extend_on_fetch {
568 if let Err(e) = self.1.update_expires_at(k.into(), Utc::now()) {
569 return Ok(Err(e));
570 }
571 }
572
573 Ok(r)
574 }
575
576 pub fn apply_batch(
578 &self,
579 batch: ExpiringBatch<V, F>,
580 ) -> sled::ConflictableTransactionResult<Result<()>> {
581 let keys = batch.1;
582 self.0.apply_batch(batch.0)?;
583
584 if self.1.extend_on_update {
585 let now = Utc::now();
586 for key in keys {
587 if let Err(e) = self.1.update_expires_at(key, now) {
588 return Ok(Err(e));
589 }
590 }
591 }
592
593 Ok(Ok(()))
594 }
595}
596
597impl<'a, V, E, F> Iterator for ExpiringIter<'a, V, E, F>
598where
599 E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
600 F: Encoding<V> + 'static,
601{
602 type Item = Result<(IVec, V)>;
603
604 fn next(&mut self) -> Option<Self::Item> {
605 Some(self.0.next()?.and_then(move |(k, v)| {
606 if self.1.extend_on_fetch {
607 self.1
608 .update_expires_at(k.clone(), Utc::now())
609 .map(move |_| (k, v))
610 } else {
611 Ok((k, v))
612 }
613 }))
614 }
615}
616
617impl<'a, V, E, F> DoubleEndedIterator for ExpiringIter<'a, V, E, F>
618where
619 E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
620 F: Encoding<V> + 'static,
621{
622 fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
623 Some(self.0.next_back()?.and_then(move |(k, v)| {
624 if self.1.extend_on_fetch {
625 self.1
626 .update_expires_at(k.clone(), Utc::now())
627 .map(move |_| (k, v))
628 } else {
629 Ok((k, v))
630 }
631 }))
632 }
633}