sampling_tree/
lib.rs

1mod sampling;
2
3pub type SimpleSamplingTree<C> = sampling::Tree<sampling::UnstableNode<C>>;
4pub use sampling::UnstableNode;
5
6#[cfg(test)]
7mod tests {
8    use std::{fmt::Write, sync::Arc};
9
10    use super::*;
11    use human_units::{FormatSize, Size};
12    use indicatif::{self, ProgressState};
13    use rand::{rngs::StdRng, Rng, SeedableRng};
14
15    #[test]
16    fn it_works() {
17        let mut rng = rand::thread_rng();
18        let n = 100;
19        let range = 10000u32;
20        let data = (0..n).map(|_| rng.gen_range(0..range));
21        let mut sampling_tree: SimpleSamplingTree<_> =
22            SimpleSamplingTree::from_iterable(data).unwrap();
23        println!("{:?}", sampling_tree);
24        let sample_idx = sampling_tree.sample(&mut rng).unwrap();
25        println!(
26            "{:?}, {:?}",
27            sample_idx,
28            sampling_tree.contribution(sample_idx).unwrap()
29        );
30        sampling_tree.update(sample_idx, 0).unwrap();
31        println!("{:?}", sampling_tree);
32
33        println!(
34            "Size of node: {}",
35            std::mem::size_of::<sampling::UnstableNode<u64>>()
36        );
37        // panic!();
38    }
39
40    #[test]
41    fn test_throughput() {
42        let mut rng = rand::thread_rng();
43        let n = 1_000_000;
44        let range = 10000u64;
45        let data = (0..n).map(|_| rng.gen_range(0..range));
46        let sampling_tree: SimpleSamplingTree<_> = SimpleSamplingTree::from_iterable(data).unwrap();
47
48        // measure throughput of sampling
49        let num_samples = 1_000_000;
50        let num_threads = 4;
51
52        let mp = Arc::new(indicatif::MultiProgress::new());
53        let sty_main = indicatif::ProgressStyle::default_bar()
54            .template("[{elapsed_precise}] {bar:40.cyan/blue} {percent}% {iter_per_sec}")
55            .unwrap()
56            .progress_chars("##-")
57            .with_key(
58                "iter_per_sec",
59                |state: &ProgressState, w: &mut dyn Write| {
60                    let speed = state.per_sec() as u64;
61                    write!(w, "{}", Size(speed).format_size()).unwrap()
62                },
63            );
64
65        let pb = Arc::new(mp.add(indicatif::ProgressBar::new(num_samples)));
66        pb.set_style(sty_main.clone());
67        let sampling: Arc<sampling::Tree<UnstableNode<u64>>> = Arc::new(sampling_tree);
68
69        let handles: Vec<_> = (0..num_threads)
70            .map(|_| {
71                let pb = pb.clone();
72                let mut rng = StdRng::from_entropy();
73                let sampling = sampling.clone();
74
75                std::thread::spawn(move || {
76                    for _ in 0..num_samples / num_threads {
77                        let sample_idx = sampling.sample(&mut rng).unwrap();
78                        let _ = sampling.contribution(sample_idx).unwrap();
79                        pb.inc(1);
80                    }
81                })
82            })
83            .collect();
84
85        for handle in handles {
86            handle.join().unwrap();
87        }
88    }
89}