Skip to main content

strontium_core/
net.rs

1use crate::rng::Rng;
2use std::cmp::Reverse;
3use std::collections::{BTreeMap, BinaryHeap, HashSet};
4use std::hash::Hash;
5use std::time::Duration;
6
7#[derive(Debug, Clone)]
8pub struct LinkConfig {
9    pub base_latency: Duration,
10    pub jitter: Duration,
11    pub drop_rate: f64,
12}
13
14impl Default for LinkConfig {
15    fn default() -> Self {
16        Self {
17            base_latency: Duration::ZERO,
18            jitter: Duration::ZERO,
19            drop_rate: 0.0,
20        }
21    }
22}
23
24pub type PartitionId = u64;
25
26struct PartitionRule<K> {
27    id: PartitionId,
28    side_a: HashSet<K>,
29    side_b: HashSet<K>,
30}
31
32impl<K: Eq + Hash> PartitionRule<K> {
33    fn covers(&self, from: &K, to: &K) -> bool {
34        (self.side_a.contains(from) && self.side_b.contains(to))
35            || (self.side_b.contains(from) && self.side_a.contains(to))
36    }
37}
38
39struct TimedEvent<E> {
40    delivery: Reverse<(Duration, u64)>,
41    event: E,
42}
43
44impl<E> PartialEq for TimedEvent<E> {
45    fn eq(&self, other: &Self) -> bool {
46        self.delivery == other.delivery
47    }
48}
49
50impl<E> Eq for TimedEvent<E> {}
51
52impl<E> PartialOrd for TimedEvent<E> {
53    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
54        Some(self.cmp(other))
55    }
56}
57
58impl<E> Ord for TimedEvent<E> {
59    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
60        self.delivery.cmp(&other.delivery)
61    }
62}
63
64pub struct Network<K, E> {
65    rng: Rng,
66    in_flight: BinaryHeap<TimedEvent<E>>,
67    next_seq: u64,
68    partitions: Vec<PartitionRule<K>>,
69    link_configs: BTreeMap<(K, K), LinkConfig>,
70    default_config: LinkConfig,
71    next_partition_id: PartitionId,
72}
73
74impl<K, E> Network<K, E>
75where
76    K: Clone + Eq + Hash + Ord,
77{
78    pub fn new(seed: u64) -> Self {
79        Self {
80            rng: Rng::new(seed),
81            in_flight: BinaryHeap::with_capacity(256),
82            next_seq: 0,
83            partitions: Vec::new(),
84            link_configs: BTreeMap::new(),
85            default_config: LinkConfig::default(),
86            next_partition_id: 0,
87        }
88    }
89
90    pub fn set_default_config(&mut self, config: LinkConfig) {
91        self.default_config = config;
92    }
93
94    pub fn set_link_config(&mut self, from: K, to: K, config: LinkConfig) {
95        self.link_configs.insert((from, to), config);
96    }
97
98    pub fn enqueue(&mut self, event: E, route: Option<(&K, &K)>, now: Duration) -> bool {
99        if let Some((from, to)) = route {
100            if self.is_partition_blocked(from, to) {
101                return false;
102            }
103        }
104
105        let config = route
106            .and_then(|(from, to)| self.link_configs.get(&(from.clone(), to.clone())).cloned())
107            .unwrap_or_else(|| self.default_config.clone());
108
109        if self.should_drop(&config) {
110            return false;
111        }
112
113        let delivery_time = self.delivery_time(&config, now);
114        let seq = self.next_seq;
115        self.next_seq += 1;
116        self.in_flight.push(TimedEvent {
117            delivery: Reverse((delivery_time, seq)),
118            event,
119        });
120
121        true
122    }
123
124    pub fn drain_ready(&mut self, up_to: Duration) -> Vec<E> {
125        let mut ready = Vec::new();
126        while let Some(delivery_time) = self.in_flight.peek().map(|timed| (timed.delivery.0).0) {
127            if delivery_time > up_to {
128                break;
129            }
130            let timed = self.in_flight.pop().expect("peeked event must exist");
131            ready.push(timed.event);
132        }
133        ready
134    }
135
136    pub fn partition(
137        &mut self,
138        side_a: impl IntoIterator<Item = K>,
139        side_b: impl IntoIterator<Item = K>,
140    ) -> PartitionId {
141        let id = self.next_partition_id;
142        self.next_partition_id += 1;
143        self.partitions.push(PartitionRule {
144            id,
145            side_a: side_a.into_iter().collect(),
146            side_b: side_b.into_iter().collect(),
147        });
148        id
149    }
150
151    pub fn heal_partition(&mut self, id: PartitionId) {
152        self.partitions.retain(|r| r.id != id);
153    }
154
155    pub fn heal_all(&mut self) {
156        self.partitions.clear();
157    }
158
159    pub fn in_flight_count(&self) -> usize {
160        self.in_flight.len()
161    }
162
163    fn is_partition_blocked(&self, from: &K, to: &K) -> bool {
164        self.partitions.iter().any(|rule| rule.covers(from, to))
165    }
166
167    fn should_drop(&mut self, config: &LinkConfig) -> bool {
168        if config.drop_rate <= 0.0 {
169            return false;
170        }
171        let roll = self.rng.next_u64() as f64 / u64::MAX as f64;
172        roll < config.drop_rate
173    }
174
175    fn delivery_time(&mut self, config: &LinkConfig, now: Duration) -> Duration {
176        let jitter_nanos = if config.jitter.as_nanos() > 0 {
177            self.rng.next_u64() % (config.jitter.as_nanos() as u64 + 1)
178        } else {
179            0
180        };
181        now + config.base_latency + Duration::from_nanos(jitter_nanos)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::{LinkConfig, Network};
188    use std::time::Duration;
189
190    #[test]
191    fn partition_blocks_routed_events() {
192        let mut net = Network::<u64, u64>::new(1);
193        let part = net.partition([1], [2]);
194        assert!(!net.enqueue(7, Some((&1, &2)), Duration::ZERO));
195        net.heal_partition(part);
196        assert!(net.enqueue(7, Some((&1, &2)), Duration::ZERO));
197    }
198
199    #[test]
200    fn latency_and_jitter_delay_delivery() {
201        let mut net = Network::<u64, u64>::new(2);
202        net.set_default_config(LinkConfig {
203            base_latency: Duration::from_millis(5),
204            jitter: Duration::ZERO,
205            drop_rate: 0.0,
206        });
207        net.enqueue(9, None, Duration::ZERO);
208        assert!(net.drain_ready(Duration::from_millis(4)).is_empty());
209        assert_eq!(net.drain_ready(Duration::from_millis(5)), vec![9]);
210    }
211}