vose_alias/
lib.rs

1//! This module is an implementation of the Vose-Alias method, to sample an element from a list, given a discrete probability distribution.
2//!
3//! This module contains function to create the Probability and Alias tables and sample from them. 
4//!
5//! The algorithm implemented follows the explanation given on [this page](https://www.keithschwarz.com/darts-dice-coins/)
6//!
7
8
9use std::fmt;
10use std::fmt::Display;
11use std::hash::Hash;
12use std::fmt::Debug;
13use float_cmp::*;
14use std::collections::HashMap;
15
16use rand::seq::SliceRandom;
17use rand::Rng;
18
19
20/////////////////////////////////////////////
21// Structure Definition and Implementation //
22/////////////////////////////////////////////
23/// A structure containing the necessary Vose-Alias tables. 
24///
25/// The structure contains the following attributes:
26/// 1. A vector containing the elements to sample frmo
27/// 2. The Alias table, created from the Vose-Alias initialization step
28/// 3. The Probability table, created frmo the Vose-Alias initialization step
29///
30/// The structure is created by the function `vose_alias::new()`. See its documentation for more details.
31///
32/// Internally, the elements are used as indexes in `HashMap` and `Vec`. Therefore, the type `T` must implement the following traits:
33/// - Copy
34/// - Hash
35/// - Eq
36/// - Debug
37#[derive(Debug, Clone)]
38pub struct VoseAlias <T> where T: Display + Copy + Hash + Eq + Debug{
39    pub elements:Vec<T>,
40    pub alias:HashMap<T, T>,
41    pub prob:HashMap<T, f32>,
42    _private:()
43    
44}
45
46
47impl<T> VoseAlias<T>
48where T: Display + Copy + Hash + Eq + Debug {
49
50    /// Returns the Vose-Alias object containing the element vector as well as the alias and probability tables.
51    ///
52    /// The `element_vector` contains the list of elements that should be sampled from.
53    /// The `probability_vector` contains the probability distribution to be sampled with.
54    /// `element_vector` and `probability_vector` should have the same size and `probability_vector` should describe a well-formed probability distribution.
55    ///
56    /// # Panics
57    ///
58    /// The function panics in two casese:
59    /// 1. the `element_vector` and the `probability_vector` do not contain the same number of elements
60    /// 2. the sum of the elements in `probability_vector` is not equal to 1 (with a floating number precision of 0.0001), meaning that `probability_vector` does not describe a well formed probability distribution
61    ///
62    /// # Examples
63    /// ```
64    /// use vose_alias::VoseAlias;
65    /// 
66    /// // Creates a Vose-Alias object from a list of Integer elements
67    /// let va = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
68    /// ```
69    
70    pub fn new(element_vector:Vec<T>, probability_vector:Vec<f32>) -> VoseAlias<T> {
71        let size_p = probability_vector.len();
72        let size_e = element_vector.len();
73        // some sanity checks
74        if size_p != size_e {
75            panic!("Both vectors should contain the same number of elements");
76        }
77
78        let mut sum = 0.0;
79        for p in &probability_vector {
80            sum = sum + p;
81        }
82
83        if !approx_eq!(f32, sum, 1.0, ulps=4) {
84            panic!("Probability vector does not sum to 1");
85        }
86
87        
88        // starting the actual init
89        let size = probability_vector.len();
90        let mut small:Vec<T> = Vec::new();
91        let mut large:Vec<T> = Vec::new();
92	let mut scaled_probability_vector:HashMap<T, f32> = HashMap::new();
93
94        let mut alias:HashMap<T, T> = HashMap::new();
95        let mut prob:HashMap<T, f32> = HashMap::new();
96
97        // multiply each proba by size
98        for i in 0..size {
99            let p:f32 = probability_vector[i];
100            let e:T = element_vector[i];
101            let scaled_proba = p * (size as f32);
102            scaled_probability_vector.insert(e, scaled_proba);
103
104            if scaled_proba < 1.0 {
105                small.push(e);
106            }
107            else {
108                large.push(e);
109            }
110        }
111
112	// emptying one column first
113        while !(small.is_empty() || large.is_empty()) {    
114	    // removing the element from small and large
115            if let (Some(l), Some(g)) = (small.pop(), large.pop()) {
116		// put g in the alias vector
117		alias.insert(l, g);
118		// getting the probability of the small element
119		if let Some(p_l) = scaled_probability_vector.get(&l) {
120		    // put it in the prob vector
121		    prob.insert(l, *p_l);
122
123		    // update the probability for g
124		    if let Some(p_g) = scaled_probability_vector.get(&g) { 
125			let new_p_g = (*p_g + *p_l) - 1.0;
126			// update scaled_probability_vector
127			scaled_probability_vector.insert(g, new_p_g);
128			if new_p_g < 1.0 {
129			    small.push(g);
130			}
131			else {
132			    large.push(g);
133			}
134		    };
135		    
136		}
137	    }
138        }
139
140	// finishing the init
141	while !large.is_empty() {
142	    if let Some(g) = large.pop() {
143		// println!("Last but not least: g = {}", g);
144		prob.insert(g, 1.0);
145	    };
146	}
147
148	while !small.is_empty() {
149	    if let Some(l) = small.pop() {
150		// println!("Last but not least: l = {}", l);
151		prob.insert(l, 1.0);
152	    }
153	}
154
155        VoseAlias {
156	    elements: element_vector,
157            alias: alias,
158            prob: prob,
159            _private: ()
160        }
161    }
162
163
164    
165    /// Returns a sampled element from a previously created Vose-Alias object.
166    ///
167    /// This function uses a `VoseAlias` object previously created using the method `vose_alias::new()` to sample in linear time an element of type `T`.
168    ///
169    /// # Panics
170    /// This function panics only if the lists created in `vose_alias::new()` are not correctly form, which would indicate a internal bug in the code.
171    /// If your code panics while using this function, please fill in an issue report.
172    ///
173    /// # Examples
174    /// ```
175    /// use vose_alias::VoseAlias;
176    ///
177    /// // Samples an integer from a list and prints it. 
178    /// let va = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
179    /// let element = va.sample();
180    /// println!("{}", element);
181    /// 
182    /// ```
183    pub fn sample(&self) -> T {
184	let (i, num) = self.roll_die_and_flip_coin();
185	return self.select_element(i, num);
186    }
187
188
189    /// This function rolls the die and flip the coin to select the right element using `rand` usual RNG. It returns the generated number. This function is used by the `sample` function and has been decoupled from the `sample` function to allow unit tests on the `sample` function, using pre-determined series of numbers. 
190    fn roll_die_and_flip_coin(&self) -> (T, u16) {
191	let i:T;
192	match self.elements.choose(&mut rand::thread_rng()) {
193	    Some(e) => i = *e,
194	    None => panic!("Internal error. The element vector is empty. If this happened, please fill in an issue report."),
195	}
196	let num = rand::thread_rng().gen_range(0, 101);
197
198	return (i, num);
199	
200    }
201
202
203    /// This function selects an element from the VoseAlias table given a die (a column) and a coin (the element or its alias). This function has been separated from the `sample` function to allow unit testing, but should never be called by itself. 
204    fn select_element(&self, die:T, coin:u16) -> T {
205	// choose randomly an element from the element vector
206	let p_i:f32;
207	match self.prob.get(&die) {
208	    Some(p) => p_i = *p,
209	    None => panic!("Internal error. The probability vector is empty. If this happened, please fill in an issue report."),
210	}
211	if (coin as f32) <= (p_i * 100.0) {
212	    return die;
213	}
214	else {
215	    match self.alias.get(&die) {
216		Some(alias_i) => return *alias_i,
217		None => panic!("Internal error. No alias found for element {:?}. If this happened, please fill in an issue report.", die),
218	    }
219	};
220    }
221    
222}
223
224
225////////////////////////////
226// Traits Implementation  //
227////////////////////////////
228impl <T> Display for VoseAlias<T>
229where T: Display + Copy + Hash + Eq + Debug {
230    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231	// format the elements
232	let mut str_elements = String::from("[ ");
233	for e in &self.elements {
234	    str_elements = str_elements + &e.to_string() + " ";
235	}
236	str_elements = str_elements + "]";
237
238	// format the alias table
239	let mut str_alias = String::from("{ ");
240	for k in self.alias.keys() {
241	    let a:T;
242	    match self.alias.get(&k) {
243		Some(element) => a = *element,
244		None => panic!("Internal error. The alias map does not contain element for {}. If you encountered this error, please fill in an issue report.", k),
245	    }
246	    str_alias = str_alias + &String::from(format!("{}:{}, ", k, a));
247	}
248	// remove the last two characters, that are not needed for the last element
249	str_alias = str_alias[..str_alias.len() - 2].to_string() + " }";
250
251	// fomat the probability table
252	let mut str_prob = String::from("{");
253	for k in self.prob.keys() {
254	    let p:f32;
255	    match self.prob.get(&k) {
256		Some(element) => p = *element,
257		None => panic!("Internal error. The alias map does not contain element for {}. If you encountered this error, please fill in an issue report.", k),
258	    }
259	    str_prob = str_prob + &String::from(format!("{}:{:.2}, ", k, p));
260	}
261	// remove the last two characters, that are not needed for the last element
262	str_prob = str_prob[..str_prob.len() - 2].to_string() + " }";
263
264	// return all of this in a nice string
265	write!(f, "{{ elements: {}, alias: {}, prob: {}}}", str_elements, str_alias, str_prob)
266    }
267}
268
269impl<T> PartialEq for VoseAlias<T>
270where T:Display + Copy + Hash + Eq + Debug {
271    fn eq(&self, other: &Self) -> bool {
272	self.alias == other.alias
273    }
274    
275}
276
277
278impl <T> Eq for VoseAlias<T>
279where T:Display + Copy + Hash + Eq + Debug{
280}
281
282
283
284
285
286
287
288///////////
289// Tests //
290///////////
291#[cfg(test)]
292mod tests{
293    use super::*;
294
295    ////////////////////////////////////////
296    // Tests of the Struct Implementation //
297    ////////////////////////////////////////
298    #[test]
299    fn construction_ok() {
300        VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
301    }
302
303    #[test]
304    #[should_panic]
305    fn size_not_ok() {
306        VoseAlias::new(vec![1, 2, 3], vec![0.5, 0.2, 0.2, 0.1]);
307    }
308
309    #[test]
310    #[should_panic]
311    fn sum_not_ok() {
312        VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.]);
313    }
314
315    #[test]
316    #[should_panic]
317    fn new_empty_vectors() {
318	let element_vector:Vec<u16> = Vec::new();
319	let probability_vector:Vec<f32> = Vec::new();
320	VoseAlias::new(element_vector, probability_vector);
321    }
322    
323    #[test]
324    fn test_roll_die_flip_coin() {
325	let element_vector = vec![1, 2, 3, 4];
326	let va = VoseAlias::new(element_vector.clone(), vec![0.5, 0.2, 0.2, 0.1]);
327	let (die, coin) = va.roll_die_and_flip_coin();
328	assert!(element_vector.contains(&die));
329	assert!(coin <= 100);
330    }
331
332    #[test]
333    fn test_select_element_ok() {
334	let va = VoseAlias::new(vec!["orange", "yellow", "green", "turquoise", "grey", "blue", "pink"], vec![0.125, 0.2, 0.1, 0.25, 0.1, 0.1, 0.125]);
335	// column orange / alias yellow
336	let element = va.select_element("orange", 0);
337	assert!(element == "orange");
338	let element = va.select_element("orange", 87);
339	assert!(element == "orange");
340	let element = va.select_element("orange", 88);
341	assert!(element == "yellow");
342	let element = va.select_element("orange", 100);
343	assert!(element == "yellow");
344
345	// column yellow / no alias
346	let element = va.select_element("yellow", 0);
347	assert!(element == "yellow");
348	let element = va.select_element("yellow", 100);
349	assert!(element == "yellow");
350
351	// column green / alias turquoise
352	let element = va.select_element("green", 0);
353	assert!(element == "green");
354	let element = va.select_element("green", 70);
355	assert!(element == "green");
356	let element = va.select_element("green", 71);
357	assert!(element == "turquoise");
358	let element = va.select_element("green", 100);
359	assert!(element == "turquoise");
360
361	// column turquoise / alias yellow
362	let element = va.select_element("turquoise", 0);
363	assert!(element == "turquoise");
364	let element = va.select_element("turquoise", 72);
365	assert!(element == "turquoise");
366	let element = va.select_element("turquoise", 73);
367	assert!(element == "yellow");
368	let element = va.select_element("turquoise", 100);
369	assert!(element == "yellow");
370
371	// column grey / alias turquoise
372	let element = va.select_element("grey", 0);
373	assert!(element == "grey");
374	let element = va.select_element("grey", 70);
375	assert!(element == "grey");
376	let element = va.select_element("grey", 71);
377	assert!(element == "turquoise");
378	let element = va.select_element("grey", 100);
379	assert!(element == "turquoise");
380
381	// column blue / alias turquoise
382	let element = va.select_element("blue", 0);
383	assert!(element == "blue");
384	let element = va.select_element("blue", 70);
385	assert!(element == "blue");
386	let element = va.select_element("blue", 71);
387	assert!(element == "turquoise");
388	let element = va.select_element("blue", 100);
389	assert!(element == "turquoise");
390
391	// column pink / alias turquoise
392	let element = va.select_element("pink", 0);
393	assert!(element == "pink");
394	let element = va.select_element("pink", 87);
395	assert!(element == "pink");
396	let element = va.select_element("pink", 88);
397	assert!(element == "turquoise");
398	let element = va.select_element("pink", 100);
399	assert!(element == "turquoise");
400    }
401
402
403    #[test]
404    #[should_panic]
405    fn select_element_proba_too_high() {
406	let va = VoseAlias::new(vec!["orange", "yellow", "green", "turquoise", "grey", "blue", "pink"], vec![0.125, 0.2, 0.1, 0.25, 0.1, 0.1, 0.125]);
407	va.select_element("yellow", 101);
408    }
409
410    #[test]
411    #[should_panic]
412    fn select_element_not_in_list() {
413	let va = VoseAlias::new(vec!["orange", "yellow", "green", "turquoise", "grey", "blue", "pink"], vec![0.125, 0.2, 0.1, 0.25, 0.1, 0.1, 0.125]);
414	va.select_element("red", 100);
415    }
416
417
418
419    ///////////////////////////////////////
420    // Tests of the trait implementation //
421    ///////////////////////////////////////
422    #[test]
423    fn test_trait_equal() {
424	let va = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
425	let va2 = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
426	assert!(va==va2);
427    }
428
429    #[test]
430    fn test_trait_not_equali() {
431	let va = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.0, 0.3]);
432	let va2 = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
433	assert!(va!=va2);
434    }
435    
436}