use hash_map::RandomState;
use linked_hash_map::LinkedHashMap;
use std::{
collections::hash_map,
fmt,
hash::{BuildHasher, Hash},
num::NonZeroUsize,
};
pub trait Weigheable {
fn measure(value: &Self) -> usize;
}
#[derive(Debug)]
pub struct ValueTooBigError;
impl std::error::Error for ValueTooBigError {}
impl fmt::Display for ValueTooBigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Value is bigger than the configured max size of the cache"
)
}
}
pub struct WeightCache<K, V, S = hash_map::RandomState> {
max: usize,
current: usize,
inner: LinkedHashMap<K, V, S>,
}
impl<K, V, S> fmt::Debug for WeightCache<K, V, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WeightCache")
.field("max", &self.max)
.field("current", &self.current)
.finish()
}
}
impl<K: Hash + Eq, V: Weigheable> Default for WeightCache<K, V> {
fn default() -> Self {
WeightCache::<K, V, RandomState>::new(NonZeroUsize::new(usize::MAX).expect("MAX > 0"))
}
}
impl<K: Hash + Eq, V: Weigheable> WeightCache<K, V> {
pub fn new(capacity: NonZeroUsize) -> Self {
Self {
max: capacity.get(),
current: 0,
inner: LinkedHashMap::new(),
}
}
}
impl<K: Hash + Eq, V: Weigheable, S: BuildHasher> WeightCache<K, V, S> {
pub fn with_hasher(capacity: NonZeroUsize, hasher: S) -> Self {
Self {
max: capacity.get(),
current: 0,
inner: LinkedHashMap::with_hasher(hasher),
}
}
pub fn get(&mut self, k: &K) -> Option<&V> {
self.inner.get_refresh(k).map(|v| v as &V)
}
pub fn get_mut(&mut self, k: &K) -> Option<&mut V> {
self.inner.get_refresh(k)
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn put(&mut self, key: K, value: V) -> Result<(), ValueTooBigError> {
let weight = V::measure(&value);
if weight > self.max {
Err(ValueTooBigError)
} else {
self.current += weight;
if let Some(x) = self.inner.insert(key, value) {
self.current -= V::measure(&x);
}
while self.current > self.max && !self.inner.is_empty() {
if let Some((_, v)) = self.inner.pop_front() {
self.current -= V::measure(&v);
}
}
self.shrink_to_fit();
Ok(())
}
}
pub fn iter(&self) -> Iter<K, V> {
Iter(self.inner.iter())
}
pub fn iter_mut(&mut self) -> IterMut<K, V> {
IterMut(self.inner.iter_mut())
}
fn shrink_to_fit(&mut self) {
while self.current > self.max && !self.inner.is_empty() {
let (_, v) = self.inner.pop_front().expect("Not empty");
self.current -= V::measure(&v);
}
}
}
impl<K: Hash + Eq, V, S: BuildHasher + Default> IntoIterator for WeightCache<K, V, S> {
type Item = (K, V);
type IntoIter = IntoIter<K, V>;
fn into_iter(self) -> Self::IntoIter {
IntoIter(self.inner.into_iter())
}
}
#[derive(Clone)]
pub struct IntoIter<K, V>(linked_hash_map::IntoIter<K, V>);
impl<K, V> Iterator for IntoIter<K, V> {
type Item = (K, V);
fn next(&mut self) -> Option<(K, V)> {
self.0.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl<K, V> DoubleEndedIterator for IntoIter<K, V> {
fn next_back(&mut self) -> Option<(K, V)> {
self.0.next_back()
}
}
impl<K, V> ExactSizeIterator for IntoIter<K, V> {
fn len(&self) -> usize {
self.0.len()
}
}
pub struct Iter<'a, K: 'a, V: 'a>(linked_hash_map::Iter<'a, K, V>);
impl<'a, K, V> Clone for Iter<'a, K, V> {
fn clone(&self) -> Iter<'a, K, V> {
Iter(self.0.clone())
}
}
impl<'a, K, V> Iterator for Iter<'a, K, V> {
type Item = (&'a K, &'a V);
fn next(&mut self) -> Option<(&'a K, &'a V)> {
self.0.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl<'a, K, V> DoubleEndedIterator for Iter<'a, K, V> {
fn next_back(&mut self) -> Option<(&'a K, &'a V)> {
self.0.next_back()
}
}
impl<'a, K, V> ExactSizeIterator for Iter<'a, K, V> {
fn len(&self) -> usize {
self.0.len()
}
}
pub struct IterMut<'a, K: 'a, V: 'a>(linked_hash_map::IterMut<'a, K, V>);
impl<'a, K, V> Iterator for IterMut<'a, K, V> {
type Item = (&'a K, &'a mut V);
fn next(&mut self) -> Option<(&'a K, &'a mut V)> {
self.0.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl<'a, K, V> DoubleEndedIterator for IterMut<'a, K, V> {
fn next_back(&mut self) -> Option<(&'a K, &'a mut V)> {
self.0.next_back()
}
}
impl<'a, K, V> ExactSizeIterator for IterMut<'a, K, V> {
fn len(&self) -> usize {
self.0.len()
}
}
#[cfg(test)]
mod test {
use std::convert::TryInto;
use super::*;
use quickcheck::{Arbitrary, Gen};
use quickcheck_macros::quickcheck;
#[derive(Clone, Debug, PartialEq)]
struct HeavyWeight(usize);
impl Weigheable for HeavyWeight {
fn measure(v: &Self) -> usize {
v.0
}
}
impl Arbitrary for HeavyWeight {
fn arbitrary(g: &mut Gen) -> Self {
Self(usize::arbitrary(g))
}
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
Box::new(usize::shrink(&self.0).map(HeavyWeight))
}
}
#[derive(Clone, Debug, PartialEq)]
struct UnitWeight;
impl Weigheable for UnitWeight {
fn measure(_: &Self) -> usize {
1
}
}
impl Arbitrary for UnitWeight {
fn arbitrary(_: &mut Gen) -> Self {
Self
}
}
#[test]
fn should_not_evict_under_max_size() {
let xs: Vec<_> = (0..10000).map(HeavyWeight).collect();
let mut cache = WeightCache::<usize, HeavyWeight>::new(usize::MAX.try_into().unwrap());
for (k, v) in xs.iter().enumerate() {
cache.put(k, v.clone()).expect("empty")
}
let cached = cache.into_iter().map(|x| x.1).collect::<Vec<_>>();
assert_eq!(xs, cached);
}
#[quickcheck]
fn should_reject_too_heavy_values(total_size: NonZeroUsize, input: HeavyWeight) -> bool {
let mut cache = WeightCache::<usize, HeavyWeight>::new(total_size);
let res = cache.put(42, input.clone());
match res {
Ok(_) if input.0 < total_size.get() => true,
Err(_) if input.0 >= total_size.get() => true,
_ => false,
}
}
#[quickcheck]
fn should_evict_once_the_size_target_is_hit(
input: Vec<UnitWeight>,
max_size: NonZeroUsize,
) -> bool {
let mut cache_size = 0usize;
let mut cache = WeightCache::<usize, UnitWeight>::new(max_size);
for (k, v) in input.into_iter().enumerate() {
let weight = UnitWeight::measure(&v);
cache_size += weight;
let len_before = cache.len();
cache.put(k, v).unwrap();
let len_after = cache.len();
if cache_size > max_size.get() {
assert_eq!(len_before, len_after);
cache_size -= weight;
} else {
assert_eq!(len_before + 1, len_after);
}
}
true
}
}