use chrono::{offset::Utc, DateTime};
use log::debug;
use sled::IVec;
use std::{collections::HashSet, marker::PhantomData};
use crate::{
encoding::Encoding,
error::Result,
structured_tree::{
StructuredBatch, StructuredIter, StructuredTransactionalTree, StructuredTree,
},
};
#[derive(Clone)]
pub struct ExpiringTree<V, E, F> {
data: StructuredTree<V, F>,
expires_at: StructuredTree<DateTime<Utc>, E>,
expires_at_inverse: StructuredTree<HashSet<IVec>, E>,
extend_on_update: bool,
extend_on_fetch: bool,
expiration_length: chrono::Duration,
}
pub struct ExpiringTreeBuilder<V, E, F> {
db: sled::Db,
data: String,
extend_on_update: bool,
extend_on_fetch: bool,
expiration_length: chrono::Duration,
value: PhantomData<V>,
encoding: PhantomData<E>,
data_encoding: PhantomData<F>,
}
pub struct ExpiringIter<'a, V, E, F>(StructuredIter<V, F>, &'a ExpiringTree<V, E, F>);
#[derive(Clone, Debug, Default)]
pub struct ExpiringBatch<V, F>(StructuredBatch<V, F>, HashSet<IVec>);
#[derive(Clone)]
pub struct ExpiringTransactionalTree<'a, V, E, F>(
StructuredTransactionalTree<'a, V, F>,
&'a ExpiringTree<V, E, F>,
);
impl<V, E, F> ExpiringTree<V, E, F>
where
E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
F: Encoding<V> + 'static,
{
pub fn cloned(&self) -> Self {
ExpiringTree {
data: self.data.cloned(),
expires_at: self.expires_at.cloned(),
expires_at_inverse: self.expires_at_inverse.cloned(),
extend_on_update: self.extend_on_update,
extend_on_fetch: self.extend_on_fetch,
expiration_length: self.expiration_length,
}
}
pub fn transaction<G, R>(&self, g: G) -> sled::TransactionResult<Result<R>>
where
G: Fn(ExpiringTransactionalTree<V, E, F>) -> sled::TransactionResult<Result<R>>,
{
self.data
.transaction(move |trans_tree| (g)(ExpiringTransactionalTree(trans_tree, &self)))
}
pub fn apply_batch(&self, batch: ExpiringBatch<V, F>) -> Result<()> {
let keys = batch.1;
self.data.apply_batch(batch.0)?;
if self.extend_on_update {
let now = Utc::now();
for key in keys {
self.update_expires_at(key, now)?;
}
}
Ok(())
}
pub fn cas<K>(
&self,
key: K,
old: Option<V>,
new: Option<V>,
) -> Result<std::result::Result<(), Option<V>>>
where
K: AsRef<[u8]>,
{
let to_delete = old.is_some() && new.is_none();
let to_update = new.is_some();
let ivec = IVec::from(key.as_ref());
let res = self.data.cas(key, old, new)?;
let success = res.is_ok();
if to_delete && success {
self.remove_expires_at(ivec)?;
} else if to_update && success && self.extend_on_update {
self.update_expires_at(ivec, Utc::now())?;
}
Ok(res)
}
pub fn get<K>(&self, key: K) -> Result<Option<V>>
where
K: AsRef<[u8]>,
{
let ivec = IVec::from(key.as_ref());
let opt = self.data.get(key)?;
if self.extend_on_fetch {
self.update_expires_at(ivec, Utc::now())?;
}
Ok(opt)
}
pub fn insert<K>(&self, key: K, value: V) -> Result<Option<V>>
where
IVec: From<K>,
K: AsRef<[u8]>,
{
let ivec: IVec = key.as_ref().into();
let opt = self.data.insert::<K>(key, value)?;
if self.extend_on_update {
self.update_expires_at(ivec, Utc::now())?;
}
Ok(opt)
}
pub fn remove<K>(&self, key: K) -> Result<Option<V>>
where
K: AsRef<[u8]>,
{
let ivec = IVec::from(key.as_ref());
let opt = self.data.remove(key)?;
self.remove_expires_at(ivec)?;
Ok(opt)
}
pub fn update_and_fetch<K>(
&self,
key: K,
f: impl Fn(Option<V>) -> Option<V>,
) -> Result<Option<V>>
where
K: AsRef<[u8]>,
{
let ivec = IVec::from(key.as_ref());
let opt = self.data.update_and_fetch(key, f)?;
if opt.is_some() && self.extend_on_update {
self.update_expires_at(ivec, Utc::now())?;
} else {
self.remove_expires_at(ivec)?;
}
Ok(opt)
}
pub fn fetch_and_update<K>(
&self,
key: K,
f: impl Fn(Option<V>) -> Option<V>,
) -> Result<Option<V>>
where
K: AsRef<[u8]>,
{
let ivec = IVec::from(key.as_ref());
let opt = self.data.fetch_and_update(key, f)?;
if opt.is_some() && self.extend_on_update {
self.update_expires_at(ivec, Utc::now())?;
} else {
self.remove_expires_at(ivec)?;
}
Ok(opt)
}
pub fn flush(&self) -> Result<()> {
self.data.flush()?;
self.expires_at.flush()?;
self.expires_at_inverse.flush()?;
Ok(())
}
pub fn contains_key<K>(&self, key: K) -> Result<bool>
where
K: AsRef<[u8]>,
{
self.data.contains_key(key)
}
pub fn iter<'a>(&'a self) -> ExpiringIter<'a, V, E, F> {
ExpiringIter(self.data.iter(), &self)
}
pub fn range<'a, K, R>(&'a self, range: R) -> ExpiringIter<'a, V, E, F>
where
K: AsRef<[u8]>,
R: std::ops::RangeBounds<K>,
{
ExpiringIter(self.data.range(range), &self)
}
pub fn get_lt<K>(&self, key: K) -> Result<Option<(IVec, V)>>
where
K: AsRef<[u8]>,
{
if let Some((k, v)) = self.data.get_lt(key)? {
if self.extend_on_fetch {
self.update_expires_at(k.clone(), Utc::now())?;
}
return Ok(Some((k, v)));
}
Ok(None)
}
pub fn get_gt<K>(&self, key: K) -> Result<Option<(IVec, V)>>
where
K: AsRef<[u8]>,
{
if let Some((k, v)) = self.data.get_gt(key)? {
if self.extend_on_fetch {
self.update_expires_at(k.clone(), Utc::now())?;
}
return Ok(Some((k, v)));
}
Ok(None)
}
pub fn scan_prefix<'a, P>(&'a self, prefix: P) -> ExpiringIter<'a, V, E, F>
where
P: AsRef<[u8]>,
{
ExpiringIter(self.data.scan_prefix(prefix), &self)
}
pub fn pop_max(&self) -> Result<Option<(IVec, V)>> {
if let Some((k, v)) = self.data.pop_max()? {
self.remove_expires_at(k.clone())?;
return Ok(Some((k, v)));
}
Ok(None)
}
pub fn pop_min(&self) -> Result<Option<(IVec, V)>> {
if let Some((k, v)) = self.data.pop_min()? {
self.remove_expires_at(k.clone())?;
return Ok(Some((k, v)));
}
Ok(None)
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn clear(&self) -> Result<()> {
self.data.clear()?;
self.expires_at.clear()?;
self.expires_at_inverse.clear()?;
Ok(())
}
pub fn name(&self) -> String {
self.data.name()
}
pub fn expired<'a>(&'a self) -> impl 'a + Iterator<Item = IVec> {
let now: IVec = Utc::now().to_string().into_bytes().into();
debug!("now: {:?}", now);
self.expires_at_inverse
.range(..now)
.values()
.filter_map(|res| res.ok())
.flat_map(|res| res.into_iter())
}
fn remove_expires_at(&self, key: IVec) -> Result<()> {
if let Some(prev) = self.expires_at.remove(key.clone())? {
self.expires_at_inverse
.update_and_fetch(prev.to_string().into_bytes(), |opt| {
opt.and_then(|mut hs| {
hs.remove(&key);
if hs.is_empty() {
None
} else {
Some(hs)
}
})
})?;
}
Ok(())
}
fn update_expires_at(&self, key: IVec, now: DateTime<Utc>) -> Result<()> {
let expires_at = now + self.expiration_length;
if let Some(prev) = self.expires_at.insert(key.clone(), expires_at)? {
self.expires_at_inverse
.update_and_fetch(prev.to_string().into_bytes(), |opt| {
opt.and_then(|mut hs| {
hs.remove(&key);
if hs.is_empty() {
None
} else {
Some(hs)
}
})
})?;
}
self.expires_at_inverse
.update_and_fetch(expires_at.to_string().into_bytes(), |opt| {
let mut hs = opt.unwrap_or(HashSet::new());
hs.insert(key.clone());
Some(hs)
})?;
Ok(())
}
}
impl<V, E, F> ExpiringTreeBuilder<V, E, F>
where
E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
F: Encoding<V> + 'static,
{
pub(crate) fn new(db: &sled::Db, data: &str) -> Self {
ExpiringTreeBuilder {
db: db.clone(),
data: data.to_owned(),
extend_on_update: false,
extend_on_fetch: false,
expiration_length: chrono::Duration::hours(12),
value: PhantomData,
encoding: PhantomData,
data_encoding: PhantomData,
}
}
pub fn extend_on_update(&mut self) -> &mut Self {
self.extend_on_update = true;
self
}
pub fn extend_on_fetch(&mut self) -> &mut Self {
self.extend_on_fetch = true;
self
}
pub fn expiration_length(&mut self, expiration_length: chrono::Duration) -> &mut Self {
self.expiration_length = expiration_length;
self
}
pub fn build(&self) -> Result<ExpiringTree<V, E, F>> {
Ok(ExpiringTree {
data: StructuredTree::new(&self.db, &self.data)?,
expires_at: StructuredTree::new(&self.db, &format!("{}-expires-at", self.data))?,
expires_at_inverse: StructuredTree::new(
&self.db,
&format!("{}-expires-at-inverse", self.data),
)?,
extend_on_update: self.extend_on_update,
extend_on_fetch: self.extend_on_fetch,
expiration_length: self.expiration_length,
})
}
}
impl<'a, V, E, F> ExpiringIter<'a, V, E, F>
where
E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
F: Encoding<V> + 'static,
{
pub fn keys(self) -> impl 'a + DoubleEndedIterator<Item = Result<IVec>> {
self.map(|res| res.map(|(key, _)| key))
}
pub fn values(self) -> impl 'a + DoubleEndedIterator<Item = Result<V>> {
self.map(|res| res.map(|(_, v)| v))
}
}
impl<V, F> ExpiringBatch<V, F>
where
F: Encoding<V> + 'static,
{
pub fn insert<K>(&mut self, key: K, value: V) -> Result<()>
where
IVec: From<K>,
{
let k = IVec::from(key);
self.1.insert(k.clone());
self.0.insert::<IVec>(k, value)
}
pub fn remove<K>(&mut self, key: K)
where
IVec: From<K>,
{
let k = IVec::from(key);
self.1.remove(&k);
self.0.remove::<IVec>(k)
}
}
impl<'a, V, E, F> ExpiringTransactionalTree<'a, V, E, F>
where
E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
F: Encoding<V> + 'static,
{
pub fn insert<K>(&self, key: K, value: V) -> sled::TransactionResult<Result<Option<V>>>
where
IVec: From<K>,
K: AsRef<[u8]>,
{
let k = IVec::from(key);
let r = self.0.insert::<IVec>(k.clone(), value)?;
if self.1.extend_on_update {
if let Err(e) = self.1.update_expires_at(k, Utc::now()) {
return Ok(Err(e));
}
}
Ok(r)
}
pub fn remove<K>(&self, key: K) -> sled::TransactionResult<Result<Option<V>>>
where
IVec: From<K>,
K: AsRef<[u8]>,
{
let k = IVec::from(key);
let r = self.0.remove::<IVec>(k.clone())?;
if let Err(e) = self.1.remove_expires_at(k) {
return Ok(Err(e));
}
Ok(r)
}
pub fn get<K>(&self, key: K) -> sled::TransactionResult<Result<Option<V>>>
where
K: AsRef<[u8]>,
{
let k = key.as_ref().to_vec();
let r = self.0.get(key)?;
if self.1.extend_on_fetch {
if let Err(e) = self.1.update_expires_at(k.into(), Utc::now()) {
return Ok(Err(e));
}
}
Ok(r)
}
pub fn apply_batch(&self, batch: ExpiringBatch<V, F>) -> sled::TransactionResult<Result<()>> {
let keys = batch.1;
self.0.apply_batch(batch.0)?;
if self.1.extend_on_update {
let now = Utc::now();
for key in keys {
if let Err(e) = self.1.update_expires_at(key, now) {
return Ok(Err(e));
}
}
}
Ok(Ok(()))
}
}
impl<'a, V, E, F> Iterator for ExpiringIter<'a, V, E, F>
where
E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
F: Encoding<V> + 'static,
{
type Item = Result<(IVec, V)>;
fn next(&mut self) -> Option<Self::Item> {
Some(self.0.next()?.and_then(move |(k, v)| {
if self.1.extend_on_fetch {
self.1
.update_expires_at(k.clone(), Utc::now())
.map(move |_| (k, v))
} else {
Ok((k, v))
}
}))
}
}
impl<'a, V, E, F> DoubleEndedIterator for ExpiringIter<'a, V, E, F>
where
E: Encoding<HashSet<IVec>> + Encoding<DateTime<Utc>> + 'static,
F: Encoding<V> + 'static,
{
fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
Some(self.0.next_back()?.and_then(move |(k, v)| {
if self.1.extend_on_fetch {
self.1
.update_expires_at(k.clone(), Utc::now())
.map(move |_| (k, v))
} else {
Ok((k, v))
}
}))
}
}