use std::fmt;
use std::fmt::Display;
use std::hash::Hash;
use std::fmt::Debug;
use float_cmp::*;
use std::collections::HashMap;
use rand::seq::SliceRandom;
use rand::Rng;
#[derive(Debug, Clone)]
pub struct VoseAlias <T> where T: Display + Copy + Hash + Eq + Debug{
pub elements:Vec<T>,
pub alias:HashMap<T, T>,
pub prob:HashMap<T, f32>,
_private:()
}
impl<T> VoseAlias<T>
where T: Display + Copy + Hash + Eq + Debug {
pub fn new(element_vector:Vec<T>, probability_vector:Vec<f32>) -> VoseAlias<T> {
let size_p = probability_vector.len();
let size_e = element_vector.len();
if size_p != size_e {
panic!("Both vectors should contain the same number of elements");
}
let mut sum = 0.0;
for p in &probability_vector {
sum = sum + p;
}
if !approx_eq!(f32, sum, 1.0, ulps=4) {
panic!("Probability vector does not sum to 1");
}
let size = probability_vector.len();
let mut small:Vec<T> = Vec::new();
let mut large:Vec<T> = Vec::new();
let mut scaled_probability_vector:HashMap<T, f32> = HashMap::new();
let mut alias:HashMap<T, T> = HashMap::new();
let mut prob:HashMap<T, f32> = HashMap::new();
for i in 0..size {
let p:f32 = probability_vector[i];
let e:T = element_vector[i];
let scaled_proba = p * (size as f32);
scaled_probability_vector.insert(e, scaled_proba);
if scaled_proba < 1.0 {
small.push(e);
}
else {
large.push(e);
}
}
while !(small.is_empty() || large.is_empty()) {
if let (Some(l), Some(g)) = (small.pop(), large.pop()) {
alias.insert(l, g);
if let Some(p_l) = scaled_probability_vector.get(&l) {
prob.insert(l, *p_l);
if let Some(p_g) = scaled_probability_vector.get(&g) {
let new_p_g = (*p_g + *p_l) - 1.0;
scaled_probability_vector.insert(g, new_p_g);
if new_p_g < 1.0 {
small.push(g);
}
else {
large.push(g);
}
};
}
}
}
while !large.is_empty() {
if let Some(g) = large.pop() {
prob.insert(g, 1.0);
};
}
while !small.is_empty() {
if let Some(l) = small.pop() {
prob.insert(l, 1.0);
}
}
VoseAlias {
elements: element_vector,
alias: alias,
prob: prob,
_private: ()
}
}
pub fn sample(&self) -> T {
let (i, num) = self.roll_die_and_flip_coin();
return self.select_element(i, num);
}
fn roll_die_and_flip_coin(&self) -> (T, u16) {
let i:T;
match self.elements.choose(&mut rand::thread_rng()) {
Some(e) => i = *e,
None => panic!("Internal error. The element vector is empty. If this happened, please fill in an issue report."),
}
let num = rand::thread_rng().gen_range(0, 101);
return (i, num);
}
fn select_element(&self, die:T, coin:u16) -> T {
let p_i:f32;
match self.prob.get(&die) {
Some(p) => p_i = *p,
None => panic!("Internal error. The probability vector is empty. If this happened, please fill in an issue report."),
}
if (coin as f32) <= (p_i * 100.0) {
return die;
}
else {
match self.alias.get(&die) {
Some(alias_i) => return *alias_i,
None => panic!("Internal error. No alias found for element {:?}. If this happened, please fill in an issue report.", die),
}
};
}
}
impl <T> Display for VoseAlias<T>
where T: Display + Copy + Hash + Eq + Debug {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut str_elements = String::from("[ ");
for e in &self.elements {
str_elements = str_elements + &e.to_string() + " ";
}
str_elements = str_elements + "]";
let mut str_alias = String::from("{ ");
for k in self.alias.keys() {
let a:T;
match self.alias.get(&k) {
Some(element) => a = *element,
None => panic!("Internal error. The alias map does not contain element for {}. If you encountered this error, please fill in an issue report.", k),
}
str_alias = str_alias + &String::from(format!("{}:{}, ", k, a));
}
str_alias = str_alias[..str_alias.len() - 2].to_string() + " }";
let mut str_prob = String::from("{");
for k in self.prob.keys() {
let p:f32;
match self.prob.get(&k) {
Some(element) => p = *element,
None => panic!("Internal error. The alias map does not contain element for {}. If you encountered this error, please fill in an issue report.", k),
}
str_prob = str_prob + &String::from(format!("{}:{:.2}, ", k, p));
}
str_prob = str_prob[..str_prob.len() - 2].to_string() + " }";
write!(f, "{{ elements: {}, alias: {}, prob: {}}}", str_elements, str_alias, str_prob)
}
}
impl<T> PartialEq for VoseAlias<T>
where T:Display + Copy + Hash + Eq + Debug {
fn eq(&self, other: &Self) -> bool {
self.alias == other.alias
}
}
impl <T> Eq for VoseAlias<T>
where T:Display + Copy + Hash + Eq + Debug{
}
#[cfg(test)]
mod tests{
use super::*;
#[test]
fn construction_ok() {
VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
}
#[test]
#[should_panic]
fn size_not_ok() {
VoseAlias::new(vec![1, 2, 3], vec![0.5, 0.2, 0.2, 0.1]);
}
#[test]
#[should_panic]
fn sum_not_ok() {
VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.]);
}
#[test]
#[should_panic]
fn new_empty_vectors() {
let element_vector:Vec<u16> = Vec::new();
let probability_vector:Vec<f32> = Vec::new();
VoseAlias::new(element_vector, probability_vector);
}
#[test]
fn test_roll_die_flip_coin() {
let element_vector = vec![1, 2, 3, 4];
let va = VoseAlias::new(element_vector.clone(), vec![0.5, 0.2, 0.2, 0.1]);
let (die, coin) = va.roll_die_and_flip_coin();
assert!(element_vector.contains(&die));
assert!(coin <= 100);
}
#[test]
fn test_select_element_ok() {
let va = VoseAlias::new(vec!["orange", "yellow", "green", "turquoise", "grey", "blue", "pink"], vec![0.125, 0.2, 0.1, 0.25, 0.1, 0.1, 0.125]);
let element = va.select_element("orange", 0);
assert!(element == "orange");
let element = va.select_element("orange", 87);
assert!(element == "orange");
let element = va.select_element("orange", 88);
assert!(element == "yellow");
let element = va.select_element("orange", 100);
assert!(element == "yellow");
let element = va.select_element("yellow", 0);
assert!(element == "yellow");
let element = va.select_element("yellow", 100);
assert!(element == "yellow");
let element = va.select_element("green", 0);
assert!(element == "green");
let element = va.select_element("green", 70);
assert!(element == "green");
let element = va.select_element("green", 71);
assert!(element == "turquoise");
let element = va.select_element("green", 100);
assert!(element == "turquoise");
let element = va.select_element("turquoise", 0);
assert!(element == "turquoise");
let element = va.select_element("turquoise", 72);
assert!(element == "turquoise");
let element = va.select_element("turquoise", 73);
assert!(element == "yellow");
let element = va.select_element("turquoise", 100);
assert!(element == "yellow");
let element = va.select_element("grey", 0);
assert!(element == "grey");
let element = va.select_element("grey", 70);
assert!(element == "grey");
let element = va.select_element("grey", 71);
assert!(element == "turquoise");
let element = va.select_element("grey", 100);
assert!(element == "turquoise");
let element = va.select_element("blue", 0);
assert!(element == "blue");
let element = va.select_element("blue", 70);
assert!(element == "blue");
let element = va.select_element("blue", 71);
assert!(element == "turquoise");
let element = va.select_element("blue", 100);
assert!(element == "turquoise");
let element = va.select_element("pink", 0);
assert!(element == "pink");
let element = va.select_element("pink", 87);
assert!(element == "pink");
let element = va.select_element("pink", 88);
assert!(element == "turquoise");
let element = va.select_element("pink", 100);
assert!(element == "turquoise");
}
#[test]
#[should_panic]
fn select_element_proba_too_high() {
let va = VoseAlias::new(vec!["orange", "yellow", "green", "turquoise", "grey", "blue", "pink"], vec![0.125, 0.2, 0.1, 0.25, 0.1, 0.1, 0.125]);
va.select_element("yellow", 101);
}
#[test]
#[should_panic]
fn select_element_not_in_list() {
let va = VoseAlias::new(vec!["orange", "yellow", "green", "turquoise", "grey", "blue", "pink"], vec![0.125, 0.2, 0.1, 0.25, 0.1, 0.1, 0.125]);
va.select_element("red", 100);
}
#[test]
fn test_trait_equal() {
let va = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
let va2 = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
assert!(va==va2);
}
#[test]
fn test_trait_not_equali() {
let va = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.0, 0.3]);
let va2 = VoseAlias::new(vec![1, 2, 3, 4], vec![0.5, 0.2, 0.2, 0.1]);
assert!(va!=va2);
}
}