1use std::fmt;
10use std::fmt::Display;
11use std::hash::Hash;
12use std::fmt::Debug;
13use float_cmp::*;
14use std::collections::HashMap;
15
16use rand::seq::SliceRandom;
17use rand::Rng;
18
19
20#[derive(Debug, Clone)]
38pub struct VoseAlias <T> where T: Display + Copy + Hash + Eq + Debug{
39 pub elements:Vec<T>,
40 pub alias:HashMap<T, T>,
41 pub prob:HashMap<T, f32>,
42 _private:()
43
44}
45
46
47impl<T> VoseAlias<T>
48where T: Display + Copy + Hash + Eq + Debug {
49
50 pub fn new(element_vector:Vec<T>, probability_vector:Vec<f32>) -> VoseAlias<T> {
71 let size_p = probability_vector.len();
72 let size_e = element_vector.len();
73 if size_p != size_e {
75 panic!("Both vectors should contain the same number of elements");
76 }
77
78 let mut sum = 0.0;
79 for p in &probability_vector {
80 sum = sum + p;
81 }
82
83 if !approx_eq!(f32, sum, 1.0, ulps=4) {
84 panic!("Probability vector does not sum to 1");
85 }
86
87
88 let size = probability_vector.len();
90 let mut small:Vec<T> = Vec::new();
91 let mut large:Vec<T> = Vec::new();
92 let mut scaled_probability_vector:HashMap<T, f32> = HashMap::new();
93
94 let mut alias:HashMap<T, T> = HashMap::new();
95 let mut prob:HashMap<T, f32> = HashMap::new();
96
97 for i in 0..size {
99 let p:f32 = probability_vector[i];
100 let e:T = element_vector[i];
101 let scaled_proba = p * (size as f32);
102 scaled_probability_vector.insert(e, scaled_proba);
103
104 if scaled_proba < 1.0 {
105 small.push(e);
106 }
107 else {
108 large.push(e);
109 }
110 }
111
112 while !(small.is_empty() || large.is_empty()) {
114 if let (Some(l), Some(g)) = (small.pop(), large.pop()) {
116 alias.insert(l, g);
118 if let Some(p_l) = scaled_probability_vector.get(&l) {
120 prob.insert(l, *p_l);
122
123 if let Some(p_g) = scaled_probability_vector.get(&g) {
125 let new_p_g = (*p_g + *p_l) - 1.0;
126 scaled_probability_vector.insert(g, new_p_g);
128 if new_p_g < 1.0 {
129 small.push(g);
130 }
131 else {
132 large.push(g);
133 }
134 };
135
136 }
137 }
138 }
139
140 while !large.is_empty() {
142 if let Some(g) = large.pop() {
143 prob.insert(g, 1.0);
145 };
146 }
147
148 while !small.is_empty() {
149 if let Some(l) = small.pop() {
150 prob.insert(l, 1.0);
152 }
153 }
154
155 VoseAlias {
156 elements: element_vector,
157 alias: alias,
158 prob: prob,
159 _private: ()
160 }
161 }
162
163
164
165 pub fn sample(&self) -> T {
184 let (i, num) = self.roll_die_and_flip_coin();
185 return self.select_element(i, num);
186 }
187
188
189 fn roll_die_and_flip_coin(&self) -> (T, u16) {
191 let i:T;
192 match self.elements.choose(&mut rand::thread_rng()) {
193 Some(e) => i = *e,
194 None => panic!("Internal error. The element vector is empty. If this happened, please fill in an issue report."),
195 }
196 let num = rand::thread_rng().gen_range(0, 101);
197
198 return (i, num);
199
200 }
201
202
203 fn select_element(&self, die:T, coin:u16) -> T {
205 let p_i:f32;
207 match self.prob.get(&die) {
208 Some(p) => p_i = *p,
209 None => panic!("Internal error. The probability vector is empty. If this happened, please fill in an issue report."),
210 }
211 if (coin as f32) <= (p_i * 100.0) {
212 return die;
213 }
214 else {
215 match self.alias.get(&die) {
216 Some(alias_i) => return *alias_i,
217 None => panic!("Internal error. No alias found for element {:?}. If this happened, please fill in an issue report.", die),
218 }
219 };
220 }
221
222}
223
224
225impl <T> Display for VoseAlias<T>
229where T: Display + Copy + Hash + Eq + Debug {
230 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231 let mut str_elements = String::from("[ ");
233 for e in &self.elements {
234 str_elements = str_elements + &e.to_string() + " ";
235 }
236 str_elements = str_elements + "]";
237
238 let mut str_alias = String::from("{ ");
240 for k in self.alias.keys() {
241 let a:T;
242 match self.alias.get(&k) {
243 Some(element) => a = *element,
244 None => panic!("Internal error. The alias map does not contain element for {}. If you encountered this error, please fill in an issue report.", k),
245 }
246 str_alias = str_alias + &String::from(format!("{}:{}, ", k, a));
247 }
248 str_alias = str_alias[..str_alias.len() - 2].to_string() + " }";
250
251 let mut str_prob = String::from("{");
253 for k in self.prob.keys() {
254 let p:f32;
255 match self.prob.get(&k) {
256 Some(element) => p = *element,
257 None => panic!("Internal error. The alias map does not contain element for {}. If you encountered this error, please fill in an issue report.", k),
258 }
259 str_prob = str_prob + &String::from(format!("{}:{:.2}, ", k, p));
260 }
261 str_prob = str_prob[..str_prob.len() - 2].to_string() + " }";
263
264 write!(f, "{{ elements: {}, alias: {}, prob: {}}}", str_elements, str_alias, str_prob)
266 }
267}
268
269impl<T> PartialEq for VoseAlias<T>
270where T:Display + Copy + Hash + Eq + Debug {
271 fn eq(&self, other: &Self) -> bool {
272 self.alias == other.alias
273 }
274
275}
276
277
278impl <T> Eq for VoseAlias<T>
279where T:Display + Copy + Hash + Eq + Debug{
280}
281
282
283
284
285
286
287
288#[cfg(test)]
292mod tests{
293 use super::*;
294
295 #[test]
299 fn construction_ok() {
300 VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
301 }
302
303 #[test]
304 #[should_panic]
305 fn size_not_ok() {
306 VoseAlias::new(vec![1, 2, 3], vec![0.5, 0.2, 0.2, 0.1]);
307 }
308
309 #[test]
310 #[should_panic]
311 fn sum_not_ok() {
312 VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.]);
313 }
314
315 #[test]
316 #[should_panic]
317 fn new_empty_vectors() {
318 let element_vector:Vec<u16> = Vec::new();
319 let probability_vector:Vec<f32> = Vec::new();
320 VoseAlias::new(element_vector, probability_vector);
321 }
322
323 #[test]
324 fn test_roll_die_flip_coin() {
325 let element_vector = vec![1, 2, 3, 4];
326 let va = VoseAlias::new(element_vector.clone(), vec![0.5, 0.2, 0.2, 0.1]);
327 let (die, coin) = va.roll_die_and_flip_coin();
328 assert!(element_vector.contains(&die));
329 assert!(coin <= 100);
330 }
331
332 #[test]
333 fn test_select_element_ok() {
334 let va = VoseAlias::new(vec!["orange", "yellow", "green", "turquoise", "grey", "blue", "pink"], vec![0.125, 0.2, 0.1, 0.25, 0.1, 0.1, 0.125]);
335 let element = va.select_element("orange", 0);
337 assert!(element == "orange");
338 let element = va.select_element("orange", 87);
339 assert!(element == "orange");
340 let element = va.select_element("orange", 88);
341 assert!(element == "yellow");
342 let element = va.select_element("orange", 100);
343 assert!(element == "yellow");
344
345 let element = va.select_element("yellow", 0);
347 assert!(element == "yellow");
348 let element = va.select_element("yellow", 100);
349 assert!(element == "yellow");
350
351 let element = va.select_element("green", 0);
353 assert!(element == "green");
354 let element = va.select_element("green", 70);
355 assert!(element == "green");
356 let element = va.select_element("green", 71);
357 assert!(element == "turquoise");
358 let element = va.select_element("green", 100);
359 assert!(element == "turquoise");
360
361 let element = va.select_element("turquoise", 0);
363 assert!(element == "turquoise");
364 let element = va.select_element("turquoise", 72);
365 assert!(element == "turquoise");
366 let element = va.select_element("turquoise", 73);
367 assert!(element == "yellow");
368 let element = va.select_element("turquoise", 100);
369 assert!(element == "yellow");
370
371 let element = va.select_element("grey", 0);
373 assert!(element == "grey");
374 let element = va.select_element("grey", 70);
375 assert!(element == "grey");
376 let element = va.select_element("grey", 71);
377 assert!(element == "turquoise");
378 let element = va.select_element("grey", 100);
379 assert!(element == "turquoise");
380
381 let element = va.select_element("blue", 0);
383 assert!(element == "blue");
384 let element = va.select_element("blue", 70);
385 assert!(element == "blue");
386 let element = va.select_element("blue", 71);
387 assert!(element == "turquoise");
388 let element = va.select_element("blue", 100);
389 assert!(element == "turquoise");
390
391 let element = va.select_element("pink", 0);
393 assert!(element == "pink");
394 let element = va.select_element("pink", 87);
395 assert!(element == "pink");
396 let element = va.select_element("pink", 88);
397 assert!(element == "turquoise");
398 let element = va.select_element("pink", 100);
399 assert!(element == "turquoise");
400 }
401
402
403 #[test]
404 #[should_panic]
405 fn select_element_proba_too_high() {
406 let va = VoseAlias::new(vec!["orange", "yellow", "green", "turquoise", "grey", "blue", "pink"], vec![0.125, 0.2, 0.1, 0.25, 0.1, 0.1, 0.125]);
407 va.select_element("yellow", 101);
408 }
409
410 #[test]
411 #[should_panic]
412 fn select_element_not_in_list() {
413 let va = VoseAlias::new(vec!["orange", "yellow", "green", "turquoise", "grey", "blue", "pink"], vec![0.125, 0.2, 0.1, 0.25, 0.1, 0.1, 0.125]);
414 va.select_element("red", 100);
415 }
416
417
418
419 #[test]
423 fn test_trait_equal() {
424 let va = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
425 let va2 = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
426 assert!(va==va2);
427 }
428
429 #[test]
430 fn test_trait_not_equali() {
431 let va = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.0, 0.3]);
432 let va2 = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
433 assert!(va!=va2);
434 }
435
436}