1use std::io;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::mpsc;
7
8use crate::cid::ConnectionId;
9use crate::peer::{ConnectionPeer, Peer};
10use crate::udp::AsyncUdpSocket;
11
12trait LinkDecider {
14 fn should_send(&mut self) -> bool;
18}
19
20#[derive(Debug)]
22pub struct MockUdpSocket<Link> {
23 outbound: mpsc::UnboundedSender<Vec<u8>>,
24 inbound: mpsc::UnboundedReceiver<Vec<u8>>,
25 pub only_peer: char,
27 pub link: Link,
29}
30
31#[derive(Clone)]
32pub struct ManualLinkDecider {
33 pub up_switch: Arc<AtomicBool>,
34}
35
36impl ManualLinkDecider {
37 fn new() -> Self {
38 Self {
39 up_switch: Arc::new(AtomicBool::new(true)),
40 }
41 }
42}
43
44impl LinkDecider for ManualLinkDecider {
45 fn should_send(&mut self) -> bool {
46 self.up_switch.load(Ordering::SeqCst)
47 }
48}
49
50pub struct LinkDropsFirstNSent {
51 target_drops: usize,
52 actual_drops: usize,
53}
54
55impl LinkDropsFirstNSent {
56 fn new(n: usize) -> Self {
57 Self {
58 target_drops: n,
59 actual_drops: 0,
60 }
61 }
62}
63
64impl LinkDecider for LinkDropsFirstNSent {
65 fn should_send(&mut self) -> bool {
66 if self.actual_drops < self.target_drops {
67 self.actual_drops += 1;
68 false
69 } else {
70 true
71 }
72 }
73}
74
75#[async_trait]
76impl<Link: LinkDecider + std::marker::Sync + std::marker::Send> AsyncUdpSocket<char>
77 for MockUdpSocket<Link>
78{
79 async fn send_to(&mut self, buf: &[u8], peer: &Peer<char>) -> io::Result<usize> {
84 if peer.id() != &self.only_peer {
85 panic!("MockUdpSocket only supports sending to one peer");
86 }
87 if !self.link.should_send() {
88 tracing::warn!("Dropping packet to {peer:?}: {buf:?}");
89 return Ok(buf.len());
90 }
91 if let Err(err) = self.outbound.send(buf.to_vec()) {
92 Err(io::Error::new(
93 io::ErrorKind::UnexpectedEof,
94 format!("channel closed: {err}"),
95 ))
96 } else {
97 Ok(buf.len())
98 }
99 }
100
101 async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer<char>)> {
105 let packet = self
106 .inbound
107 .recv()
108 .await
109 .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "channel closed"))?;
110 if buf.len() < packet.len() {
111 panic!("buffer too small for perfect link");
112 }
113 let packet_len = packet.len();
114 buf[..packet_len].copy_from_slice(&packet[..]);
115 Ok((packet_len, Peer::new(self.only_peer)))
116 }
117}
118
119impl ConnectionPeer for char {
120 type Id = char;
121
122 fn id(&self) -> Self::Id {
123 *self
124 }
125
126 fn consolidate(a: Self, b: Self) -> Self {
127 assert!(a == b, "Consolidating non-equal peers");
128 a
129 }
130}
131
132fn build_link_pair<LinkAtoB, LinkBtoA>(
133 a_to_b_link: LinkAtoB,
134 b_to_a_link: LinkBtoA,
135) -> (MockUdpSocket<LinkAtoB>, MockUdpSocket<LinkBtoA>) {
136 let (peer_a, peer_b): (char, char) = ('A', 'B');
137 let (a_tx, a_rx) = mpsc::unbounded_channel();
138 let (b_tx, b_rx) = mpsc::unbounded_channel();
139 let a = MockUdpSocket {
140 outbound: a_tx,
141 inbound: b_rx,
142 only_peer: peer_b,
143 link: a_to_b_link,
144 };
145 let b = MockUdpSocket {
146 outbound: b_tx,
147 inbound: a_rx,
148 only_peer: peer_a,
149 link: b_to_a_link,
150 };
151 (a, b)
152}
153
154fn build_connection_id_pair<LinkAtoB, LinkBtoA>(
155 socket_a: &MockUdpSocket<LinkAtoB>,
156 socket_b: &MockUdpSocket<LinkBtoA>,
157) -> (ConnectionId<char>, ConnectionId<char>) {
158 build_connection_id_pair_starting_at(socket_a, socket_b, 100)
159}
160
161fn build_connection_id_pair_starting_at<LinkAtoB, LinkBtoA>(
162 socket_a: &MockUdpSocket<LinkAtoB>,
163 socket_b: &MockUdpSocket<LinkBtoA>,
164 lower_id: u16,
165) -> (ConnectionId<char>, ConnectionId<char>) {
166 let higher_id = lower_id.wrapping_add(1);
167 let a_cid = ConnectionId {
168 send: higher_id,
169 recv: lower_id,
170 peer_id: socket_a.only_peer,
171 };
172 let b_cid = ConnectionId {
173 send: lower_id,
174 recv: higher_id,
175 peer_id: socket_b.only_peer,
176 };
177 (a_cid, b_cid)
178}
179
180#[allow(clippy::type_complexity)]
182pub fn build_manually_linked_pair() -> (
183 (MockUdpSocket<ManualLinkDecider>, ConnectionId<char>),
184 (MockUdpSocket<ManualLinkDecider>, ConnectionId<char>),
185) {
186 let (socket_a, socket_b) = build_link_pair(ManualLinkDecider::new(), ManualLinkDecider::new());
187 let (a_cid, b_cid) = build_connection_id_pair(&socket_a, &socket_b);
188 ((socket_a, a_cid), (socket_b, b_cid))
189}
190
191#[allow(clippy::type_complexity)]
196pub fn build_link_drops_first_n_sent_pair(
197 n: usize,
198) -> (
199 (MockUdpSocket<ManualLinkDecider>, ConnectionId<char>),
200 (MockUdpSocket<LinkDropsFirstNSent>, ConnectionId<char>),
201) {
202 let link_a_to_b = ManualLinkDecider::new();
203 let link_b_to_a = LinkDropsFirstNSent::new(n);
204 let (socket_a, socket_b) = build_link_pair(link_a_to_b, link_b_to_a);
205 let (a_cid, b_cid) = build_connection_id_pair(&socket_a, &socket_b);
206 ((socket_a, a_cid), (socket_b, b_cid))
207}