weighted_sampling/
lib.rs

1//! Weighted reservoir sampling using Algorithm A-Chao.
2
3use rand::Rng;
4
5/// Maintains a single weighted sample from a stream.
6/// Each item is selected with probability proportional to its weight.
7pub struct WeightedReservoirSampler<T> {
8  selected: Option<T>,
9  total_weight: u64,
10}
11
12impl<T> WeightedReservoirSampler<T> {
13  /// Create an empty sampler.
14  #[inline]
15  pub fn new() -> Self {
16    Self {
17      selected: None,
18      total_weight: 0,
19    }
20  }
21
22  /// Add a new item with a given weight.
23  ///
24  /// Items with weight 0 are ignored.
25  pub fn sample<R: Rng + ?Sized>(&mut self, weight: u64, item: T, rng: &mut R) {
26    if self.try_select(weight, rng) {
27      self.selected = Some(item);
28    }
29  }
30
31  /// Current selected item (if any).
32  #[inline]
33  pub fn selected(&self) -> Option<&T> {
34    self.selected.as_ref()
35  }
36
37  /// Take ownership of the selected item.
38  #[inline]
39  pub fn into_selected(self) -> Option<T> {
40    self.selected
41  }
42
43  /// True if no positive-weight items were added.
44  #[inline]
45  pub fn is_empty(&self) -> bool {
46    self.total_weight == 0
47  }
48
49  /// Sum of all positive weights processed so far.
50  #[inline]
51  pub fn total_weight(&self) -> u64 {
52    self.total_weight
53  }
54
55  /// Core of Algorithm A-Chao:
56  /// - First item is always selected.
57  /// - Later items replace the sample with probability `weight / total_weight`.
58  fn try_select<R: Rng + ?Sized>(&mut self, weight: u64, rng: &mut R) -> bool {
59    // Ignore zero-weight items.
60    if weight == 0 {
61      return false;
62    }
63
64    self.total_weight += weight;
65
66    // Select if this is the first item.
67    if self.total_weight == weight {
68      return true;
69    }
70
71    // Select if random draw is within item weight.
72    let draw = rng.random_range(1..=self.total_weight);
73    if draw <= weight {
74      return true;
75    }
76
77    false
78  }
79}
80
81impl<T> Default for WeightedReservoirSampler<T> {
82  fn default() -> Self {
83    Self::new()
84  }
85}