streaming_algorithms/
top.rs

1use serde::{Deserialize, Serialize};
2use std::{
3	cmp, collections::{hash_map::Entry, HashMap}, fmt::{self, Debug}, hash::Hash, iter, ops
4};
5use twox_hash::RandomXxHashBuilder;
6
7use crate::{
8	count_min::CountMinSketch, ordered_linked_list::{OrderedLinkedList, OrderedLinkedListIndex, OrderedLinkedListIter}, traits::{Intersect, New, UnionAssign}, IntersectPlusUnionIsPlus
9};
10
11/// This probabilistic data structure tracks the `n` top keys given a stream of `(key,value)` tuples, ordered by the sum of the values for each key (the "aggregated value"). It uses only `O(n)` space.
12///
13/// Its implementation is two parts:
14///
15/// * a doubly linked hashmap, mapping the top `n` keys to their aggregated values, and ordered by their aggregated values. This is used to keep a more precise track of the aggregated value of the top `n` keys, and reduce collisions in the count-min sketch.
16/// * a [count-min sketch](https://en.wikipedia.org/wiki/Count–min_sketch) to track all of the keys outside the top `n`. This data structure is also known as a [counting Bloom filter](https://en.wikipedia.org/wiki/Bloom_filter#Counting_filters). It uses *conservative updating* for increased accuracy.
17///
18/// The algorithm is as follows:
19///
20/// ```text
21/// while a key and value from the input stream arrive:
22///     if H[key] exists
23///         increment aggregated value associated with H[key]
24///     elsif number of items in H < k
25///         put H[key] into map with its associated value
26///     else
27///         add C[key] into the count-min sketch with its associated value
28///         if aggregated value associated with C[key] is > the lowest aggregated value in H
29///             move the lowest key and value from H into C
30///             move C[key] and value from C into H
31/// endwhile
32/// ```
33///
34/// See [*An Improved Data Stream Summary: The Count-Min Sketch and its Applications*](http://dimacs.rutgers.edu/~graham/pubs/papers/cm-full.pdf) and [*New Directions in Traffic Measurement and Accounting*](http://pages.cs.wisc.edu/~suman/courses/740/papers/estan03tocs.pdf) for background on the count-min sketch with conservative updating.
35#[derive(Clone, Serialize, Deserialize)]
36#[serde(bound(
37	serialize = "A: Hash + Eq + Serialize, C: Serialize, <C as New>::Config: Serialize",
38	deserialize = "A: Hash + Eq + Deserialize<'de>, C: Deserialize<'de>, <C as New>::Config: Deserialize<'de>"
39))]
40pub struct Top<A, C: New> {
41	map: HashMap<A, OrderedLinkedListIndex<'static>, RandomXxHashBuilder>,
42	list: OrderedLinkedList<Node<A, C>>,
43	count_min: CountMinSketch<A, C>,
44	config: <C as New>::Config,
45}
46impl<A: Hash + Eq + Clone, C: Ord + New + for<'a> UnionAssign<&'a C> + Intersect> Top<A, C> {
47	/// Create an empty `Top` data structure with the specified `n` capacity.
48	pub fn new(n: usize, probability: f64, tolerance: f64, config: <C as New>::Config) -> Self {
49		Self {
50			map: HashMap::with_capacity_and_hasher(n, RandomXxHashBuilder::default()),
51			list: OrderedLinkedList::new(n),
52			count_min: CountMinSketch::new(probability, tolerance, config.clone()),
53			config,
54		}
55	}
56	fn assert(&self) {
57		if !cfg!(feature = "assert") {
58			return;
59		}
60		for (k, &v) in &self.map {
61			assert!(&self.list[v].0 == k);
62		}
63		let mut cur = &self.list[self.list.head().unwrap()].1;
64		for &Node(_, ref count) in self.list.iter() {
65			assert!(cur >= count);
66			cur = count;
67		}
68	}
69	/// The `n` most frequent elements we have capacity to track.
70	pub fn capacity(&self) -> usize {
71		self.list.capacity()
72	}
73	/// "Visit" an element.
74	pub fn push<V: ?Sized>(&mut self, item: A, value: &V)
75	where
76		C: for<'a> ops::AddAssign<&'a V> + IntersectPlusUnionIsPlus,
77	{
78		match self.map.entry(item.clone()) {
79			Entry::Occupied(entry) => {
80				let offset = *entry.get();
81				self.list.mutate(offset, |Node(t, mut count)| {
82					count += value;
83					Node(t, count)
84				});
85			}
86			Entry::Vacant(entry) => {
87				if self.list.len() < self.list.capacity() {
88					let mut x = C::new(&self.config);
89					x += value;
90					let new = self.list.push_back(Node(item, x));
91					let new = unsafe { new.staticify() };
92					let _ = entry.insert(new);
93				} else {
94					let score = self.count_min.push(&item, value);
95					if score > self.list[self.list.tail().unwrap()].1 {
96						let old = self.list.pop_back();
97						let new = self.list.push_back(Node(item, score));
98						let new = unsafe { new.staticify() };
99						let _ = entry.insert(new);
100						let _ = self.map.remove(&old.0).unwrap();
101						self.count_min.union_assign(&old.0, &old.1);
102					}
103				}
104			}
105		}
106		self.assert();
107	}
108	/// Clears the `Top` data structure, as if it was new.
109	pub fn clear(&mut self) {
110		self.map.clear();
111		self.list.clear();
112		self.count_min.clear();
113	}
114	/// An iterator visiting all elements and their counts in descending order of frequency. The iterator element type is (&'a A, usize).
115	pub fn iter(&self) -> TopIter<'_, A, C> {
116		TopIter {
117			list_iter: self.list.iter(),
118		}
119	}
120}
121impl<
122		A: Hash + Eq + Clone + Debug,
123		C: Ord + New + Clone + for<'a> UnionAssign<&'a C> + Intersect + Debug,
124	> Debug for Top<A, C>
125{
126	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
127		f.debug_list().entries(self.iter()).finish()
128	}
129}
130
131/// An iterator over the entries and counts in a [`Top`] datastructure.
132///
133/// This struct is created by the [`iter`](Top::iter()) method on [`Top`]. See its documentation for more.
134pub struct TopIter<'a, A: Hash + Eq + Clone + 'a, C: Ord + 'a> {
135	list_iter: OrderedLinkedListIter<'a, Node<A, C>>,
136}
137impl<'a, A: Hash + Eq + Clone, C: Ord + 'a> Clone for TopIter<'a, A, C> {
138	fn clone(&self) -> Self {
139		Self {
140			list_iter: self.list_iter.clone(),
141		}
142	}
143}
144impl<'a, A: Hash + Eq + Clone, C: Ord + 'a> Iterator for TopIter<'a, A, C> {
145	type Item = (&'a A, &'a C);
146	fn next(&mut self) -> Option<(&'a A, &'a C)> {
147		self.list_iter.next().map(|x| (&x.0, &x.1))
148	}
149}
150impl<'a, A: Hash + Eq + Clone + Debug, C: Ord + Debug + 'a> Debug for TopIter<'a, A, C> {
151	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
152		f.debug_list().entries(self.clone()).finish()
153	}
154}
155
156impl<
157		A: Hash + Eq + Clone,
158		C: Ord
159			+ New
160			+ Clone
161			+ for<'a> ops::AddAssign<&'a C>
162			+ for<'a> UnionAssign<&'a C>
163			+ Intersect
164			+ IntersectPlusUnionIsPlus,
165	> iter::Sum<Top<A, C>> for Option<Top<A, C>>
166{
167	fn sum<I>(mut iter: I) -> Self
168	where
169		I: Iterator<Item = Top<A, C>>,
170	{
171		let mut total = iter.next()?;
172		for sample in iter {
173			total += sample;
174		}
175		Some(total)
176	}
177}
178impl<
179		A: Hash + Eq + Clone,
180		C: Ord
181			+ New
182			+ Clone
183			+ for<'a> ops::AddAssign<&'a C>
184			+ for<'a> UnionAssign<&'a C>
185			+ Intersect
186			+ IntersectPlusUnionIsPlus,
187	> ops::Add for Top<A, C>
188{
189	type Output = Self;
190	fn add(mut self, other: Self) -> Self {
191		self += other;
192		self
193	}
194}
195impl<
196		A: Hash + Eq + Clone,
197		C: Ord
198			+ New
199			+ Clone
200			+ for<'a> ops::AddAssign<&'a C>
201			+ for<'a> UnionAssign<&'a C>
202			+ Intersect
203			+ IntersectPlusUnionIsPlus,
204	> ops::AddAssign for Top<A, C>
205{
206	fn add_assign(&mut self, other: Self) {
207		assert_eq!(self.capacity(), other.capacity());
208
209		let mut scores = HashMap::<_, C>::new();
210		for (url, count) in self.iter() {
211			*scores
212				.entry(url.clone())
213				.or_insert_with(|| C::new(&self.config)) += count;
214		}
215		for (url, count) in other.iter() {
216			*scores
217				.entry(url.clone())
218				.or_insert_with(|| C::new(&self.config)) += count;
219		}
220		let mut top = self.clone();
221		top.clear();
222		for (url, count) in scores {
223			top.push(url.clone(), &count);
224		}
225		*self = top;
226	}
227}
228
229#[derive(Clone, Serialize, Deserialize)]
230struct Node<T, C>(T, C);
231impl<T, C: Ord> Ord for Node<T, C> {
232	#[inline(always)]
233	fn cmp(&self, other: &Self) -> cmp::Ordering {
234		self.1.cmp(&other.1)
235	}
236}
237impl<T, C: PartialOrd> PartialOrd for Node<T, C> {
238	#[inline(always)]
239	fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
240		self.1.partial_cmp(&other.1)
241	}
242}
243impl<T, C: PartialEq> PartialEq for Node<T, C> {
244	#[inline(always)]
245	fn eq(&self, other: &Self) -> bool {
246		self.1.eq(&other.1)
247	}
248}
249impl<T, C: Eq> Eq for Node<T, C> {}
250
251#[cfg(test)]
252mod test {
253	use super::*;
254	use crate::{distinct::HyperLogLog, traits::IntersectPlusUnionIsPlus};
255	use rand::{self, Rng, SeedableRng};
256	use std::time;
257
258	#[test]
259	fn abc() {
260		let mut rng =
261			rand::rngs::SmallRng::from_seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
262		let mut top = Top::<String, usize>::new(100, 0.99, 2.0 / 1000.0, ());
263		let mut x = HashMap::new();
264		for _ in 0..10_000 {
265			let (a, b) = (rng.gen_range(0, 2) == 0, rng.gen_range(0, 2) == 0);
266			let c = rng.gen_range(0, 50);
267			let record = match (a, b) {
268				(true, _) => format!("a{}", c),
269				(false, true) => format!("b{}", c),
270				(false, false) => format!("c{}", c),
271			};
272			top.push(record.clone(), &1);
273			*x.entry(record).or_insert(0) += 1;
274		}
275		println!("{:#?}", top);
276		let mut x = x.into_iter().collect::<Vec<_>>();
277		x.sort_by_key(|x| cmp::Reverse(x.1));
278		println!("{:#?}", x);
279	}
280
281	#[derive(Serialize, Deserialize)]
282	#[serde(bound = "")]
283	struct HLL<V>(HyperLogLog<V>);
284	impl<V: Hash> Ord for HLL<V> {
285		#[inline(always)]
286		fn cmp(&self, other: &Self) -> cmp::Ordering {
287			self.0.len().partial_cmp(&other.0.len()).unwrap()
288		}
289	}
290	impl<V: Hash> PartialOrd for HLL<V> {
291		#[inline(always)]
292		fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
293			self.0.len().partial_cmp(&other.0.len())
294		}
295	}
296	impl<V: Hash> PartialEq for HLL<V> {
297		#[inline(always)]
298		fn eq(&self, other: &Self) -> bool {
299			self.0.len().eq(&other.0.len())
300		}
301	}
302	impl<V: Hash> Eq for HLL<V> {}
303	impl<V: Hash> Clone for HLL<V> {
304		fn clone(&self) -> Self {
305			Self(self.0.clone())
306		}
307	}
308	impl<V: Hash> New for HLL<V> {
309		type Config = f64;
310		fn new(config: &Self::Config) -> Self {
311			Self(New::new(config))
312		}
313	}
314	impl<V: Hash> Intersect for HLL<V> {
315		fn intersect<'a>(iter: impl Iterator<Item = &'a Self>) -> Option<Self>
316		where
317			Self: Sized + 'a,
318		{
319			Intersect::intersect(iter.map(|x| &x.0)).map(Self)
320		}
321	}
322	impl<'a, V: Hash> UnionAssign<&'a HLL<V>> for HLL<V> {
323		fn union_assign(&mut self, rhs: &'a Self) {
324			self.0.union_assign(&rhs.0)
325		}
326	}
327	impl<'a, V: Hash> ops::AddAssign<&'a V> for HLL<V> {
328		fn add_assign(&mut self, rhs: &'a V) {
329			self.0.add_assign(rhs)
330		}
331	}
332	impl<V: Hash> Debug for HLL<V> {
333		fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
334			self.0.fmt(fmt)
335		}
336	}
337	impl<V> IntersectPlusUnionIsPlus for HLL<V> {
338		const VAL: bool = <HyperLogLog<V> as IntersectPlusUnionIsPlus>::VAL;
339	}
340
341	#[test]
342	fn top_hll() {
343		let mut rng =
344			rand::rngs::SmallRng::from_seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
345		let mut top = Top::<String, HLL<String>>::new(1000, 0.99, 2.0 / 1000.0, 0.00408);
346		// let mut x = HashMap::new();
347		for _ in 0..5_000 {
348			let (a, b) = (rng.gen_range(0, 2) == 0, rng.gen_range(0, 2) == 0);
349			let c = rng.gen_range(0, 800);
350			let record = match (a, b) {
351				(true, _) => (format!("a{}", c), format!("{}", rng.gen_range(0, 500))),
352				(false, true) => (format!("b{}", c), format!("{}", rng.gen_range(0, 200))),
353				(false, false) => (format!("c{}", c), format!("{}", rng.gen_range(0, 200))),
354			};
355			// *x.entry(record.0)
356			// 	.or_insert(HashMap::new())
357			// 	.entry(record.1)
358			// 	.or_insert(0) += 1;
359			top.push(record.0, &record.1);
360		}
361		println!("{:#?}", top);
362		// let mut x = x.into_iter().collect::<Vec<_>>();
363		// x.sort_by_key(|x|cmp::Reverse(x.1));
364		// println!("{:#?}", x);
365	}
366
367	#[ignore] // takes too long on CI
368	#[test]
369	fn many() {
370		let start = time::Instant::now();
371
372		let mut rng =
373			rand::rngs::SmallRng::from_seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
374		let mut top = Top::<String, HLL<String>>::new(1000, 0.99, 2.0 / 1000.0, 0.05);
375		// let mut x = HashMap::new();
376		for _ in 0..5_000_000 {
377			let (a, b) = (rng.gen_range(0, 2) == 0, rng.gen_range(0, 2) == 0);
378			let c = rng.gen_range(0, 800);
379			let record = match (a, b) {
380				(true, _) => (format!("a{}", c), format!("{}", rng.gen_range(0, 500))),
381				(false, true) => (format!("b{}", c), format!("{}", rng.gen_range(0, 200))),
382				(false, false) => (format!("c{}", c), format!("{}", rng.gen_range(0, 200))),
383			};
384			// *x.entry(record.0)
385			// 	.or_insert(HashMap::new())
386			// 	.entry(record.1)
387			// 	.or_insert(0) += 1;
388			top.push(record.0, &record.1);
389		}
390
391		println!("{:?}", start.elapsed());
392		// println!("{:#?}", top);
393		// let mut x = x.into_iter().collect::<Vec<_>>();
394		// x.sort_by_key(|x|cmp::Reverse(x.1));
395		// println!("{:#?}", x);
396	}
397}
398
399// mod merge {
400// 	// https://stackoverflow.com/questions/23039130/rust-implementing-merge-sorted-iterator/32020190#32020190
401// 	use std::{cmp::Ordering, iter::Peekable};
402
403// 	pub struct Merge<L, R>
404// 	where
405// 		L: Iterator<Item = R::Item>,
406// 		R: Iterator,
407// 	{
408// 		left: Peekable<L>,
409// 		right: Peekable<R>,
410// 	}
411// 	impl<L, R> Merge<L, R>
412// 	where
413// 		L: Iterator<Item = R::Item>,
414// 		R: Iterator,
415// 	{
416// 		pub fn new(left: L, right: R) -> Self {
417// 			Merge {
418// 				left: left.peekable(),
419// 				right: right.peekable(),
420// 			}
421// 		}
422// 	}
423
424// 	impl<L, R> Iterator for Merge<L, R>
425// 	where
426// 		L: Iterator<Item = R::Item>,
427// 		R: Iterator,
428// 		L::Item: Ord,
429// 	{
430// 		type Item = L::Item;
431
432// 		fn next(&mut self) -> Option<L::Item> {
433// 			let which = match (self.left.peek(), self.right.peek()) {
434// 				(Some(l), Some(r)) => Some(l.cmp(r)),
435// 				(Some(_), None) => Some(Ordering::Less),
436// 				(None, Some(_)) => Some(Ordering::Greater),
437// 				(None, None) => None,
438// 			};
439
440// 			match which {
441// 				Some(Ordering::Less) => self.left.next(),
442// 				Some(Ordering::Equal) => self.left.next(),
443// 				Some(Ordering::Greater) => self.right.next(),
444// 				None => None,
445// 			}
446// 		}
447// 	}
448// }