1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
use super::{ChildLoadBalancingPolicy, LoadBalancingPolicy, Plan, Statement};
use crate::transport::{cluster::ClusterData, node::Node};
use std::sync::{
    atomic::{AtomicUsize, Ordering},
    Arc,
};
use tracing::trace;

/// A Round-robin load balancing policy.
pub struct RoundRobinPolicy {
    index: AtomicUsize,
}

impl RoundRobinPolicy {
    pub fn new() -> Self {
        Self {
            index: AtomicUsize::new(0),
        }
    }
}

impl Default for RoundRobinPolicy {
    fn default() -> Self {
        Self::new()
    }
}

const ORDER_TYPE: Ordering = Ordering::Relaxed;

impl LoadBalancingPolicy for RoundRobinPolicy {
    fn plan<'a>(&self, _statement: &Statement, cluster: &'a ClusterData) -> Plan<'a> {
        let index = self.index.fetch_add(1, ORDER_TYPE);

        let nodes_count = cluster.all_nodes.len();
        let rotation = super::compute_rotation(index, nodes_count);
        let rotated_nodes = super::slice_rotated_left(&cluster.all_nodes, rotation).cloned();
        trace!(
            nodes = rotated_nodes
                .clone()
                .map(|node| node.address.to_string())
                .collect::<Vec<String>>()
                .join(",")
                .as_str(),
            "RoundRobin"
        );

        Box::new(rotated_nodes)
    }

    fn name(&self) -> String {
        "RoundRobinPolicy".to_string()
    }
}

impl ChildLoadBalancingPolicy for RoundRobinPolicy {
    fn apply_child_policy(
        &self,
        mut plan: Vec<Arc<Node>>,
    ) -> Box<dyn Iterator<Item = Arc<Node>> + Send + Sync> {
        let index = self.index.fetch_add(1, ORDER_TYPE);

        let len = plan.len(); // borrow checker forces making such a variable

        plan.rotate_left(super::compute_rotation(index, len));
        Box::new(plan.into_iter())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::transport::load_balancing::tests;
    use std::collections::HashSet;

    // ConnectionKeeper (which lives in Node) requires context of Tokio runtime
    #[tokio::test]
    async fn test_round_robin_policy() {
        let cluster = tests::mock_cluster_data_for_round_robin_tests();

        let policy = RoundRobinPolicy::new();

        let plans = (0..16)
            .map(|_| {
                tests::get_plan_and_collect_node_identifiers(
                    &policy,
                    &tests::EMPTY_STATEMENT,
                    &cluster,
                )
            })
            .collect::<HashSet<_>>();

        // Check if `plans` contains all possible round robin plans
        let expected_plans = vec![
            vec![1, 2, 3, 4, 5],
            vec![2, 3, 4, 5, 1],
            vec![3, 4, 5, 1, 2],
            vec![4, 5, 1, 2, 3],
            vec![5, 1, 2, 3, 4],
        ]
        .into_iter()
        .collect::<HashSet<Vec<_>>>();

        assert_eq!(expected_plans, plans);
    }
}