weighted_rs/
smooth_weight.rs

1use super::Weight;
2use std::{collections::HashMap, hash::Hash};
3
4#[derive(Clone, Debug)]
5struct SmoothWeightItem<T> {
6    item: T,
7    weight: isize,
8    current_weight: isize,
9    effective_weight: isize,
10}
11
12// SW (Smooth Weighted) is a struct that contains weighted items and provides methods to select a
13// weighted item. It is used for the smooth weighted round-robin balancing algorithm. This algorithm
14// is implemented in Nginx: https://github.com/phusion/nginx/commit/27e94984486058d73157038f7950a0a36ecc6e35.
15// Algorithm is as follows: on each peer selection we increase current_weight
16// of each eligible peer by its weight, select peer with greatest current_weight
17// and reduce its current_weight by total number of weight points distributed
18// among peers.
19// In case of { 5, 1, 1 } weights this gives the following sequence of
20// current_weight's: (a, a, b, a, c, a, a)
21#[derive(Default)]
22pub struct SmoothWeight<T> {
23    items: Vec<SmoothWeightItem<T>>,
24    n: isize,
25}
26
27impl<T: Clone + PartialEq + Eq + Hash> SmoothWeight<T> {
28    pub fn new() -> Self {
29        SmoothWeight {
30            items: Vec::new(),
31            n: 0,
32        }
33    }
34
35    //https://github.com/phusion/nginx/commit/27e94984486058d73157038f7950a0a36ecc6e35
36    fn next_smooth_weighted(&mut self) -> Option<SmoothWeightItem<T>> {
37        let mut total = 0;
38
39        let mut best = self.items[0].clone();
40        let mut best_index = 0;
41        let mut found = false;
42
43        let items_len = self.items.len();
44        for i in 0..items_len {
45            self.items[i].current_weight += self.items[i].effective_weight;
46            total += self.items[i].effective_weight;
47            if self.items[i].effective_weight < self.items[i].weight {
48                self.items[i].effective_weight += 1;
49            }
50
51            if !found || self.items[i].current_weight > best.current_weight {
52                best = self.items[i].clone();
53                found = true;
54                best_index = i;
55            }
56        }
57
58        if !found {
59            return None;
60        }
61
62        self.items[best_index].current_weight -= total;
63        Some(best)
64    }
65}
66
67impl<T: Clone + PartialEq + Eq + Hash> Weight for SmoothWeight<T> {
68    type Item = T;
69
70    fn next(&mut self) -> Option<T> {
71        if self.n == 0 {
72            return None;
73        }
74        if self.n == 1 {
75            return Some(self.items[0].item.clone());
76        }
77
78        let rt = self.next_smooth_weighted()?;
79        Some(rt.item)
80    }
81    // add adds a weighted item for selection.
82    fn add(&mut self, item: T, weight: isize) {
83        let weight_item = SmoothWeightItem {
84            item,
85            weight,
86            current_weight: 0,
87            effective_weight: weight,
88        };
89
90        self.items.push(weight_item);
91        self.n += 1;
92    }
93
94    // all returns all items.
95    fn all(&self) -> HashMap<T, isize> {
96        let mut rt: HashMap<T, isize> = HashMap::new();
97        for w in &self.items {
98            rt.insert(w.item.clone(), w.weight);
99        }
100        rt
101    }
102
103    // remove_all removes all weighted items.
104    fn remove_all(&mut self) {
105        self.items.clear();
106        self.n = 0;
107    }
108
109    // reset resets the balancing algorithm.
110    fn reset(&mut self) {
111        for w in &mut self.items {
112            w.current_weight = 0;
113            w.effective_weight = w.weight;
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use crate::{SmoothWeight, Weight};
121    use std::collections::HashMap;
122
123    #[test]
124    fn test_smooth_weight() {
125        let mut sw: SmoothWeight<&str> = SmoothWeight::new();
126        sw.add("server1", 5);
127        sw.add("server2", 2);
128        sw.add("server3", 3);
129
130        let mut results: HashMap<&str, usize> = HashMap::new();
131
132        for _ in 0..100 {
133            let s = sw.next().unwrap();
134            // *results.get_mut(s).unwrap() += 1;
135            *results.entry(s).or_insert(0) += 1;
136        }
137
138        assert_eq!(results["server1"], 50);
139        assert_eq!(results["server2"], 20);
140        assert_eq!(results["server3"], 30);
141    }
142}