stochastic_data_structures/
lib.rs

1extern crate rand;
2
3use rand::*;
4
5#[derive(Debug, Copy, Clone)]
6pub struct ExtractStats {
7    pub loop_count: u32,
8    pub group_iterations: Option<u32>,
9}
10
11pub trait StatisticalMethod<T>
12where
13    T: Clone + Copy,
14{
15    fn add(&mut self, rate: f32, payload: T) -> Outcome<T>;
16    fn delete(&mut self, outcome: Outcome<T>);
17    fn update(&mut self, outcome: Outcome<T>, new_rate: f32) -> Outcome<T>;
18    fn extract<Random: Rng>(&self, rnd: &mut Random) -> (ExtractStats, T, f32);
19}
20
21#[derive(Default, Clone, Copy, Debug)]
22pub struct Outcome<T> {
23    idx: usize,
24    group_idx: Option<usize>,
25    rate: f32,
26    pub payload: T,
27}
28
29#[derive(Clone, Debug)]
30pub struct RejectionMethod<T> {
31    max_rate: f32,
32    outcomes: Vec<Outcome<T>>,
33}
34
35impl<T> RejectionMethod<T> {
36    pub fn new(max_rate: f32) -> Self {
37        Self {
38            max_rate: max_rate,
39            outcomes: vec![],
40        }
41    }
42}
43
44impl<T> StatisticalMethod<T> for RejectionMethod<T>
45where
46    T: Copy + Clone,
47{
48    fn add(&mut self, mut rate: f32, payload: T) -> Outcome<T> {
49        if rate > self.max_rate {
50            // todo: clamp rate, or something
51            panic!("Invalid rate provided in `add`");
52        }
53
54        if rate == 0.0 {
55            rate = 0.0001;
56        }
57
58        let outcome = Outcome {
59            payload: payload,
60            group_idx: None,
61            rate: rate,
62            idx: self.outcomes.len(),
63        };
64
65        self.outcomes.push(outcome);
66        outcome
67    }
68
69    fn delete(&mut self, outcome: Outcome<T>) {
70        self.outcomes.swap_remove(outcome.idx);
71        /*if self.outcomes.len() != 0 {
72            self.outcomes[outcome.idx].idx = outcome.idx;
73        }*/
74    }
75
76    fn update(&mut self, outcome: Outcome<T>, new_rate: f32) -> Outcome<T> {
77        // if new_rate == 0.0 && self.outcomes[outcome_idx].rate > 0.0 {
78        //     self.delete(outcome_idx);
79        // } else if (new_rate > 0.0 && self.outcomes[outcome_idx].rate == 0.0) {
80        // }
81
82        if new_rate == 0.0 {
83            panic!("Invalid rate provided in `update`");
84        }
85
86        let outcome = &mut self.outcomes[outcome.idx];
87        outcome.rate = new_rate;
88        *outcome
89    }
90
91    fn extract<Random: Rng>(&self, rng: &mut Random) -> (ExtractStats, T, f32) {
92        let mut loop_count = 0;
93        loop {
94            loop_count += 1;
95            let rand = rng.gen_range::<f32>(0.0, self.outcomes.len() as f32);
96            let rand_idx = rand.floor();
97            let rand_rate = (rand - rand_idx) * self.max_rate;
98
99            let outcome = &self.outcomes[rand_idx as usize];
100
101            if outcome.rate >= rand_rate {
102                return (
103                    ExtractStats {
104                        loop_count,
105                        group_iterations: None,
106                    },
107                    outcome.payload,
108                    outcome.rate,
109                );
110            }
111        }
112    }
113}
114
115#[derive(Debug)]
116pub struct CompositeRejectionMethod<T> {
117    groups: Vec<RejectionMethod<T>>,
118    sum_rates: Vec<f32>,
119    total_rate: f32,
120    constant: f32,
121    max: f32,
122}
123
124impl<T> CompositeRejectionMethod<T> {
125    pub fn new(max: f32, constant: f32) -> Self {
126        if constant <= 1.0 {
127            panic!("Invalid constant");
128        }
129
130        if max <= 1.0 {
131            panic!("Invalid max value");
132        }
133
134        let group_count = max.log(constant).ceil() as usize;
135        let mut groups = vec![];
136
137        for exponent in 0..group_count {
138            groups.push(RejectionMethod::new(max / constant.powf(exponent as f32)));
139        }
140
141        Self {
142            groups: groups,
143            sum_rates: vec![0.0; group_count],
144            total_rate: 0.0,
145            constant: constant,
146            max: max,
147        }
148    }
149
150    fn find_group_idx(&self, rate: f32) -> usize {
151        // clamp rate to 1.0 on the lower end so all the rates between 0
152        // and 1 fall into the very first bucket
153        (self.max / rate.max(1.0)).log(self.constant).floor() as usize
154    }
155}
156
157impl<T> StatisticalMethod<T> for CompositeRejectionMethod<T>
158where
159    T: Copy + Clone,
160{
161    fn add(&mut self, rate: f32, payload: T) -> Outcome<T> {
162        if rate > self.max {
163            panic!("Rate out of range rate: {}, max rate: {}", rate, self.max);
164        }
165
166        let group_idx = self.find_group_idx(rate);
167
168        let mut outcome = self.groups[group_idx].add(rate, payload);
169        self.sum_rates[group_idx] += rate;
170        self.total_rate += rate;
171
172        outcome.group_idx = Some(group_idx);
173        outcome
174    }
175
176    fn delete(&mut self, outcome: Outcome<T>) {
177        if let Some(group_idx) = outcome.group_idx {
178            self.sum_rates[group_idx] -= outcome.rate;
179            self.total_rate -= outcome.rate;
180
181            self.groups[group_idx].delete(outcome);
182        }
183    }
184
185    fn update(&mut self, outcome: Outcome<T>, new_rate: f32) -> Outcome<T> {
186        let new_group_idx = self.find_group_idx(new_rate);
187
188        if let Some(old_group_idx) = outcome.group_idx {
189            let delta_rate = new_rate - outcome.rate;
190
191            self.total_rate += delta_rate;
192
193            let mut outcome = if new_group_idx == old_group_idx {
194                // group stayed the same, just update
195                self.sum_rates[new_group_idx] += delta_rate;
196                self.groups[new_group_idx].update(outcome, new_rate)
197            } else {
198                // group changed, remove from old group
199                self.sum_rates[old_group_idx] -= outcome.rate;
200                self.groups[old_group_idx].delete(outcome);
201
202                // add to new group
203                self.sum_rates[new_group_idx] += new_rate;
204                self.groups[new_group_idx].add(new_rate, outcome.payload)
205            };
206
207            outcome.group_idx = Some(new_group_idx);
208            return outcome;
209        } else {
210            panic!("Outcome must have a group idx set");
211        }
212    }
213
214    fn extract<Random: Rng>(&self, rng: &mut Random) -> (ExtractStats, T, f32) {
215        let u = rng.gen::<f32>();
216        let mut rand = u * self.total_rate;
217        let mut iterations = 0;
218        for (idx, g) in self.groups.iter().enumerate() {
219            if self.sum_rates[idx] > rand {
220                let mut r = g.extract(rng);
221                r.0.group_iterations = Some(iterations);
222                r.2 = self.total_rate / r.2;
223                return r;
224            }
225
226            iterations += 1;
227
228            rand = rand - self.sum_rates[idx];
229        }
230
231        panic!(
232            "Shouldn't be able to reach here, algorithm invariant breached {} {}",
233            u, iterations
234        );
235    }
236}
237
238pub fn to_fixed_8_24(v: f32) -> u32 {
239    let int = (v.floor() as u32) << 8u64;
240    let mut frac = ((v - v.floor()) * 255.0) as u32;
241
242    if frac > 255 {
243        frac = 255
244    };
245
246    int + frac
247}
248
249pub fn from_fixed_8_24(v: u32) -> f32 {
250    let nominator = (v >> 8u64) as f32;
251    let denominator = (v & 0xffu32) as f32 / 255.0 as f32;
252
253    nominator + denominator
254}
255
256#[derive(Clone)]
257pub struct AliasMethod {
258    alias: Vec<u32>,
259    probability: Vec<f32>,
260}
261
262impl AliasMethod {
263    pub fn new(mut list: Vec<f32>) -> AliasMethod {
264        let mut sum = 0.0;
265
266        for p in list.iter() {
267            sum += p;
268        }
269
270        let list_len = list.len() as f32;
271
272        for p in list.iter_mut() {
273            *p *= list_len / sum;
274        }
275
276        let mut small = Vec::new();
277        let mut large = Vec::new();
278
279        small.resize(list.len(), 0);
280        large.resize(list.len(), 0);
281
282        let mut num_small = 0;
283        let mut num_large = 0;
284
285        for k in 0..list.len() {
286            let i = list.len() - k - 1;
287
288            if list[i] < 1.0 {
289                small[num_small] = i;
290                num_small += 1;
291            } else {
292                large[num_large] = i;
293                num_large += 1;
294            }
295        }
296
297        let mut alias = AliasMethod {
298            alias: vec![0; list.len()],
299            probability: vec![0.0; list.len()],
300        };
301
302        while num_small != 0 && num_large != 0 {
303            num_small -= 1;
304            num_large -= 1;
305
306            let a = small[num_small];
307            let g = large[num_large];
308
309            alias.probability[a] = list[a];
310            alias.alias[a] = g as u32;
311
312            list[g] = list[g] + list[a] - 1.0;
313
314            if list[g] < 1.0 {
315                small[num_small] = g;
316                num_small += 1;
317            } else {
318                large[num_large] = g;
319                num_large += 1;
320            }
321        }
322
323        for k in 0..num_large {
324            alias.probability[large[k]] = 1.0
325        }
326
327        for k in 0..num_small {
328            alias.probability[small[k]] = 1.0
329        }
330
331        alias
332    }
333
334    pub fn find_index(&self, u0: f32, u1: f32) -> usize {
335        let idx = (self.alias.len() as f32 * u0) as usize;
336        if u1 < self.probability[idx] {
337            idx
338        } else {
339            self.alias[idx] as usize
340        }
341    } /*
342
343    pub fn find_index(&self, u0: f32) -> usize {
344        let u1 = ((b - a + 1) * u0) - ((b - a + 1.0) * u0).floor();
345        self.find_index(u0, u1)
346    }*/
347}