streaming_algorithms/
count_min.rs

1// This file includes source code from https://github.com/jedisct1/rust-count-min-sketch/blob/088274e22a3decc986dec928c92cc90a709a0274/src/lib.rs under the following MIT License:
2
3// Copyright (c) 2016 Frank Denis
4
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11
12// The above copyright notice and this permission notice shall be included in all
13// copies or substantial portions of the Software.
14
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23use serde::{Deserialize, Serialize};
24use std::{
25	borrow::Borrow, cmp::max, convert::TryFrom, fmt, hash::{Hash, Hasher}, marker::PhantomData, ops
26};
27use twox_hash::XxHash;
28
29use super::f64_to_usize;
30use crate::traits::{Intersect, IntersectPlusUnionIsPlus, New, UnionAssign};
31
32/// An implementation of a [count-min sketch](https://en.wikipedia.org/wiki/Count–min_sketch) data structure with *conservative updating* for increased accuracy.
33///
34/// This data structure is also known as a [counting Bloom filter](https://en.wikipedia.org/wiki/Bloom_filter#Counting_filters).
35///
36/// See [*An Improved Data Stream Summary: The Count-Min Sketch and its Applications*](http://dimacs.rutgers.edu/~graham/pubs/papers/cm-full.pdf) and [*New Directions in Traffic Measurement and Accounting*](http://pages.cs.wisc.edu/~suman/courses/740/papers/estan03tocs.pdf) for background on the count-min sketch with conservative updating.
37#[derive(Serialize, Deserialize)]
38#[serde(bound(
39	serialize = "C: Serialize, <C as New>::Config: Serialize",
40	deserialize = "C: Deserialize<'de>, <C as New>::Config: Deserialize<'de>"
41))]
42pub struct CountMinSketch<K: ?Sized, C: New> {
43	counters: Vec<Vec<C>>,
44	offsets: Vec<usize>, // to avoid malloc/free each push
45	mask: usize,
46	k_num: usize,
47	config: <C as New>::Config,
48	marker: PhantomData<fn(K)>,
49}
50
51impl<K: ?Sized, C> CountMinSketch<K, C>
52where
53	K: Hash,
54	C: New + for<'a> UnionAssign<&'a C> + Intersect,
55{
56	/// Create an empty `CountMinSketch` data structure with the specified error tolerance.
57	pub fn new(probability: f64, tolerance: f64, config: C::Config) -> Self {
58		let width = Self::optimal_width(tolerance);
59		let k_num = Self::optimal_k_num(probability);
60		let counters: Vec<Vec<C>> = (0..k_num)
61			.map(|_| (0..width).map(|_| C::new(&config)).collect())
62			.collect();
63		let offsets = vec![0; k_num];
64		Self {
65			counters,
66			offsets,
67			mask: Self::mask(width),
68			k_num,
69			config,
70			marker: PhantomData,
71		}
72	}
73
74	/// "Visit" an element.
75	pub fn push<Q: ?Sized, V: ?Sized>(&mut self, key: &Q, value: &V) -> C
76	where
77		Q: Hash,
78		K: Borrow<Q>,
79		C: for<'a> ops::AddAssign<&'a V> + IntersectPlusUnionIsPlus,
80	{
81		if !<C as IntersectPlusUnionIsPlus>::VAL {
82			let offsets = self.offsets(key);
83			self.offsets
84				.iter_mut()
85				.zip(offsets)
86				.for_each(|(offset, offset_new)| {
87					*offset = offset_new;
88				});
89			let mut lowest = C::intersect(
90				self.offsets
91					.iter()
92					.enumerate()
93					.map(|(k_i, &offset)| &self.counters[k_i][offset]),
94			)
95			.unwrap();
96			lowest += value;
97			self.counters
98				.iter_mut()
99				.zip(self.offsets.iter())
100				.for_each(|(counters, &offset)| {
101					counters[offset].union_assign(&lowest);
102				});
103			lowest
104		} else {
105			let offsets = self.offsets(key);
106			C::intersect(
107				self.counters
108					.iter_mut()
109					.zip(offsets)
110					.map(|(counters, offset)| {
111						counters[offset] += value;
112						&counters[offset]
113					}),
114			)
115			.unwrap()
116		}
117	}
118
119	/// Union the aggregated value for `key` with `value`.
120	pub fn union_assign<Q: ?Sized>(&mut self, key: &Q, value: &C)
121	where
122		Q: Hash,
123		K: Borrow<Q>,
124	{
125		let offsets = self.offsets(key);
126		self.counters
127			.iter_mut()
128			.zip(offsets)
129			.for_each(|(counters, offset)| {
130				counters[offset].union_assign(value);
131			})
132	}
133
134	/// Retrieve an estimate of the aggregated value for `key`.
135	pub fn get<Q: ?Sized>(&self, key: &Q) -> C
136	where
137		Q: Hash,
138		K: Borrow<Q>,
139	{
140		C::intersect(
141			self.counters
142				.iter()
143				.zip(self.offsets(key))
144				.map(|(counters, offset)| &counters[offset]),
145		)
146		.unwrap()
147	}
148
149	// pub fn estimate_memory(
150	// 	probability: f64, tolerance: f64,
151	// ) -> Result<usize, &'static str> {
152	// 	let width = Self::optimal_width(tolerance);
153	// 	let k_num = Self::optimal_k_num(probability);
154	// 	Ok(width * mem::size_of::<C>() * k_num)
155	// }
156
157	/// Clears the `CountMinSketch` data structure, as if it was new.
158	pub fn clear(&mut self) {
159		let config = &self.config;
160		self.counters
161			.iter_mut()
162			.flat_map(|x| x.iter_mut())
163			.for_each(|counter| {
164				*counter = C::new(config);
165			})
166	}
167
168	fn optimal_width(tolerance: f64) -> usize {
169		let e = tolerance;
170		let width = f64_to_usize((2.0 / e).round());
171		max(2, width)
172			.checked_next_power_of_two()
173			.expect("Width would be way too large")
174	}
175
176	fn mask(width: usize) -> usize {
177		assert!(width > 1);
178		assert_eq!(width & (width - 1), 0);
179		width - 1
180	}
181
182	fn optimal_k_num(probability: f64) -> usize {
183		max(
184			1,
185			f64_to_usize(((1.0 - probability).ln() / 0.5_f64.ln()).floor()),
186		)
187	}
188
189	fn offsets<Q: ?Sized>(&self, key: &Q) -> impl Iterator<Item = usize>
190	where
191		Q: Hash,
192		K: Borrow<Q>,
193	{
194		let mask = self.mask;
195		hashes(key).map(move |hash| usize::try_from(hash & u64::try_from(mask).unwrap()).unwrap())
196	}
197}
198
199fn hashes<Q: ?Sized>(key: &Q) -> impl Iterator<Item = u64>
200where
201	Q: Hash,
202{
203	#[allow(missing_copy_implementations, missing_debug_implementations)]
204	struct X(XxHash);
205	impl Iterator for X {
206		type Item = u64;
207		fn next(&mut self) -> Option<Self::Item> {
208			let ret = self.0.finish();
209			self.0.write(&[123]);
210			Some(ret)
211		}
212	}
213	let mut hasher = XxHash::default();
214	key.hash(&mut hasher);
215	X(hasher)
216}
217
218impl<K: ?Sized, C: New + Clone> Clone for CountMinSketch<K, C> {
219	fn clone(&self) -> Self {
220		Self {
221			counters: self.counters.clone(),
222			offsets: vec![0; self.offsets.len()],
223			mask: self.mask,
224			k_num: self.k_num,
225			config: self.config.clone(),
226			marker: PhantomData,
227		}
228	}
229}
230impl<K: ?Sized, C: New> fmt::Debug for CountMinSketch<K, C> {
231	fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
232		fmt.debug_struct("CountMinSketch")
233			// .field("counters", &self.counters)
234			.finish()
235	}
236}
237
238#[cfg(test)]
239mod tests {
240	type CountMinSketch8<K> = super::CountMinSketch<K, u8>;
241	type CountMinSketch16<K> = super::CountMinSketch<K, u16>;
242	type CountMinSketch64<K> = super::CountMinSketch<K, u64>;
243
244	#[ignore] // release mode stops panic
245	#[test]
246	#[should_panic]
247	fn test_overflow() {
248		let mut cms = CountMinSketch8::<&str>::new(0.95, 10.0 / 100.0, ());
249		for _ in 0..300 {
250			let _ = cms.push("key", &1);
251		}
252		// assert_eq!(cms.get("key"), &u8::max_value());
253	}
254
255	#[test]
256	fn test_increment() {
257		let mut cms = CountMinSketch16::<&str>::new(0.95, 10.0 / 100.0, ());
258		for _ in 0..300 {
259			let _ = cms.push("key", &1);
260		}
261		assert_eq!(cms.get("key"), 300);
262	}
263
264	#[test]
265	fn test_increment_multi() {
266		let mut cms = CountMinSketch64::<u64>::new(0.99, 2.0 / 100.0, ());
267		for i in 0..1_000_000 {
268			let _ = cms.push(&(i % 100), &1);
269		}
270		for key in 0..100 {
271			assert!(cms.get(&key) >= 9_000);
272		}
273		// cms.reset();
274		// for key in 0..100 {
275		//     assert!(cms.get(&key) < 11_000);
276		// }
277	}
278}