use std::iter::FromIterator;
use rand::Rng;
pub struct WeightedRandomList<T> {
sum_weights: usize,
data: Vec<(usize, T)>,
}
impl<T> WeightedRandomList<T> {
pub fn new() -> WeightedRandomList<T> {
WeightedRandomList {
sum_weights: 0,
data: Vec::new(),
}
}
pub fn push(&mut self, weight: usize, new: T) {
self.sum_weights = self.sum_weights.checked_add(weight)
.expect("WeightedRandomList.push() sum of weights exceided data type");
self.data.push((weight , new))
}
pub fn get_random(&self) -> Option<&T> {
if self.data.is_empty() {
return None;
}
let r = rand::thread_rng().gen_range(0, self.sum_weights);
let mut local_sum = 0;
for (weight, data) in &self.data {
local_sum += weight;
if local_sum >= r {
return Some(data)
}
}
None
}
}
impl<T> WeightedRandomList<T> where T: std::cmp::PartialEq + Clone {
pub fn entry_first(&mut self, needle: &T) -> Entry<T> {
for (index, (_weight, data)) in self.data.iter().enumerate() {
if needle == data {
return Entry::Occupied {
arena: self,
index,
}
}
}
Entry::Vacant {
arena: self,
needle: needle.clone(),
}
}
}
pub enum Entry<'container, T> {
Vacant {
arena: &'container mut WeightedRandomList<T>,
needle: T,
},
Occupied {
arena: &'container mut WeightedRandomList<T>,
index: usize,
}
}
impl<'container, T> Entry<'container, T> {
pub fn set_weight(&mut self, new_weight: usize) {
use Entry::*;
match self {
Vacant { .. } => { unimplemented!("set_weight on variant Vacant"); }
Occupied { arena, index } => {
let (weight, _data) = &mut arena.data
.get_mut(*index)
.expect("unable to directly access data from Entry view");
arena.sum_weights -= *weight;
arena.sum_weights += new_weight;
*weight = new_weight;
}
}
}
pub fn or_insert_with_weight(self, weight: usize) -> Self {
use Entry::*;
match self {
Occupied { .. } => { self }
Vacant { arena, needle } => {
let index = arena.data.len();
arena.push(weight, needle);
Occupied {
arena,
index,
}
}
}
}
pub fn delete(self) -> Result<(), ()> {
use Entry::*;
match self {
Vacant { .. } => { Err(()) }
Occupied { arena, index } => {
let len = arena.data.len();
if len != (index + 1) {
arena.data.swap(index, len - 1);
}
let (weight, _data) = arena.data.pop().expect("must be the last element");
arena.sum_weights -= weight;
Ok(())
}
}
}
}
impl<T> FromIterator<(usize, T)> for WeightedRandomList<T> {
fn from_iter<I: IntoIterator<Item=(usize, T)>>(iter: I) -> Self {
let mut c = WeightedRandomList::new();
for (weight, data) in iter {
c.push(weight, data);
}
c
}
}
#[cfg(test)]
mod tests {
use crate::*;
use std::collections::HashMap;
#[test]
fn simple_inserts() {
let mut l = WeightedRandomList::new();
assert_eq!(None, l.get_random());
l.push(100, 42usize);
assert_eq!(Some(&42), l.get_random());
l.push(0, 23);
for _ in 0..1000 {
assert_eq!(Some(&42), l.get_random());
}
l.push(900, 12);
let results = (0..1000).map(|_| l.get_random().unwrap().clone())
.fold(HashMap::new(), |mut map, n| {
let count = map.entry(n).or_insert(0);
*count += 1;
map
});
assert!(results[&42] < results[&12]);
}
#[test]
fn modify_weight() {
let mut l = WeightedRandomList::new();
l.push(10, 1);
l.push(10, 2);
l.push(100, 3);
assert_eq!(120, l.sum_weights);
{
let mut three = l.entry_first(&3);
three.set_weight(80);
}
assert_eq!(100, l.sum_weights);
let results = (0..1000).map(|_| l.get_random().unwrap().clone())
.fold(HashMap::new(), |mut map, n| {
let count = map.entry(n).or_insert(0);
*count += 1;
map
});
assert!(results[&1] + results[&2] < results[&3]);
}
#[test]
fn or_insert() {
let mut l = WeightedRandomList::new();
l.entry_first(&1).or_insert_with_weight(10);
{
assert_eq!(10, l.sum_weights);
let mut two = l.entry_first(&2).or_insert_with_weight(1);
two.set_weight(10);
assert_eq!(20, l.sum_weights);
}
l.entry_first(&3).or_insert_with_weight(80);
let results = (0..1000).map(|_| l.get_random().unwrap().clone())
.fold(HashMap::new(), |mut map, n| {
let count = map.entry(n).or_insert(0);
*count += 1;
map
});
assert!(results[&1] + results[&2] < results[&3]);
}
#[test]
fn delete_last() {
let mut l = WeightedRandomList::new();
l.push(100, 3);
assert_eq!(1, l.data.len());
assert_eq!(100, l.sum_weights);
{
let three = l.entry_first(&3);
assert!(three.delete().is_ok());
}
assert_eq!(0, l.sum_weights);
assert_eq!(0, l.data.len());
}
#[test]
fn delete_mid() {
let mut l = WeightedRandomList::new();
assert_eq!(0, l.data.len());
l.push(10, 1);
assert_eq!(1, l.data.len());
l.push(10, 2);
assert_eq!(2, l.data.len());
l.push(100, 3);
assert_eq!(3, l.data.len());
assert_eq!(120, l.sum_weights);
{
let two = l.entry_first(&2);
assert!(two.delete().is_ok());
}
assert_eq!(110, l.sum_weights); assert_eq!(2, l.data.len());
}
#[test]
fn delete_end() {
let mut l = WeightedRandomList::new();
assert_eq!(0, l.data.len());
l.push(10, 1);
assert_eq!(1, l.data.len());
l.push(10, 2);
assert_eq!(2, l.data.len());
l.push(100, 3);
assert_eq!(3, l.data.len());
assert_eq!(120, l.sum_weights);
{
let three = l.entry_first(&3);
assert!(three.delete().is_ok());
}
assert_eq!(20, l.sum_weights);
assert_eq!(2, l.data.len());
}
#[test]
fn delete_not_found() {
let mut l = WeightedRandomList::new();
assert_eq!(0, l.data.len());
l.push(10, 1);
assert_eq!(1, l.data.len());
l.push(10, 2);
assert_eq!(2, l.data.len());
l.push(100, 3);
assert_eq!(3, l.data.len());
assert_eq!(120, l.sum_weights);
{
let absent = l.entry_first(&22);
assert!(absent.delete().is_err());
}
assert_eq!(120, l.sum_weights);
assert_eq!(3, l.data.len());
}
#[test]
fn collect() {
let list = [
(1, "https://source.example.net/archive"),
(10, "https://mirror-slow0.example.net/archive"),
(10, "https://mirror-slow1.example.net/archive"),
(100, "https://mirror-fast.example.net/archive"),
];
let collected = list.iter()
.map(|(weight, url)| (*weight, url.to_string()))
.collect::<WeightedRandomList<String>>();
assert_eq!(list.len(), collected.data.len());
for ((w, url), (wc, urlc)) in list.iter().zip(collected.data.iter()) {
assert_eq!(w, wc);
assert_eq!(url, urlc);
}
assert_eq!(list.iter().map(|(w,_)| w).sum::<usize>(), collected.sum_weights);
}
}