1use bytes::Bytes;
7use std::cell::RefCell;
8use std::collections::{HashMap, VecDeque};
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12use super::clock::SimClock;
13use super::fault::{FaultInjector, FaultType};
14use super::rng::DeterministicRng;
15use crate::constants::{
16 NETWORK_JITTER_MS_DEFAULT, NETWORK_LATENCY_MS_DEFAULT, NETWORK_LATENCY_MS_MAX,
17};
18
19#[derive(Debug, Clone)]
21pub struct NetworkMessage {
22 pub from: String,
24 pub to: String,
26 pub payload: Bytes,
28 pub deliver_at_ms: u64,
30}
31
32#[derive(Debug, Clone, thiserror::Error)]
34pub enum NetworkError {
35 #[error("network partition between {from} and {to}")]
37 Partitioned {
38 from: String,
40 to: String,
42 },
43
44 #[error("packet loss fault injected")]
46 PacketLoss,
47
48 #[error("connection timeout")]
50 Timeout,
51
52 #[error("connection refused")]
54 ConnectionRefused,
55}
56
57pub struct SimNetwork {
65 messages: Arc<RwLock<HashMap<String, VecDeque<NetworkMessage>>>>,
67 partitions: Arc<RwLock<Vec<(String, String)>>>,
69 clock: SimClock,
71 fault_injector: Arc<FaultInjector>,
73 rng: RefCell<DeterministicRng>,
75 base_latency_ms: u64,
77 latency_jitter_ms: u64,
79}
80
81impl SimNetwork {
82 #[must_use]
86 pub fn new(clock: SimClock, rng: DeterministicRng, fault_injector: Arc<FaultInjector>) -> Self {
87 Self {
88 messages: Arc::new(RwLock::new(HashMap::new())),
89 partitions: Arc::new(RwLock::new(Vec::new())),
90 clock,
91 fault_injector,
92 rng: RefCell::new(rng),
93 base_latency_ms: NETWORK_LATENCY_MS_DEFAULT,
94 latency_jitter_ms: NETWORK_JITTER_MS_DEFAULT,
95 }
96 }
97
98 #[must_use]
103 pub fn with_latency(mut self, base_ms: u64, jitter_ms: u64) -> Self {
104 assert!(
106 base_ms <= NETWORK_LATENCY_MS_MAX,
107 "base_latency_ms {} exceeds max {}",
108 base_ms,
109 NETWORK_LATENCY_MS_MAX
110 );
111
112 self.base_latency_ms = base_ms;
113 self.latency_jitter_ms = jitter_ms;
114 self
115 }
116
117 pub async fn send(&self, from: &str, to: &str, payload: Bytes) -> bool {
121 assert!(!from.is_empty(), "from node ID cannot be empty");
123 assert!(!to.is_empty(), "to node ID cannot be empty");
124
125 {
127 let partitions = self.partitions.read().await;
128 if partitions
129 .iter()
130 .any(|(a, b)| (a == from && b == to) || (a == to && b == from))
131 {
132 tracing::debug!(from = from, to = to, "Message dropped: network partition");
133 return false;
134 }
135 }
136
137 if let Some(fault) = self.fault_injector.should_inject("network_send") {
139 match fault {
140 FaultType::NetworkTimeout
141 | FaultType::NetworkConnectionRefused
142 | FaultType::NetworkReset => {
143 tracing::debug!(from = from, to = to, fault = ?fault, "Message dropped: fault");
144 return false;
145 }
146 _ => {}
147 }
148 }
149
150 let latency = self.calculate_latency();
152 let deliver_at_ms = self.clock.now_ms() + latency;
153
154 let message = NetworkMessage {
155 from: from.to_string(),
156 to: to.to_string(),
157 payload,
158 deliver_at_ms,
159 };
160
161 let mut messages = self.messages.write().await;
163 messages
164 .entry(to.to_string())
165 .or_default()
166 .push_back(message);
167
168 true
169 }
170
171 pub async fn receive(&self, node_id: &str) -> Vec<NetworkMessage> {
175 assert!(!node_id.is_empty(), "node_id cannot be empty");
177
178 let current_time = self.clock.now_ms();
179 let mut messages = self.messages.write().await;
180
181 let queue = match messages.get_mut(node_id) {
182 Some(q) => q,
183 None => return Vec::new(),
184 };
185
186 let mut ready = Vec::new();
188 let mut remaining = VecDeque::new();
189
190 while let Some(msg) = queue.pop_front() {
191 if msg.deliver_at_ms <= current_time {
192 ready.push(msg);
193 } else {
194 remaining.push_back(msg);
195 }
196 }
197
198 *queue = remaining;
199
200 if !ready.is_empty() {
202 if let Some(FaultType::NetworkPartialWrite) =
203 self.fault_injector.should_inject("network_receive")
204 {
205 self.rng.borrow_mut().shuffle(&mut ready);
206 tracing::debug!(node_id = node_id, "Messages reordered by fault");
207 }
208 }
209
210 ready
211 }
212
213 pub async fn partition(&self, node_a: &str, node_b: &str) {
217 assert!(!node_a.is_empty(), "node_a cannot be empty");
219 assert!(!node_b.is_empty(), "node_b cannot be empty");
220 assert_ne!(node_a, node_b, "cannot partition node with itself");
221
222 let mut partitions = self.partitions.write().await;
223 partitions.push((node_a.to_string(), node_b.to_string()));
224
225 tracing::info!(
226 node_a = node_a,
227 node_b = node_b,
228 "Network partition created"
229 );
230 }
231
232 pub async fn heal(&self, node_a: &str, node_b: &str) {
234 let mut partitions = self.partitions.write().await;
235 partitions.retain(|(a, b)| !((a == node_a && b == node_b) || (a == node_b && b == node_a)));
236
237 tracing::info!(node_a = node_a, node_b = node_b, "Network partition healed");
238 }
239
240 pub async fn heal_all(&self) {
242 let mut partitions = self.partitions.write().await;
243 partitions.clear();
244
245 tracing::info!("All network partitions healed");
246 }
247
248 pub async fn is_partitioned(&self, node_a: &str, node_b: &str) -> bool {
250 let partitions = self.partitions.read().await;
251 partitions
252 .iter()
253 .any(|(a, b)| (a == node_a && b == node_b) || (a == node_b && b == node_a))
254 }
255
256 pub async fn pending_count(&self, node_id: &str) -> usize {
258 let messages = self.messages.read().await;
259 messages.get(node_id).map(|q| q.len()).unwrap_or(0)
260 }
261
262 pub async fn total_pending(&self) -> usize {
264 let messages = self.messages.read().await;
265 messages.values().map(|q| q.len()).sum()
266 }
267
268 pub async fn clear(&self) {
270 let mut messages = self.messages.write().await;
271 messages.clear();
272 }
273
274 #[must_use]
276 pub fn clock(&self) -> &SimClock {
277 &self.clock
278 }
279
280 fn calculate_latency(&self) -> u64 {
282 let jitter = if self.latency_jitter_ms > 0 {
283 self.rng
284 .borrow_mut()
285 .next_usize(0, self.latency_jitter_ms as usize) as u64
286 } else {
287 0
288 };
289 self.base_latency_ms + jitter
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::dst::fault::FaultInjectorBuilder;
297
298 fn create_network() -> SimNetwork {
299 let clock = SimClock::new();
300 let mut rng = DeterministicRng::new(42);
301 let fault_injector = Arc::new(FaultInjectorBuilder::new(rng.fork()).build());
302 SimNetwork::new(clock, rng, fault_injector).with_latency(0, 0)
303 }
304
305 #[tokio::test]
306 async fn test_send_and_receive() {
307 let network = create_network();
308
309 let sent = network.send("node-1", "node-2", Bytes::from("hello")).await;
311 assert!(sent);
312
313 let messages = network.receive("node-2").await;
315 assert_eq!(messages.len(), 1);
316 assert_eq!(messages[0].payload, Bytes::from("hello"));
317 assert_eq!(messages[0].from, "node-1");
318 assert_eq!(messages[0].to, "node-2");
319 }
320
321 #[tokio::test]
322 async fn test_partition() {
323 let network = create_network();
324
325 network.partition("node-1", "node-2").await;
327 assert!(network.is_partitioned("node-1", "node-2").await);
328 assert!(network.is_partitioned("node-2", "node-1").await); let sent = network.send("node-1", "node-2", Bytes::from("hello")).await;
332 assert!(!sent);
333
334 network.heal("node-1", "node-2").await;
336 assert!(!network.is_partitioned("node-1", "node-2").await);
337
338 let sent = network.send("node-1", "node-2", Bytes::from("hello")).await;
340 assert!(sent);
341 }
342
343 #[tokio::test]
344 async fn test_latency() {
345 let clock = SimClock::new();
346 let mut rng = DeterministicRng::new(42);
347 let fault_injector = Arc::new(FaultInjectorBuilder::new(rng.fork()).build());
348 let network = SimNetwork::new(clock.clone(), rng, fault_injector).with_latency(100, 0);
349
350 network.send("node-1", "node-2", Bytes::from("hello")).await;
352
353 let messages = network.receive("node-2").await;
355 assert!(messages.is_empty());
356
357 clock.advance_ms(100);
359
360 let messages = network.receive("node-2").await;
362 assert_eq!(messages.len(), 1);
363 }
364
365 #[tokio::test]
366 async fn test_multiple_messages() {
367 let network = create_network();
368
369 network.send("node-1", "node-2", Bytes::from("msg1")).await;
371 network.send("node-1", "node-2", Bytes::from("msg2")).await;
372 network.send("node-3", "node-2", Bytes::from("msg3")).await;
373
374 assert_eq!(network.pending_count("node-2").await, 3);
375 assert_eq!(network.total_pending().await, 3);
376
377 let messages = network.receive("node-2").await;
379 assert_eq!(messages.len(), 3);
380 assert_eq!(network.pending_count("node-2").await, 0);
381 }
382
383 #[tokio::test]
384 async fn test_heal_all() {
385 let network = create_network();
386
387 network.partition("node-1", "node-2").await;
389 network.partition("node-2", "node-3").await;
390 network.partition("node-1", "node-3").await;
391
392 assert!(network.is_partitioned("node-1", "node-2").await);
393 assert!(network.is_partitioned("node-2", "node-3").await);
394
395 network.heal_all().await;
397
398 assert!(!network.is_partitioned("node-1", "node-2").await);
399 assert!(!network.is_partitioned("node-2", "node-3").await);
400 assert!(!network.is_partitioned("node-1", "node-3").await);
401 }
402
403 #[tokio::test]
404 async fn test_clear() {
405 let network = create_network();
406
407 network.send("node-1", "node-2", Bytes::from("msg1")).await;
408 network.send("node-1", "node-2", Bytes::from("msg2")).await;
409
410 assert_eq!(network.total_pending().await, 2);
411
412 network.clear().await;
413
414 assert_eq!(network.total_pending().await, 0);
415 }
416
417 #[test]
418 #[should_panic(expected = "from node ID cannot be empty")]
419 fn test_send_empty_from() {
420 let network = create_network();
421 let _ = tokio_test::block_on(network.send("", "node-2", Bytes::from("hello")));
422 }
423
424 #[test]
425 #[should_panic(expected = "cannot partition node with itself")]
426 fn test_partition_self() {
427 let network = create_network();
428 let _ = tokio_test::block_on(network.partition("node-1", "node-1"));
429 }
430}