Skip to main content

sillad_conntest/
lib.rs

1use std::time::{Duration, Instant};
2
3use async_io::Timer;
4use async_task::Task;
5use async_trait::async_trait;
6use futures_lite::FutureExt as _;
7use futures_util::{AsyncReadExt, AsyncWriteExt};
8use rand::{Rng, RngCore};
9use sillad::{Pipe, dialer::Dialer, listener::Listener};
10
11/// Wraps an underlying dialer with a connection quality test.
12pub struct ConnTestDialer<D: Dialer> {
13    pub inner: D,
14    pub ping_count: usize,
15}
16
17#[async_trait]
18impl<D: Dialer> Dialer for ConnTestDialer<D> {
19    type P = D::P;
20
21    async fn dial(&self) -> std::io::Result<Self::P> {
22        let mut pipe = self.inner.dial().await?;
23        for index in 0..self.ping_count {
24            let start = Instant::now();
25            // Pick a random payload size (nonzero)
26            let size = rand::rng().random_range(1..1000u16);
27            // Tell the server the payload size.
28            pipe.write_all(&size.to_be_bytes()).await?;
29            // Prepare and send a random payload.
30            let mut buf = vec![0u8; size as usize];
31            rand::rng().fill_bytes(&mut buf);
32            pipe.write_all(&buf).await?;
33            // Read back the echoed payload.
34            let mut echo = vec![0u8; size as usize];
35            pipe.read_exact(&mut echo).await?;
36            let remote_addr = pipe.remote_addr();
37            tracing::debug!(
38                elapsed = debug(start.elapsed()),
39                total_count = self.ping_count,
40                index,
41                remote_addr = debug(remote_addr),
42                "ping completed"
43            );
44            if buf != echo {
45                return Err(std::io::Error::new(
46                    std::io::ErrorKind::InvalidData,
47                    "ping returned incorrect data",
48                ));
49            }
50        }
51        // Termination message: a 0 length indicates end of testing.
52        pipe.write_all(&[0u8; 2]).await?;
53        Ok(pipe)
54    }
55}
56
57/// Wraps an underlying listener with a connection quality test.
58pub struct ConnTestListener<L: Listener> {
59    recv_conn: tachyonix::Receiver<L::P>,
60    _task: Task<()>,
61}
62
63impl<L: Listener> ConnTestListener<L> {
64    pub fn new(mut listener: L) -> Self {
65        // Create a channel for passing successfully tested connections.
66        let (send_conn, recv_conn) = tachyonix::channel(1);
67        // Spawn a background task that loops over accepted connections.
68        let task = smolscale::spawn(async move {
69            loop {
70                // Accept a new connection from the underlying listener.
71                let mut conn = match listener.accept().await {
72                    Ok(c) => c,
73                    Err(e) => {
74                        tracing::warn!("Failed to accept connection: {:?}", e);
75                        async_io::Timer::after(Duration::from_secs(1)).await;
76                        continue;
77                    }
78                };
79                let send_conn = send_conn.clone();
80                // For each accepted connection, spawn a task to perform the ping test.
81                smolscale::spawn::<std::io::Result<()>>(async move {
82                    let inner = async {
83                        loop {
84                            let mut size_buf = [0u8; 2];
85                            conn.read_exact(&mut size_buf).await?;
86                            let size = u16::from_be_bytes(size_buf);
87                            // A zero size means the client has finished pinging.
88                            if size == 0 {
89                                let _ = send_conn.send(conn).await;
90                                return Ok(());
91                            }
92                            let mut payload = vec![0u8; size as usize];
93                            conn.read_exact(&mut payload).await?;
94                            conn.write_all(&payload).await?;
95                        }
96                    };
97                    inner
98                        .or(async {
99                            Timer::after(Duration::from_secs(30)).await;
100                            Ok(())
101                        })
102                        .await
103                })
104                .detach();
105            }
106        });
107        Self {
108            recv_conn,
109            _task: task,
110        }
111    }
112}
113
114#[async_trait]
115impl<L: Listener> Listener for ConnTestListener<L> {
116    type P = L::P;
117
118    async fn accept(&mut self) -> std::io::Result<Self::P> {
119        // Wait for a connection that passed the ping test.
120        self.recv_conn.recv().await.map_err(|_| {
121            std::io::Error::new(std::io::ErrorKind::BrokenPipe, "background task is done")
122        })
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    use futures_lite::{AsyncReadExt, AsyncWriteExt};
131
132    use sillad::tcp::{TcpDialer, TcpListener};
133    use smolscale::spawn;
134    use std::io;
135    use std::net::SocketAddr;
136
137    // If your TCP types are defined in another module, adjust these imports accordingly.
138    // For example:
139    // use crate::{TcpListener, TcpDialer};
140
141    /// This unit test creates a TCP listener (wrapped by `ConnTestListener`) that
142    /// echoes incoming data. The client uses `ConnTestDialer` to perform several
143    /// ping rounds before using the connection. The test then verifies that a test
144    /// message is echoed back correctly.
145    #[test]
146    fn test_successful_ping() -> io::Result<()> {
147        async_io::block_on(async {
148            // Bind a TCP listener to an ephemeral port on localhost.
149            let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
150            let tcp_listener = TcpListener::bind(addr).await?;
151            let local_addr = tcp_listener.local_addr().await;
152
153            // Wrap the TCP listener with ConnTestListener.
154            let mut conn_test_listener = ConnTestListener::new(tcp_listener);
155
156            // Spawn a background task that, once the ping test is complete,
157            // performs an echo for any additional messages.
158            let server_handle = spawn(async move {
159                // Accept the connection that passed the ping test.
160                let mut conn = conn_test_listener.accept().await?;
161                let mut buf = [0u8; 1024];
162                loop {
163                    let n = conn.read(&mut buf).await?;
164                    if n == 0 {
165                        break; // Connection closed.
166                    }
167                    conn.write_all(&buf[..n]).await?;
168                }
169                Ok::<(), io::Error>(())
170            });
171
172            // Create a TCP dialer pointed at the server’s address.
173            let tcp_dialer = TcpDialer {
174                dest_addr: local_addr,
175            };
176
177            // Wrap the TCP dialer with ConnTestDialer (performing, for example, 3 ping rounds).
178            let conn_test_dialer = ConnTestDialer {
179                inner: tcp_dialer,
180                ping_count: 3,
181            };
182
183            // Dial to the server. This will perform the ping test internally.
184            let mut client_pipe = conn_test_dialer.dial().await?;
185
186            // Send a test message and expect an echo.
187            let test_message = b"hello, unit test!";
188            client_pipe.write_all(test_message).await?;
189            let mut buf = vec![0u8; test_message.len()];
190            client_pipe.read_exact(&mut buf).await?;
191            assert_eq!(
192                &buf, test_message,
193                "the echoed message should match the sent message"
194            );
195
196            // Clean up.
197            drop(client_pipe);
198            server_handle.await?;
199            Ok(())
200        })
201    }
202
203    /// This unit test simulates a server that deliberately corrupts the ping echo.
204    /// As a result, the `ConnTestDialer` should detect the invalid data and fail.
205    #[test]
206    fn test_failed_ping() -> io::Result<()> {
207        async_io::block_on(async {
208            // Bind a TCP listener to an ephemeral port.
209            let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
210            let mut tcp_listener = TcpListener::bind(addr).await?;
211            let local_addr = tcp_listener.local_addr().await;
212
213            // Spawn a server task that corrupts each ping echo.
214            let server_handle = spawn(async move {
215                let mut conn = tcp_listener.accept().await?;
216                loop {
217                    // Read the two-byte payload size.
218                    let mut size_buf = [0u8; 2];
219                    if conn.read_exact(&mut size_buf).await.is_err() {
220                        break; // Connection closed.
221                    }
222                    let size = u16::from_be_bytes(size_buf);
223                    if size == 0 {
224                        break; // Termination message.
225                    }
226                    // Read the payload.
227                    let mut payload = vec![0u8; size as usize];
228                    conn.read_exact(&mut payload).await?;
229                    // Corrupt the payload (flip the first byte, if any).
230                    if !payload.is_empty() {
231                        payload[0] = payload[0].wrapping_add(1);
232                    }
233                    // Send the corrupted payload back.
234                    conn.write_all(&payload).await?;
235                }
236                Ok::<(), io::Error>(())
237            });
238
239            // Create a TCP dialer pointed at the server’s address.
240            let tcp_dialer = TcpDialer {
241                dest_addr: local_addr,
242            };
243
244            // Wrap the TCP dialer with ConnTestDialer (using 3 ping rounds).
245            let conn_test_dialer = ConnTestDialer {
246                inner: tcp_dialer,
247                ping_count: 3,
248            };
249
250            // Attempt to dial to the server.
251            // Since the server corrupts the echoed pings, the dial should return an error.
252            let result = conn_test_dialer.dial().await;
253            assert!(
254                result.is_err(),
255                "dialing should fail due to corrupted ping echoes"
256            );
257
258            let _ = server_handle.await;
259            Ok(())
260        })
261    }
262}