1use std::time::{Duration, Instant};
2
3use async_io::Timer;
4use async_task::Task;
5use async_trait::async_trait;
6use futures_lite::FutureExt as _;
7use futures_util::{AsyncReadExt, AsyncWriteExt};
8use rand::{Rng, RngCore};
9use sillad::{Pipe, dialer::Dialer, listener::Listener};
10
11pub struct ConnTestDialer<D: Dialer> {
13 pub inner: D,
14 pub ping_count: usize,
15}
16
17#[async_trait]
18impl<D: Dialer> Dialer for ConnTestDialer<D> {
19 type P = D::P;
20
21 async fn dial(&self) -> std::io::Result<Self::P> {
22 let mut pipe = self.inner.dial().await?;
23 for index in 0..self.ping_count {
24 let start = Instant::now();
25 let size = rand::rng().random_range(1..1000u16);
27 pipe.write_all(&size.to_be_bytes()).await?;
29 let mut buf = vec![0u8; size as usize];
31 rand::rng().fill_bytes(&mut buf);
32 pipe.write_all(&buf).await?;
33 let mut echo = vec![0u8; size as usize];
35 pipe.read_exact(&mut echo).await?;
36 let remote_addr = pipe.remote_addr();
37 tracing::debug!(
38 elapsed = debug(start.elapsed()),
39 total_count = self.ping_count,
40 index,
41 remote_addr = debug(remote_addr),
42 "ping completed"
43 );
44 if buf != echo {
45 return Err(std::io::Error::new(
46 std::io::ErrorKind::InvalidData,
47 "ping returned incorrect data",
48 ));
49 }
50 }
51 pipe.write_all(&[0u8; 2]).await?;
53 Ok(pipe)
54 }
55}
56
57pub struct ConnTestListener<L: Listener> {
59 recv_conn: tachyonix::Receiver<L::P>,
60 _task: Task<()>,
61}
62
63impl<L: Listener> ConnTestListener<L> {
64 pub fn new(mut listener: L) -> Self {
65 let (send_conn, recv_conn) = tachyonix::channel(1);
67 let task = smolscale::spawn(async move {
69 loop {
70 let mut conn = match listener.accept().await {
72 Ok(c) => c,
73 Err(e) => {
74 tracing::warn!("Failed to accept connection: {:?}", e);
75 async_io::Timer::after(Duration::from_secs(1)).await;
76 continue;
77 }
78 };
79 let send_conn = send_conn.clone();
80 smolscale::spawn::<std::io::Result<()>>(async move {
82 let inner = async {
83 loop {
84 let mut size_buf = [0u8; 2];
85 conn.read_exact(&mut size_buf).await?;
86 let size = u16::from_be_bytes(size_buf);
87 if size == 0 {
89 let _ = send_conn.send(conn).await;
90 return Ok(());
91 }
92 let mut payload = vec![0u8; size as usize];
93 conn.read_exact(&mut payload).await?;
94 conn.write_all(&payload).await?;
95 }
96 };
97 inner
98 .or(async {
99 Timer::after(Duration::from_secs(30)).await;
100 Ok(())
101 })
102 .await
103 })
104 .detach();
105 }
106 });
107 Self {
108 recv_conn,
109 _task: task,
110 }
111 }
112}
113
114#[async_trait]
115impl<L: Listener> Listener for ConnTestListener<L> {
116 type P = L::P;
117
118 async fn accept(&mut self) -> std::io::Result<Self::P> {
119 self.recv_conn.recv().await.map_err(|_| {
121 std::io::Error::new(std::io::ErrorKind::BrokenPipe, "background task is done")
122 })
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 use futures_lite::{AsyncReadExt, AsyncWriteExt};
131
132 use sillad::tcp::{TcpDialer, TcpListener};
133 use smolscale::spawn;
134 use std::io;
135 use std::net::SocketAddr;
136
137 #[test]
146 fn test_successful_ping() -> io::Result<()> {
147 async_io::block_on(async {
148 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
150 let tcp_listener = TcpListener::bind(addr).await?;
151 let local_addr = tcp_listener.local_addr().await;
152
153 let mut conn_test_listener = ConnTestListener::new(tcp_listener);
155
156 let server_handle = spawn(async move {
159 let mut conn = conn_test_listener.accept().await?;
161 let mut buf = [0u8; 1024];
162 loop {
163 let n = conn.read(&mut buf).await?;
164 if n == 0 {
165 break; }
167 conn.write_all(&buf[..n]).await?;
168 }
169 Ok::<(), io::Error>(())
170 });
171
172 let tcp_dialer = TcpDialer {
174 dest_addr: local_addr,
175 };
176
177 let conn_test_dialer = ConnTestDialer {
179 inner: tcp_dialer,
180 ping_count: 3,
181 };
182
183 let mut client_pipe = conn_test_dialer.dial().await?;
185
186 let test_message = b"hello, unit test!";
188 client_pipe.write_all(test_message).await?;
189 let mut buf = vec![0u8; test_message.len()];
190 client_pipe.read_exact(&mut buf).await?;
191 assert_eq!(
192 &buf, test_message,
193 "the echoed message should match the sent message"
194 );
195
196 drop(client_pipe);
198 server_handle.await?;
199 Ok(())
200 })
201 }
202
203 #[test]
206 fn test_failed_ping() -> io::Result<()> {
207 async_io::block_on(async {
208 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
210 let mut tcp_listener = TcpListener::bind(addr).await?;
211 let local_addr = tcp_listener.local_addr().await;
212
213 let server_handle = spawn(async move {
215 let mut conn = tcp_listener.accept().await?;
216 loop {
217 let mut size_buf = [0u8; 2];
219 if conn.read_exact(&mut size_buf).await.is_err() {
220 break; }
222 let size = u16::from_be_bytes(size_buf);
223 if size == 0 {
224 break; }
226 let mut payload = vec![0u8; size as usize];
228 conn.read_exact(&mut payload).await?;
229 if !payload.is_empty() {
231 payload[0] = payload[0].wrapping_add(1);
232 }
233 conn.write_all(&payload).await?;
235 }
236 Ok::<(), io::Error>(())
237 });
238
239 let tcp_dialer = TcpDialer {
241 dest_addr: local_addr,
242 };
243
244 let conn_test_dialer = ConnTestDialer {
246 inner: tcp_dialer,
247 ping_count: 3,
248 };
249
250 let result = conn_test_dialer.dial().await;
253 assert!(
254 result.is_err(),
255 "dialing should fail due to corrupted ping echoes"
256 );
257
258 let _ = server_handle.await;
259 Ok(())
260 })
261 }
262}