stochastic_data_structures/
lib.rs1extern crate rand;
2
3use rand::*;
4
5#[derive(Debug, Copy, Clone)]
6pub struct ExtractStats {
7 pub loop_count: u32,
8 pub group_iterations: Option<u32>,
9}
10
11pub trait StatisticalMethod<T>
12where
13 T: Clone + Copy,
14{
15 fn add(&mut self, rate: f32, payload: T) -> Outcome<T>;
16 fn delete(&mut self, outcome: Outcome<T>);
17 fn update(&mut self, outcome: Outcome<T>, new_rate: f32) -> Outcome<T>;
18 fn extract<Random: Rng>(&self, rnd: &mut Random) -> (ExtractStats, T, f32);
19}
20
21#[derive(Default, Clone, Copy, Debug)]
22pub struct Outcome<T> {
23 idx: usize,
24 group_idx: Option<usize>,
25 rate: f32,
26 pub payload: T,
27}
28
29#[derive(Clone, Debug)]
30pub struct RejectionMethod<T> {
31 max_rate: f32,
32 outcomes: Vec<Outcome<T>>,
33}
34
35impl<T> RejectionMethod<T> {
36 pub fn new(max_rate: f32) -> Self {
37 Self {
38 max_rate: max_rate,
39 outcomes: vec![],
40 }
41 }
42}
43
44impl<T> StatisticalMethod<T> for RejectionMethod<T>
45where
46 T: Copy + Clone,
47{
48 fn add(&mut self, mut rate: f32, payload: T) -> Outcome<T> {
49 if rate > self.max_rate {
50 panic!("Invalid rate provided in `add`");
52 }
53
54 if rate == 0.0 {
55 rate = 0.0001;
56 }
57
58 let outcome = Outcome {
59 payload: payload,
60 group_idx: None,
61 rate: rate,
62 idx: self.outcomes.len(),
63 };
64
65 self.outcomes.push(outcome);
66 outcome
67 }
68
69 fn delete(&mut self, outcome: Outcome<T>) {
70 self.outcomes.swap_remove(outcome.idx);
71 }
75
76 fn update(&mut self, outcome: Outcome<T>, new_rate: f32) -> Outcome<T> {
77 if new_rate == 0.0 {
83 panic!("Invalid rate provided in `update`");
84 }
85
86 let outcome = &mut self.outcomes[outcome.idx];
87 outcome.rate = new_rate;
88 *outcome
89 }
90
91 fn extract<Random: Rng>(&self, rng: &mut Random) -> (ExtractStats, T, f32) {
92 let mut loop_count = 0;
93 loop {
94 loop_count += 1;
95 let rand = rng.gen_range::<f32>(0.0, self.outcomes.len() as f32);
96 let rand_idx = rand.floor();
97 let rand_rate = (rand - rand_idx) * self.max_rate;
98
99 let outcome = &self.outcomes[rand_idx as usize];
100
101 if outcome.rate >= rand_rate {
102 return (
103 ExtractStats {
104 loop_count,
105 group_iterations: None,
106 },
107 outcome.payload,
108 outcome.rate,
109 );
110 }
111 }
112 }
113}
114
115#[derive(Debug)]
116pub struct CompositeRejectionMethod<T> {
117 groups: Vec<RejectionMethod<T>>,
118 sum_rates: Vec<f32>,
119 total_rate: f32,
120 constant: f32,
121 max: f32,
122}
123
124impl<T> CompositeRejectionMethod<T> {
125 pub fn new(max: f32, constant: f32) -> Self {
126 if constant <= 1.0 {
127 panic!("Invalid constant");
128 }
129
130 if max <= 1.0 {
131 panic!("Invalid max value");
132 }
133
134 let group_count = max.log(constant).ceil() as usize;
135 let mut groups = vec![];
136
137 for exponent in 0..group_count {
138 groups.push(RejectionMethod::new(max / constant.powf(exponent as f32)));
139 }
140
141 Self {
142 groups: groups,
143 sum_rates: vec![0.0; group_count],
144 total_rate: 0.0,
145 constant: constant,
146 max: max,
147 }
148 }
149
150 fn find_group_idx(&self, rate: f32) -> usize {
151 (self.max / rate.max(1.0)).log(self.constant).floor() as usize
154 }
155}
156
157impl<T> StatisticalMethod<T> for CompositeRejectionMethod<T>
158where
159 T: Copy + Clone,
160{
161 fn add(&mut self, rate: f32, payload: T) -> Outcome<T> {
162 if rate > self.max {
163 panic!("Rate out of range rate: {}, max rate: {}", rate, self.max);
164 }
165
166 let group_idx = self.find_group_idx(rate);
167
168 let mut outcome = self.groups[group_idx].add(rate, payload);
169 self.sum_rates[group_idx] += rate;
170 self.total_rate += rate;
171
172 outcome.group_idx = Some(group_idx);
173 outcome
174 }
175
176 fn delete(&mut self, outcome: Outcome<T>) {
177 if let Some(group_idx) = outcome.group_idx {
178 self.sum_rates[group_idx] -= outcome.rate;
179 self.total_rate -= outcome.rate;
180
181 self.groups[group_idx].delete(outcome);
182 }
183 }
184
185 fn update(&mut self, outcome: Outcome<T>, new_rate: f32) -> Outcome<T> {
186 let new_group_idx = self.find_group_idx(new_rate);
187
188 if let Some(old_group_idx) = outcome.group_idx {
189 let delta_rate = new_rate - outcome.rate;
190
191 self.total_rate += delta_rate;
192
193 let mut outcome = if new_group_idx == old_group_idx {
194 self.sum_rates[new_group_idx] += delta_rate;
196 self.groups[new_group_idx].update(outcome, new_rate)
197 } else {
198 self.sum_rates[old_group_idx] -= outcome.rate;
200 self.groups[old_group_idx].delete(outcome);
201
202 self.sum_rates[new_group_idx] += new_rate;
204 self.groups[new_group_idx].add(new_rate, outcome.payload)
205 };
206
207 outcome.group_idx = Some(new_group_idx);
208 return outcome;
209 } else {
210 panic!("Outcome must have a group idx set");
211 }
212 }
213
214 fn extract<Random: Rng>(&self, rng: &mut Random) -> (ExtractStats, T, f32) {
215 let u = rng.gen::<f32>();
216 let mut rand = u * self.total_rate;
217 let mut iterations = 0;
218 for (idx, g) in self.groups.iter().enumerate() {
219 if self.sum_rates[idx] > rand {
220 let mut r = g.extract(rng);
221 r.0.group_iterations = Some(iterations);
222 r.2 = self.total_rate / r.2;
223 return r;
224 }
225
226 iterations += 1;
227
228 rand = rand - self.sum_rates[idx];
229 }
230
231 panic!(
232 "Shouldn't be able to reach here, algorithm invariant breached {} {}",
233 u, iterations
234 );
235 }
236}
237
238pub fn to_fixed_8_24(v: f32) -> u32 {
239 let int = (v.floor() as u32) << 8u64;
240 let mut frac = ((v - v.floor()) * 255.0) as u32;
241
242 if frac > 255 {
243 frac = 255
244 };
245
246 int + frac
247}
248
249pub fn from_fixed_8_24(v: u32) -> f32 {
250 let nominator = (v >> 8u64) as f32;
251 let denominator = (v & 0xffu32) as f32 / 255.0 as f32;
252
253 nominator + denominator
254}
255
256#[derive(Clone)]
257pub struct AliasMethod {
258 alias: Vec<u32>,
259 probability: Vec<f32>,
260}
261
262impl AliasMethod {
263 pub fn new(mut list: Vec<f32>) -> AliasMethod {
264 let mut sum = 0.0;
265
266 for p in list.iter() {
267 sum += p;
268 }
269
270 let list_len = list.len() as f32;
271
272 for p in list.iter_mut() {
273 *p *= list_len / sum;
274 }
275
276 let mut small = Vec::new();
277 let mut large = Vec::new();
278
279 small.resize(list.len(), 0);
280 large.resize(list.len(), 0);
281
282 let mut num_small = 0;
283 let mut num_large = 0;
284
285 for k in 0..list.len() {
286 let i = list.len() - k - 1;
287
288 if list[i] < 1.0 {
289 small[num_small] = i;
290 num_small += 1;
291 } else {
292 large[num_large] = i;
293 num_large += 1;
294 }
295 }
296
297 let mut alias = AliasMethod {
298 alias: vec![0; list.len()],
299 probability: vec![0.0; list.len()],
300 };
301
302 while num_small != 0 && num_large != 0 {
303 num_small -= 1;
304 num_large -= 1;
305
306 let a = small[num_small];
307 let g = large[num_large];
308
309 alias.probability[a] = list[a];
310 alias.alias[a] = g as u32;
311
312 list[g] = list[g] + list[a] - 1.0;
313
314 if list[g] < 1.0 {
315 small[num_small] = g;
316 num_small += 1;
317 } else {
318 large[num_large] = g;
319 num_large += 1;
320 }
321 }
322
323 for k in 0..num_large {
324 alias.probability[large[k]] = 1.0
325 }
326
327 for k in 0..num_small {
328 alias.probability[small[k]] = 1.0
329 }
330
331 alias
332 }
333
334 pub fn find_index(&self, u0: f32, u1: f32) -> usize {
335 let idx = (self.alias.len() as f32 * u0) as usize;
336 if u1 < self.probability[idx] {
337 idx
338 } else {
339 self.alias[idx] as usize
340 }
341 } }