use super::*;
macro_rules! cfg_std_feature {
($($item:item)*) => {
$(
#[cfg(feature = "std")]
$item
)*
};
}
macro_rules! cfg_not_std_feature {
($($item:item)*) => {
$(
#[cfg(not(feature = "std"))]
$item
)*
};
}
cfg_not_std_feature! {
pub trait GenericKey: Clone + Eq + Ord {}
impl<T: Clone + Eq + Ord> GenericKey for T {}
}
cfg_std_feature! {
pub trait GenericKey: Clone + Eq + Ord + Hash {}
impl<T: Clone + Eq + Ord + Hash> GenericKey for T {}
}
#[allow(clippy::enum_variant_names)]
#[derive(Debug)]
enum GenericMap<K, V> {
BTreeMap(BTreeMap<K, V>),
#[cfg(feature = "std")]
HashMap(HashMap<K, V>),
#[cfg(all(feature = "std", feature = "rustc-hash"))]
FxHashMap(FxHashMap<K, V>),
}
impl<K, V> Default for GenericMap<K, V> {
fn default() -> Self {
Self::BTreeMap(BTreeMap::default())
}
}
impl<K, V> GenericMap<K, V>
where
K: GenericKey,
{
#[inline(always)]
fn get(&self, k: &K) -> Option<&V> {
match self {
Self::BTreeMap(inner) => inner.get(k),
#[cfg(feature = "std")]
Self::HashMap(inner) => inner.get(k),
#[cfg(all(feature = "std", feature = "rustc-hash"))]
Self::FxHashMap(inner) => inner.get(k),
}
}
#[inline(always)]
fn get_mut(&mut self, k: &K) -> Option<&mut V> {
match self {
Self::BTreeMap(inner) => inner.get_mut(k),
#[cfg(feature = "std")]
Self::HashMap(inner) => inner.get_mut(k),
#[cfg(all(feature = "std", feature = "rustc-hash"))]
Self::FxHashMap(inner) => inner.get_mut(k),
}
}
#[inline(always)]
fn len(&self) -> usize {
match self {
Self::BTreeMap(inner) => inner.len(),
#[cfg(feature = "std")]
Self::HashMap(inner) => inner.len(),
#[cfg(all(feature = "std", feature = "rustc-hash"))]
Self::FxHashMap(inner) => inner.len(),
}
}
#[inline(always)]
fn keys(&self) -> Vec<K> {
match self {
Self::BTreeMap(inner) => inner.keys().cloned().collect(),
#[cfg(feature = "std")]
Self::HashMap(inner) => inner.keys().cloned().collect(),
#[cfg(all(feature = "std", feature = "rustc-hash"))]
Self::FxHashMap(inner) => inner.keys().cloned().collect(),
}
}
#[inline(always)]
fn is_empty(&self) -> bool {
match self {
Self::BTreeMap(inner) => inner.is_empty(),
#[cfg(feature = "std")]
Self::HashMap(inner) => inner.is_empty(),
#[cfg(all(feature = "std", feature = "rustc-hash"))]
Self::FxHashMap(inner) => inner.is_empty(),
}
}
#[inline(always)]
fn insert(&mut self, k: K, v: V) -> Option<V> {
match self {
Self::BTreeMap(inner) => inner.insert(k, v),
#[cfg(feature = "std")]
Self::HashMap(inner) => inner.insert(k, v),
#[cfg(all(feature = "std", feature = "rustc-hash"))]
Self::FxHashMap(inner) => inner.insert(k, v),
}
}
#[inline(always)]
fn clear(&mut self) {
match self {
Self::BTreeMap(inner) => inner.clear(),
#[cfg(feature = "std")]
Self::HashMap(inner) => inner.clear(),
#[cfg(all(feature = "std", feature = "rustc-hash"))]
Self::FxHashMap(inner) => inner.clear(),
}
}
#[inline(always)]
fn remove(&mut self, k: &K) -> Option<V> {
match self {
Self::BTreeMap(inner) => inner.remove(k),
#[cfg(feature = "std")]
Self::HashMap(inner) => inner.remove(k),
#[cfg(all(feature = "std", feature = "rustc-hash"))]
Self::FxHashMap(inner) => inner.remove(k),
}
}
}
#[cfg(feature = "std")]
#[allow(clippy::enum_variant_names)]
pub enum MapKind {
BTreeMap,
HashMap,
#[cfg(feature = "rustc-hash")]
FxHashMap,
}
#[derive(Debug)]
pub struct TimedMap<C, K, V> {
#[cfg(feature = "std")]
clock: StdClock,
#[cfg(feature = "std")]
marker: PhantomData<C>,
#[cfg(not(feature = "std"))]
clock: C,
map: GenericMap<K, ExpirableEntry<V>>,
expiries: BTreeMap<u64, BTreeSet<K>>,
expiration_tick: u16,
expiration_tick_cap: u16,
}
#[cfg(feature = "std")]
impl<C, K, V> Default for TimedMap<C, K, V> {
fn default() -> Self {
Self {
clock: StdClock::new(),
map: GenericMap::default(),
expiries: BTreeMap::default(),
marker: PhantomData,
expiration_tick: 0,
expiration_tick_cap: 1,
}
}
}
impl<C, K, V> TimedMap<C, K, V>
where
C: Clock,
K: GenericKey,
{
#[cfg(feature = "std")]
pub fn new() -> Self {
Self::default()
}
#[cfg(feature = "std")]
pub fn new_with_map_kind(map_kind: MapKind) -> Self {
let map = match map_kind {
MapKind::BTreeMap => GenericMap::<K, ExpirableEntry<V>>::BTreeMap(BTreeMap::default()),
MapKind::HashMap => GenericMap::HashMap(HashMap::default()),
#[cfg(feature = "rustc-hash")]
MapKind::FxHashMap => GenericMap::FxHashMap(FxHashMap::default()),
};
Self {
map,
clock: StdClock::new(),
expiries: BTreeMap::default(),
#[cfg(feature = "std")]
marker: PhantomData,
expiration_tick: 0,
expiration_tick_cap: 1,
}
}
#[cfg(not(feature = "std"))]
pub fn new(clock: C) -> Self {
Self {
clock,
map: GenericMap::default(),
expiries: BTreeMap::default(),
expiration_tick: 0,
expiration_tick_cap: 1,
}
}
#[inline(always)]
pub fn expiration_tick_cap(mut self, expiration_tick_cap: u16) -> Self {
self.expiration_tick_cap = expiration_tick_cap;
self
}
pub fn get(&self, k: &K) -> Option<&V> {
self.map
.get(k)
.filter(|v| !v.is_expired(self.clock.elapsed_seconds_since_creation()))
.map(|v| v.value())
}
pub fn get_mut(&mut self, k: &K) -> Option<&mut V> {
self.map
.get_mut(k)
.filter(|v| !v.is_expired(self.clock.elapsed_seconds_since_creation()))
.map(|v| v.value_mut())
}
#[inline(always)]
pub fn get_unchecked(&self, k: &K) -> Option<&V> {
self.map.get(k).map(|v| v.value())
}
#[inline(always)]
pub fn get_mut_unchecked(&mut self, k: &K) -> Option<&mut V> {
self.map.get_mut(k).map(|v| v.value_mut())
}
pub fn get_remaining_duration(&self, k: &K) -> Option<Duration> {
match self.map.get(k) {
Some(v) => {
let now = self.clock.elapsed_seconds_since_creation();
if v.is_expired(now) {
return None;
}
v.remaining_duration(now)
}
None => None,
}
}
#[inline(always)]
pub fn len(&self) -> usize {
self.map.len() - self.len_expired()
}
#[inline(always)]
pub fn len_expired(&self) -> usize {
let now = self.clock.elapsed_seconds_since_creation();
self.expiries
.iter()
.filter_map(
|(exp, keys)| {
if exp <= &now {
Some(keys.len())
} else {
None
}
},
)
.sum()
}
#[inline(always)]
pub fn len_unchecked(&self) -> usize {
self.map.len()
}
#[inline(always)]
pub fn keys(&self) -> Vec<K> {
self.map.keys()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
#[inline(always)]
fn insert(&mut self, k: K, v: V, expires_at: Option<u64>) -> Option<V> {
let entry = ExpirableEntry::new(v, expires_at);
self.map.insert(k, entry).map(|v| v.owned_value())
}
pub fn insert_expirable(&mut self, k: K, v: V, duration: Duration) -> Option<V> {
self.expiration_tick += 1;
let now = self.clock.elapsed_seconds_since_creation();
let expires_at = now + duration.as_secs();
let res = self.insert(k.clone(), v, Some(expires_at));
let expiry_keys = self.expiries.entry(expires_at).or_default();
expiry_keys.insert(k);
if self.expiration_tick >= self.expiration_tick_cap {
self.drop_expired_entries_inner(now);
self.expiration_tick = 0;
}
res
}
pub fn insert_expirable_unchecked(&mut self, k: K, v: V, duration: Duration) -> Option<V> {
let now = self.clock.elapsed_seconds_since_creation();
let expires_at = now + duration.as_secs();
self.insert(k, v, Some(expires_at))
}
pub fn insert_constant(&mut self, k: K, v: V) -> Option<V> {
self.expiration_tick += 1;
let res = self.insert(k, v, None);
let now = self.clock.elapsed_seconds_since_creation();
if self.expiration_tick >= self.expiration_tick_cap {
self.drop_expired_entries_inner(now);
self.expiration_tick = 0;
}
res
}
pub fn insert_constant_unchecked(&mut self, k: K, v: V) -> Option<V> {
self.expiration_tick += 1;
self.insert(k, v, None)
}
#[inline(always)]
pub fn remove(&mut self, k: &K) -> Option<V> {
self.map
.remove(k)
.filter(|v| {
if let EntryStatus::ExpiresAtSeconds(expires_at_seconds) = v.status() {
self.expiries.remove(expires_at_seconds);
}
!v.is_expired(self.clock.elapsed_seconds_since_creation())
})
.map(|v| v.owned_value())
}
#[inline(always)]
pub fn remove_unchecked(&mut self, k: &K) -> Option<V> {
self.map
.remove(k)
.filter(|v| {
if let EntryStatus::ExpiresAtSeconds(expires_at_seconds) = v.status() {
self.expiries.remove(expires_at_seconds);
}
true
})
.map(|v| v.owned_value())
}
#[inline(always)]
pub fn clear(&mut self) {
self.map.clear()
}
#[inline(always)]
pub fn drop_expired_entries(&mut self) {
let now = self.clock.elapsed_seconds_since_creation();
self.drop_expired_entries_inner(now);
}
fn drop_expired_entries_inner(&mut self, now: u64) {
while let Some((exp, keys)) = self.expiries.pop_first() {
if exp > now {
self.expiries.insert(exp, keys);
break;
}
for key in keys {
self.map.remove(&key);
}
}
}
}
#[cfg(test)]
#[cfg(not(feature = "std"))]
mod tests {
use super::*;
struct MockClock {
current_time: u64,
}
impl Clock for MockClock {
fn elapsed_seconds_since_creation(&self) -> u64 {
self.current_time
}
}
#[test]
fn nostd_insert_and_get_constant_entry() {
let clock = MockClock { current_time: 1000 };
let mut map: TimedMap<MockClock, u32, &str> = TimedMap::new(clock);
map.insert_constant(1, "constant value");
assert_eq!(map.get(&1), Some(&"constant value"));
assert_eq!(map.get_remaining_duration(&1), None);
}
#[test]
fn nostd_insert_and_get_expirable_entry() {
let clock = MockClock { current_time: 1000 };
let mut map: TimedMap<MockClock, u32, &str> = TimedMap::new(clock);
let duration = Duration::from_secs(60);
map.insert_expirable(1, "expirable value", duration);
assert_eq!(map.get(&1), Some(&"expirable value"));
assert_eq!(map.get_remaining_duration(&1), Some(duration));
}
#[test]
fn nostd_expired_entry() {
let clock = MockClock { current_time: 1000 };
let mut map: TimedMap<MockClock, u32, &str> = TimedMap::new(clock);
let duration = Duration::from_secs(60);
map.insert_expirable(1, "expirable value", duration);
let clock = MockClock { current_time: 1070 };
map.clock = clock;
assert_eq!(map.get(&1), None);
assert_eq!(map.get_remaining_duration(&1), None);
}
#[test]
fn nostd_remove_entry() {
let clock = MockClock { current_time: 1000 };
let mut map: TimedMap<MockClock, u32, &str> = TimedMap::new(clock);
map.insert_constant(1, "constant value");
assert_eq!(map.remove(&1), Some("constant value"));
assert_eq!(map.get(&1), None);
}
#[test]
fn nostd_drop_expired_entries() {
let clock = MockClock { current_time: 1000 };
let mut map: TimedMap<MockClock, u32, &str> = TimedMap::new(clock);
map.insert_expirable(1, "expirable value1", Duration::from_secs(50));
map.insert_expirable(2, "expirable value2", Duration::from_secs(70));
map.insert_constant(3, "constant value");
let clock = MockClock { current_time: 1055 };
map.clock = clock;
assert_eq!(map.get(&1), None);
assert_eq!(map.get(&2), Some(&"expirable value2"));
assert_eq!(map.get(&3), Some(&"constant value"));
let clock = MockClock { current_time: 1071 };
map.clock = clock;
assert_eq!(map.get(&1), None);
assert_eq!(map.get(&2), None);
assert_eq!(map.get(&3), Some(&"constant value"));
}
#[test]
fn nostd_update_existing_entry() {
let clock = MockClock { current_time: 1000 };
let mut map: TimedMap<MockClock, u32, &str> = TimedMap::new(clock);
map.insert_constant(1, "initial value");
assert_eq!(map.get(&1), Some(&"initial value"));
map.insert_expirable(1, "updated value", Duration::from_secs(15));
assert_eq!(map.get(&1), Some(&"updated value"));
let clock = MockClock { current_time: 1016 };
map.clock = clock;
assert_eq!(map.get(&1), None);
}
}
#[cfg(feature = "std")]
#[cfg(test)]
mod std_tests {
use core::ops::Add;
use super::*;
#[test]
fn std_expirable_and_constant_entries() {
let mut map: TimedMap<StdClock, u32, &str> = TimedMap::new();
map.insert_constant(1, "constant value");
map.insert_expirable(2, "expirable value", Duration::from_secs(2));
assert_eq!(map.get(&1), Some(&"constant value"));
assert_eq!(map.get(&2), Some(&"expirable value"));
assert_eq!(map.get_remaining_duration(&1), None);
assert!(map.get_remaining_duration(&2).is_some());
}
#[test]
fn std_expired_entry_removal() {
let mut map: TimedMap<StdClock, u32, &str> = TimedMap::new();
let duration = Duration::from_secs(2);
map.insert_expirable(1, "expirable value", duration);
std::thread::sleep(Duration::from_secs(3));
assert_eq!(map.get(&1), None);
assert_eq!(map.get_remaining_duration(&1), None);
}
#[test]
fn std_remove_entry() {
let mut map: TimedMap<StdClock, _, _> = TimedMap::new();
map.insert_constant(1, "constant value");
map.insert_expirable(2, "expirable value", Duration::from_secs(2));
assert_eq!(map.remove(&1), Some("constant value"));
assert_eq!(map.remove(&2), Some("expirable value"));
assert_eq!(map.get(&1), None);
assert_eq!(map.get(&2), None);
}
#[test]
fn std_drop_expired_entries() {
let mut map: TimedMap<StdClock, u32, &str> = TimedMap::new();
map.insert_expirable(1, "expirable value1", Duration::from_secs(2));
map.insert_expirable(2, "expirable value2", Duration::from_secs(4));
std::thread::sleep(Duration::from_secs(3));
assert_eq!(map.get(&1), None);
assert_eq!(map.get(&2), Some(&"expirable value2"));
}
#[test]
fn std_update_existing_entry() {
let mut map: TimedMap<StdClock, u32, &str> = TimedMap::new();
map.insert_constant(1, "initial value");
assert_eq!(map.get(&1), Some(&"initial value"));
map.insert_expirable(1, "updated value", Duration::from_secs(1));
assert_eq!(map.get(&1), Some(&"updated value"));
std::thread::sleep(Duration::from_secs(2));
assert_eq!(map.get(&1), None);
}
#[test]
fn std_insert_constant_and_expirable_combined() {
let mut map: TimedMap<StdClock, u32, &str> = TimedMap::new();
map.insert_constant(1, "constant value");
map.insert_expirable(2, "expirable value", Duration::from_secs(2));
assert_eq!(map.get(&1), Some(&"constant value"));
assert_eq!(map.get(&2), Some(&"expirable value"));
std::thread::sleep(Duration::from_secs(3));
assert_eq!(map.get(&1), Some(&"constant value"));
assert_eq!(map.get(&2), None);
}
#[test]
fn std_expirable_entry_still_valid_before_expiration() {
let mut map: TimedMap<StdClock, u32, &str> = TimedMap::new();
map.insert_expirable(1, "expirable value", Duration::from_secs(3));
std::thread::sleep(Duration::from_secs(2));
assert_eq!(map.get(&1), Some(&"expirable value"));
assert!(map.get_remaining_duration(&1).unwrap().as_secs() == 1);
}
#[test]
fn std_length_functions() {
let mut map: TimedMap<StdClock, u32, &str> = TimedMap::new();
map.insert_expirable(1, "expirable value", Duration::from_secs(1));
map.insert_expirable(2, "expirable value", Duration::from_secs(1));
map.insert_expirable(3, "expirable value", Duration::from_secs(3));
map.insert_expirable(4, "expirable value", Duration::from_secs(3));
map.insert_expirable(5, "expirable value", Duration::from_secs(3));
map.insert_expirable(6, "expirable value", Duration::from_secs(3));
std::thread::sleep(Duration::from_secs(2).add(Duration::from_millis(1)));
assert_eq!(map.len(), 4);
assert_eq!(map.len_expired(), 2);
assert_eq!(map.len_unchecked(), 6);
}
}