weight_cache/
lib.rs

1//! A cache that holds a limited number of key-value pairs. When the capacity of
2//! the cache is exceeded, the least-recently-used (where "used" means a look-up
3//! or putting the pair into the cache) pair is automatically removed.
4//!
5//! Contrary to the [lru-cache](https://crates.io/crates/lru-cache) crate (which
6//! this crate is heavily inspired by!), the capacity is not the number of items
7//! in the cache, but can be given by an arbitrary criterion by implementing
8//! [`Weighable`] for the value type V. A straight-forward example of this would
9//! be to use the allocated size of the object, and provide a total capacity
10//! which must not be exceeded by the cache.
11//!
12//! # Examples
13//!```
14//! use weight_cache::{Weighable, WeightCache};
15//! use std::num::NonZeroUsize;
16//!
17//! #[derive(PartialEq, Debug)]
18//! enum Food {
19//!     Milk { milliliters: usize },
20//!     Cucumber { pieces: usize },
21//!     Meat { grams: usize },
22//!     Potato { pieces: usize },
23//!     Crab { grams: usize },
24//! }
25//!
26//! impl Weighable for Food {
27//!     fn measure(value: &Self) -> usize {
28//!         match value {
29//!             Food::Milk { milliliters } => milliliters * 104 / 100,
30//!             Food::Cucumber { pieces } => pieces * 158,
31//!             Food::Meat { grams } => *grams,
32//!             Food::Potato { pieces } => pieces * 175,
33//!             Food::Crab { grams } => *grams,
34//!         }
35//!     }
36//! }
37//!
38//! let mut cache = WeightCache::new(NonZeroUsize::new(500).unwrap());
39//!
40//! // Can't put too much in!
41//! assert!(cache.put(0, Food::Meat { grams: 600 }).is_err());
42//! assert!(cache.is_empty());
43//!
44//! cache.put(1, Food::Milk { milliliters: 100 }).unwrap();
45//! assert!(!cache.is_empty());
46//! assert_eq!(*cache.get(&1).unwrap(), Food::Milk { milliliters: 100 });
47//!
48//! cache.put(2, Food::Crab { grams: 300 }).unwrap();
49//! assert_eq!(*cache.get(&2).unwrap(), Food::Crab { grams: 300 });
50//! assert_eq!(*cache.get(&1).unwrap(), Food::Milk { milliliters: 100 });
51//!
52//! cache.put(3, Food::Potato { pieces: 2 }).unwrap();
53//! assert_eq!(*cache.get(&3).unwrap(), Food::Potato { pieces: 2});
54//! assert!(cache.get(&2).is_none()); // 1 has been touched last
55//! assert_eq!(*cache.get(&1).unwrap(), Food::Milk { milliliters: 100 });
56//!```
57//!
58//! # Feature flags
59//!
60//! * `metrics`: Enables metric gathering on the cache. Register a
61//! [`prometheus::Registry`] with a call to [`WeightCache::register`]; set a
62//! custom metric namespace with [`WeightCache::new_with_namespace`]
63
64use hash_map::RandomState;
65use linked_hash_map::LinkedHashMap;
66#[cfg(feature = "metrics")]
67use prometheus::{
68    core::{AtomicU64, GenericCounter, GenericGauge},
69    Opts, Registry,
70};
71use std::{
72    collections::hash_map,
73    fmt,
74    hash::{BuildHasher, Hash},
75    num::NonZeroUsize,
76};
77
78/// A trait to implemented for the value type, providing a way to
79/// [`Weighable::measure`] the thing.
80pub trait Weighable {
81    fn measure(value: &Self) -> usize;
82}
83
84#[derive(Debug)]
85/// An error indicating that the to-be-inserted value is bigger than the max size
86/// of the cache.
87pub struct ValueTooBigError;
88impl std::error::Error for ValueTooBigError {}
89impl fmt::Display for ValueTooBigError {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        write!(
92            f,
93            "Value is bigger than the configured max size of the cache"
94        )
95    }
96}
97
98struct ValueWithWeight<V> {
99    value: V,
100    weight: usize,
101}
102/// A cache that holds a limited number of key-value pairs. When the capacity of
103/// the cache is exceeded, the least-recently-used (where "used" means a look-up
104/// or putting the pair into the cache) pairs are automatically removed until the
105/// size limit is met again.
106pub struct WeightCache<K, V, S = hash_map::RandomState> {
107    max: usize,
108    current: usize,
109    inner: LinkedHashMap<K, ValueWithWeight<V>, S>,
110    #[cfg(feature = "metrics")]
111    metrics: Metrics,
112}
113impl<K, V, S> fmt::Debug for WeightCache<K, V, S> {
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        f.debug_struct("WeightCache")
116            .field("max", &self.max)
117            .field("current", &self.current)
118            .finish()
119    }
120}
121
122impl<K: Hash + Eq, V: Weighable> Default for WeightCache<K, V> {
123    fn default() -> Self {
124        WeightCache::<K, V, RandomState>::new(
125            NonZeroUsize::new(usize::max_value()).expect("MAX > 0"),
126        )
127    }
128}
129
130#[cfg(feature = "metrics")]
131struct Metrics {
132    hits: GenericCounter<AtomicU64>,
133    misses: GenericCounter<AtomicU64>,
134    inserts: GenericCounter<AtomicU64>,
135    inserts_fail: GenericCounter<AtomicU64>,
136    size: GenericGauge<AtomicU64>,
137}
138#[cfg(feature = "metrics")]
139impl Metrics {
140    fn new(namespace: Option<&str>) -> Self {
141        if let Some(namespace) = namespace {
142            let cache_size = GenericGauge::with_opts(
143                Opts::new("cache_size", "Current size of the cache").namespace(namespace),
144            )
145            .unwrap();
146            cache_size.set(0);
147            Self {
148                hits: GenericCounter::with_opts(
149                    Opts::new("cache_hit", "Number of cache hits").namespace(namespace),
150                )
151                .unwrap(),
152                misses: GenericCounter::with_opts(
153                    Opts::new("cache_miss", "Number of cache misses").namespace(namespace),
154                )
155                .unwrap(),
156                inserts: GenericCounter::with_opts(
157                    Opts::new("cache_insert", "Number of successful cache insertions")
158                        .namespace(namespace),
159                )
160                .unwrap(),
161                inserts_fail: GenericCounter::with_opts(
162                    Opts::new("cache_insert_fail", "Number of failed cache insertions")
163                        .namespace(namespace),
164                )
165                .unwrap(),
166                size: cache_size,
167            }
168        } else {
169            let cache_size = GenericGauge::new("cache_size", "Current size of the cache").unwrap();
170            cache_size.set(0);
171            Self {
172                hits: GenericCounter::new("cache_hit", "Number of cache hits").unwrap(),
173                misses: GenericCounter::new("cache_miss", "Number of cache misses").unwrap(),
174                inserts: GenericCounter::new(
175                    "cache_insert",
176                    "Number of successful cache insertions",
177                )
178                .unwrap(),
179                inserts_fail: GenericCounter::new(
180                    "cache_insert_fail",
181                    "Number of failed cache insertions",
182                )
183                .unwrap(),
184                size: cache_size,
185            }
186        }
187    }
188}
189
190impl<K: Hash + Eq, V: Weighable> WeightCache<K, V> {
191    pub fn new(capacity: NonZeroUsize) -> Self {
192        Self {
193            max: capacity.get(),
194            current: 0,
195            inner: LinkedHashMap::new(),
196            #[cfg(feature = "metrics")]
197            metrics: Metrics::new(None),
198        }
199    }
200    #[cfg(feature = "metrics")]
201    pub fn new_with_namespace(capacity: NonZeroUsize, metrics_namespace: Option<&str>) -> Self {
202        Self {
203            max: capacity.get(),
204            current: 0,
205            inner: LinkedHashMap::new(),
206            metrics: Metrics::new(metrics_namespace),
207        }
208    }
209}
210impl<K: Hash + Eq, V: Weighable, S: BuildHasher> WeightCache<K, V, S> {
211    /// Create a [`WeightCache`] with a custom hasher.
212    pub fn with_hasher(capacity: NonZeroUsize, hasher: S) -> Self {
213        Self {
214            max: capacity.get(),
215            current: 0,
216            inner: LinkedHashMap::with_hasher(hasher),
217            #[cfg(feature = "metrics")]
218            metrics: Metrics::new(None),
219        }
220    }
221    #[cfg(feature = "metrics")]
222    /// Registers metrics with a [`prometheus::Registry`]
223    pub fn register(&self, registry: &Registry) -> Result<(), prometheus::Error> {
224        registry.register(Box::new(self.metrics.hits.clone()))?;
225        registry.register(Box::new(self.metrics.misses.clone()))?;
226        registry.register(Box::new(self.metrics.inserts.clone()))?;
227        registry.register(Box::new(self.metrics.inserts_fail.clone()))?;
228        registry.register(Box::new(self.metrics.size.clone()))?;
229        Ok(())
230    }
231
232    /// Returns a reference to the value corresponding to the given key, if it
233    /// exists.
234    pub fn get(&mut self, k: &K) -> Option<&V> {
235        if let Some(v) = self.inner.get_refresh(k) {
236            #[cfg(feature = "metrics")]
237            self.metrics.hits.inc();
238            Some(&v.value as &V)
239        } else {
240            #[cfg(feature = "metrics")]
241            self.metrics.misses.inc();
242            None
243        }
244    }
245
246    /// Returns the number of key-value pairs in the cache.
247    pub fn len(&self) -> usize {
248        self.inner.len()
249    }
250
251    /// Returns `true` if the cache contains no key-value pairs.
252    pub fn is_empty(&self) -> bool {
253        self.inner.is_empty()
254    }
255
256    /// Inserts a key-value pair into the cache. Returns an error if the value is
257    /// bigger than the cache's configured max size.
258    pub fn put(&mut self, key: K, value: V) -> Result<(), ValueTooBigError> {
259        let weight = V::measure(&value);
260        if weight > self.max {
261            #[cfg(feature = "metrics")]
262            self.metrics.inserts_fail.inc();
263            Err(ValueTooBigError)
264        } else {
265            self.current += weight;
266            // did we remove an element?
267            if let Some(x) = self.inner.insert(key, ValueWithWeight { value, weight }) {
268                self.current -= x.weight;
269            }
270
271            // remove elements until we're below the size boundary again
272            self.shrink_to_fit();
273            #[cfg(feature = "metrics")]
274            self.metrics.inserts.inc();
275            Ok(())
276        }
277    }
278
279    fn shrink_to_fit(&mut self) {
280        while self.current > self.max && !self.inner.is_empty() {
281            let (_, v) = self.inner.pop_front().expect("Not empty");
282            self.current -= v.weight;
283        }
284        #[cfg(feature = "metrics")]
285        self.metrics.size.set(self.current as u64);
286    }
287}
288
289impl<K: Hash + Eq + 'static, V: Weighable + 'static, S: BuildHasher> WeightCache<K, V, S> {
290    /// Returns an iterator over the cache's key-value pairs in least- to
291    /// most-recently-used order consuming the cache.
292    pub fn consume(self) -> Box<dyn Iterator<Item = (K, V)> + 'static> {
293        #[cfg(feature = "metrics")]
294        self.metrics.size.set(0);
295        Box::new(self.inner.into_iter().map(|(k, v)| (k, v.value)))
296    }
297}
298
299#[cfg(test)]
300mod test {
301    use std::convert::TryInto;
302
303    use super::*;
304    use quickcheck::{Arbitrary, Gen};
305    use quickcheck_macros::quickcheck;
306
307    #[derive(Clone, Debug, PartialEq)]
308    struct HeavyWeight(usize);
309    impl Weighable for HeavyWeight {
310        fn measure(v: &Self) -> usize {
311            v.0
312        }
313    }
314    impl Arbitrary for HeavyWeight {
315        fn arbitrary(g: &mut Gen) -> Self {
316            Self(usize::arbitrary(g))
317        }
318        fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
319            Box::new(usize::shrink(&self.0).map(HeavyWeight))
320        }
321    }
322    #[derive(Clone, Debug, PartialEq)]
323    struct UnitWeight;
324    impl Weighable for UnitWeight {
325        fn measure(_: &Self) -> usize {
326            1
327        }
328    }
329    impl Arbitrary for UnitWeight {
330        fn arbitrary(_: &mut Gen) -> Self {
331            Self
332        }
333    }
334
335    #[test]
336    fn should_not_evict_under_max_size() {
337        let xs: Vec<_> = (0..10000).map(HeavyWeight).collect();
338        let mut cache =
339            WeightCache::<usize, HeavyWeight>::new(usize::max_value().try_into().unwrap());
340        for (k, v) in xs.iter().enumerate() {
341            cache.put(k, v.clone()).expect("empty")
342        }
343        let cached = cache.consume().map(|x| x.1).collect::<Vec<_>>();
344
345        assert_eq!(xs, cached);
346    }
347
348    #[cfg(feature = "metrics")]
349    fn metrics_test(namespace: Option<&str>) {
350        let mut cache =
351            WeightCache::<usize, UnitWeight>::new_with_namespace(3.try_into().unwrap(), namespace);
352        let registry = Registry::new();
353        cache.register(&registry).unwrap();
354        for i in 0usize..5 {
355            cache.put(i, UnitWeight).unwrap();
356        }
357        for i in 0usize..5 {
358            cache.get(&i);
359        }
360        for metric in registry.gather() {
361            println!("{} {:?}", metric.get_name(), metric.get_metric()[0]);
362            match metric.get_name() {
363                x if x
364                    == format!(
365                        "{}cache_size",
366                        namespace.map(|y| format!("{}_", y)).unwrap_or_default()
367                    ) =>
368                {
369                    assert_eq!(3, metric.get_metric()[0].get_gauge().get_value() as usize)
370                }
371
372                x if x
373                    == format!(
374                        "{}cache_insert",
375                        namespace.map(|y| format!("{}_", y)).unwrap_or_default()
376                    ) =>
377                {
378                    assert_eq!(5, metric.get_metric()[0].get_counter().get_value() as usize)
379                }
380
381                x if x
382                    == format!(
383                        "{}cache_insert_fail",
384                        namespace.map(|y| format!("{}_", y)).unwrap_or_default()
385                    ) =>
386                {
387                    assert_eq!(0, metric.get_metric()[0].get_counter().get_value() as usize)
388                }
389
390                x if x
391                    == format!(
392                        "{}cache_hit",
393                        namespace.map(|y| format!("{}_", y)).unwrap_or_default()
394                    ) =>
395                {
396                    assert_eq!(3, metric.get_metric()[0].get_counter().get_value() as usize)
397                }
398
399                x if x
400                    == format!(
401                        "{}cache_miss",
402                        namespace.map(|y| format!("{}_", y)).unwrap_or_default()
403                    ) =>
404                {
405                    assert_eq!(2, metric.get_metric()[0].get_counter().get_value() as usize)
406                }
407                x => panic!("unknown metrics {}", x),
408            }
409        }
410    }
411
412    #[cfg(feature = "metrics")]
413    #[test]
414    fn should_gather_metrics() {
415        metrics_test(None);
416        metrics_test(Some("test"));
417    }
418
419    #[quickcheck]
420    fn should_reject_too_heavy_values(total_size: NonZeroUsize, input: HeavyWeight) -> bool {
421        let mut cache = WeightCache::<usize, HeavyWeight>::new(total_size);
422        let res = cache.put(42, input.clone());
423        match res {
424            Ok(_) if input.0 < total_size.get() => true,
425            Err(_) if input.0 >= total_size.get() => true,
426            _ => false,
427        }
428    }
429
430    #[quickcheck]
431    fn should_evict_once_the_size_target_is_hit(
432        input: Vec<UnitWeight>,
433        max_size: NonZeroUsize,
434    ) -> bool {
435        let mut cache_size = 0usize;
436        let mut cache = WeightCache::<usize, UnitWeight>::new(max_size);
437        for (k, v) in input.into_iter().enumerate() {
438            let weight = UnitWeight::measure(&v);
439            cache_size += weight;
440            let len_before = cache.len();
441            cache.put(k, v).unwrap();
442            let len_after = cache.len();
443            if cache_size > max_size.get() {
444                assert_eq!(len_before, len_after);
445                cache_size -= weight;
446            } else {
447                assert_eq!(len_before + 1, len_after);
448            }
449        }
450
451        true
452    }
453}