1use hash_map::RandomState;
65use linked_hash_map::LinkedHashMap;
66#[cfg(feature = "metrics")]
67use prometheus::{
68 core::{AtomicU64, GenericCounter, GenericGauge},
69 Opts, Registry,
70};
71use std::{
72 collections::hash_map,
73 fmt,
74 hash::{BuildHasher, Hash},
75 num::NonZeroUsize,
76};
77
78pub trait Weighable {
81 fn measure(value: &Self) -> usize;
82}
83
84#[derive(Debug)]
85pub struct ValueTooBigError;
88impl std::error::Error for ValueTooBigError {}
89impl fmt::Display for ValueTooBigError {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 write!(
92 f,
93 "Value is bigger than the configured max size of the cache"
94 )
95 }
96}
97
98struct ValueWithWeight<V> {
99 value: V,
100 weight: usize,
101}
102pub struct WeightCache<K, V, S = hash_map::RandomState> {
107 max: usize,
108 current: usize,
109 inner: LinkedHashMap<K, ValueWithWeight<V>, S>,
110 #[cfg(feature = "metrics")]
111 metrics: Metrics,
112}
113impl<K, V, S> fmt::Debug for WeightCache<K, V, S> {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 f.debug_struct("WeightCache")
116 .field("max", &self.max)
117 .field("current", &self.current)
118 .finish()
119 }
120}
121
122impl<K: Hash + Eq, V: Weighable> Default for WeightCache<K, V> {
123 fn default() -> Self {
124 WeightCache::<K, V, RandomState>::new(
125 NonZeroUsize::new(usize::max_value()).expect("MAX > 0"),
126 )
127 }
128}
129
130#[cfg(feature = "metrics")]
131struct Metrics {
132 hits: GenericCounter<AtomicU64>,
133 misses: GenericCounter<AtomicU64>,
134 inserts: GenericCounter<AtomicU64>,
135 inserts_fail: GenericCounter<AtomicU64>,
136 size: GenericGauge<AtomicU64>,
137}
138#[cfg(feature = "metrics")]
139impl Metrics {
140 fn new(namespace: Option<&str>) -> Self {
141 if let Some(namespace) = namespace {
142 let cache_size = GenericGauge::with_opts(
143 Opts::new("cache_size", "Current size of the cache").namespace(namespace),
144 )
145 .unwrap();
146 cache_size.set(0);
147 Self {
148 hits: GenericCounter::with_opts(
149 Opts::new("cache_hit", "Number of cache hits").namespace(namespace),
150 )
151 .unwrap(),
152 misses: GenericCounter::with_opts(
153 Opts::new("cache_miss", "Number of cache misses").namespace(namespace),
154 )
155 .unwrap(),
156 inserts: GenericCounter::with_opts(
157 Opts::new("cache_insert", "Number of successful cache insertions")
158 .namespace(namespace),
159 )
160 .unwrap(),
161 inserts_fail: GenericCounter::with_opts(
162 Opts::new("cache_insert_fail", "Number of failed cache insertions")
163 .namespace(namespace),
164 )
165 .unwrap(),
166 size: cache_size,
167 }
168 } else {
169 let cache_size = GenericGauge::new("cache_size", "Current size of the cache").unwrap();
170 cache_size.set(0);
171 Self {
172 hits: GenericCounter::new("cache_hit", "Number of cache hits").unwrap(),
173 misses: GenericCounter::new("cache_miss", "Number of cache misses").unwrap(),
174 inserts: GenericCounter::new(
175 "cache_insert",
176 "Number of successful cache insertions",
177 )
178 .unwrap(),
179 inserts_fail: GenericCounter::new(
180 "cache_insert_fail",
181 "Number of failed cache insertions",
182 )
183 .unwrap(),
184 size: cache_size,
185 }
186 }
187 }
188}
189
190impl<K: Hash + Eq, V: Weighable> WeightCache<K, V> {
191 pub fn new(capacity: NonZeroUsize) -> Self {
192 Self {
193 max: capacity.get(),
194 current: 0,
195 inner: LinkedHashMap::new(),
196 #[cfg(feature = "metrics")]
197 metrics: Metrics::new(None),
198 }
199 }
200 #[cfg(feature = "metrics")]
201 pub fn new_with_namespace(capacity: NonZeroUsize, metrics_namespace: Option<&str>) -> Self {
202 Self {
203 max: capacity.get(),
204 current: 0,
205 inner: LinkedHashMap::new(),
206 metrics: Metrics::new(metrics_namespace),
207 }
208 }
209}
210impl<K: Hash + Eq, V: Weighable, S: BuildHasher> WeightCache<K, V, S> {
211 pub fn with_hasher(capacity: NonZeroUsize, hasher: S) -> Self {
213 Self {
214 max: capacity.get(),
215 current: 0,
216 inner: LinkedHashMap::with_hasher(hasher),
217 #[cfg(feature = "metrics")]
218 metrics: Metrics::new(None),
219 }
220 }
221 #[cfg(feature = "metrics")]
222 pub fn register(&self, registry: &Registry) -> Result<(), prometheus::Error> {
224 registry.register(Box::new(self.metrics.hits.clone()))?;
225 registry.register(Box::new(self.metrics.misses.clone()))?;
226 registry.register(Box::new(self.metrics.inserts.clone()))?;
227 registry.register(Box::new(self.metrics.inserts_fail.clone()))?;
228 registry.register(Box::new(self.metrics.size.clone()))?;
229 Ok(())
230 }
231
232 pub fn get(&mut self, k: &K) -> Option<&V> {
235 if let Some(v) = self.inner.get_refresh(k) {
236 #[cfg(feature = "metrics")]
237 self.metrics.hits.inc();
238 Some(&v.value as &V)
239 } else {
240 #[cfg(feature = "metrics")]
241 self.metrics.misses.inc();
242 None
243 }
244 }
245
246 pub fn len(&self) -> usize {
248 self.inner.len()
249 }
250
251 pub fn is_empty(&self) -> bool {
253 self.inner.is_empty()
254 }
255
256 pub fn put(&mut self, key: K, value: V) -> Result<(), ValueTooBigError> {
259 let weight = V::measure(&value);
260 if weight > self.max {
261 #[cfg(feature = "metrics")]
262 self.metrics.inserts_fail.inc();
263 Err(ValueTooBigError)
264 } else {
265 self.current += weight;
266 if let Some(x) = self.inner.insert(key, ValueWithWeight { value, weight }) {
268 self.current -= x.weight;
269 }
270
271 self.shrink_to_fit();
273 #[cfg(feature = "metrics")]
274 self.metrics.inserts.inc();
275 Ok(())
276 }
277 }
278
279 fn shrink_to_fit(&mut self) {
280 while self.current > self.max && !self.inner.is_empty() {
281 let (_, v) = self.inner.pop_front().expect("Not empty");
282 self.current -= v.weight;
283 }
284 #[cfg(feature = "metrics")]
285 self.metrics.size.set(self.current as u64);
286 }
287}
288
289impl<K: Hash + Eq + 'static, V: Weighable + 'static, S: BuildHasher> WeightCache<K, V, S> {
290 pub fn consume(self) -> Box<dyn Iterator<Item = (K, V)> + 'static> {
293 #[cfg(feature = "metrics")]
294 self.metrics.size.set(0);
295 Box::new(self.inner.into_iter().map(|(k, v)| (k, v.value)))
296 }
297}
298
299#[cfg(test)]
300mod test {
301 use std::convert::TryInto;
302
303 use super::*;
304 use quickcheck::{Arbitrary, Gen};
305 use quickcheck_macros::quickcheck;
306
307 #[derive(Clone, Debug, PartialEq)]
308 struct HeavyWeight(usize);
309 impl Weighable for HeavyWeight {
310 fn measure(v: &Self) -> usize {
311 v.0
312 }
313 }
314 impl Arbitrary for HeavyWeight {
315 fn arbitrary(g: &mut Gen) -> Self {
316 Self(usize::arbitrary(g))
317 }
318 fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
319 Box::new(usize::shrink(&self.0).map(HeavyWeight))
320 }
321 }
322 #[derive(Clone, Debug, PartialEq)]
323 struct UnitWeight;
324 impl Weighable for UnitWeight {
325 fn measure(_: &Self) -> usize {
326 1
327 }
328 }
329 impl Arbitrary for UnitWeight {
330 fn arbitrary(_: &mut Gen) -> Self {
331 Self
332 }
333 }
334
335 #[test]
336 fn should_not_evict_under_max_size() {
337 let xs: Vec<_> = (0..10000).map(HeavyWeight).collect();
338 let mut cache =
339 WeightCache::<usize, HeavyWeight>::new(usize::max_value().try_into().unwrap());
340 for (k, v) in xs.iter().enumerate() {
341 cache.put(k, v.clone()).expect("empty")
342 }
343 let cached = cache.consume().map(|x| x.1).collect::<Vec<_>>();
344
345 assert_eq!(xs, cached);
346 }
347
348 #[cfg(feature = "metrics")]
349 fn metrics_test(namespace: Option<&str>) {
350 let mut cache =
351 WeightCache::<usize, UnitWeight>::new_with_namespace(3.try_into().unwrap(), namespace);
352 let registry = Registry::new();
353 cache.register(®istry).unwrap();
354 for i in 0usize..5 {
355 cache.put(i, UnitWeight).unwrap();
356 }
357 for i in 0usize..5 {
358 cache.get(&i);
359 }
360 for metric in registry.gather() {
361 println!("{} {:?}", metric.get_name(), metric.get_metric()[0]);
362 match metric.get_name() {
363 x if x
364 == format!(
365 "{}cache_size",
366 namespace.map(|y| format!("{}_", y)).unwrap_or_default()
367 ) =>
368 {
369 assert_eq!(3, metric.get_metric()[0].get_gauge().get_value() as usize)
370 }
371
372 x if x
373 == format!(
374 "{}cache_insert",
375 namespace.map(|y| format!("{}_", y)).unwrap_or_default()
376 ) =>
377 {
378 assert_eq!(5, metric.get_metric()[0].get_counter().get_value() as usize)
379 }
380
381 x if x
382 == format!(
383 "{}cache_insert_fail",
384 namespace.map(|y| format!("{}_", y)).unwrap_or_default()
385 ) =>
386 {
387 assert_eq!(0, metric.get_metric()[0].get_counter().get_value() as usize)
388 }
389
390 x if x
391 == format!(
392 "{}cache_hit",
393 namespace.map(|y| format!("{}_", y)).unwrap_or_default()
394 ) =>
395 {
396 assert_eq!(3, metric.get_metric()[0].get_counter().get_value() as usize)
397 }
398
399 x if x
400 == format!(
401 "{}cache_miss",
402 namespace.map(|y| format!("{}_", y)).unwrap_or_default()
403 ) =>
404 {
405 assert_eq!(2, metric.get_metric()[0].get_counter().get_value() as usize)
406 }
407 x => panic!("unknown metrics {}", x),
408 }
409 }
410 }
411
412 #[cfg(feature = "metrics")]
413 #[test]
414 fn should_gather_metrics() {
415 metrics_test(None);
416 metrics_test(Some("test"));
417 }
418
419 #[quickcheck]
420 fn should_reject_too_heavy_values(total_size: NonZeroUsize, input: HeavyWeight) -> bool {
421 let mut cache = WeightCache::<usize, HeavyWeight>::new(total_size);
422 let res = cache.put(42, input.clone());
423 match res {
424 Ok(_) if input.0 < total_size.get() => true,
425 Err(_) if input.0 >= total_size.get() => true,
426 _ => false,
427 }
428 }
429
430 #[quickcheck]
431 fn should_evict_once_the_size_target_is_hit(
432 input: Vec<UnitWeight>,
433 max_size: NonZeroUsize,
434 ) -> bool {
435 let mut cache_size = 0usize;
436 let mut cache = WeightCache::<usize, UnitWeight>::new(max_size);
437 for (k, v) in input.into_iter().enumerate() {
438 let weight = UnitWeight::measure(&v);
439 cache_size += weight;
440 let len_before = cache.len();
441 cache.put(k, v).unwrap();
442 let len_after = cache.len();
443 if cache_size > max_size.get() {
444 assert_eq!(len_before, len_after);
445 cache_size -= weight;
446 } else {
447 assert_eq!(len_before + 1, len_after);
448 }
449 }
450
451 true
452 }
453}