1use std::convert::TryFrom;
2use std::error::Error;
3use std::net::IpAddr;
4
5use async_channel::Sender;
6use async_trait::async_trait;
7use tokio::net;
8use tokio::time::{timeout as tokio_timeout, Duration};
9
10use crate::io::{AsyncReader, AsyncWriter, Stream};
11
12#[async_trait]
13pub trait Listener {
14 async fn accept_clients(&mut self, new_clients: Sender<Stream>) -> Result<(), Box<dyn Error>>;
15}
16
17pub struct TcpListener {
18 listener: net::TcpListener,
19}
20
21impl TcpListener {
22 pub async fn new(local_address: IpAddr, local_port: u16) -> Result<Self, Box<dyn Error>> {
23 let listener_address = format!("{}:{}", local_address, local_port);
24 log::info!("start listening on {}", listener_address);
25 let listener = net::TcpListener::bind(listener_address).await?;
26 Ok(Self { listener })
27 }
28}
29
30#[async_trait]
31impl Listener for TcpListener {
32 async fn accept_clients(&mut self, new_clients: Sender<Stream>) -> Result<(), Box<dyn Error>> {
33 while let Ok((client_stream, client_address)) = self.listener.accept().await {
34 log::debug!("got connection from {}", client_address);
35 let (client_reader, client_writer) = client_stream.into_split();
36 new_clients
37 .send(Stream::new(client_reader, client_writer))
38 .await?;
39 }
40 Ok(())
41 }
42}
43
44pub const MAX_UDP_PACKET_SIZE: usize = u16::MAX as usize;
45pub const STREAMED_UDP_PACKET_HEADER_SIZE: usize = 2;
46
47pub async fn stream_udp_packet(payload: &[u8], size: usize, writer: &mut Box<dyn AsyncWriter>) {
48 if payload.len() < size {
49 log::error!(
50 "payload {:?} is too small (expecting size {})",
51 payload,
52 size
53 );
54 return;
55 }
56
57 let size_u16 = match u16::try_from(size) {
58 Ok(s) => s,
59 Err(e) => {
60 log::error!("size {} can't fit in a u16: {}", size, e);
61 return;
62 }
63 };
64
65 if let Err(e) = writer.write(&size_u16.to_be_bytes()).await {
66 log::error!("failed to write header: {}", e);
67 return;
68 };
69
70 if let Err(e) = writer.write(&payload[..size]).await {
71 log::error!("failed to write payload: {}", e);
72 return;
73 };
74}
75
76#[derive(PartialEq, Debug)]
77pub enum UnstreamPacketResult {
78 Error,
79 Timeout,
80 Payload(Vec<u8>),
81}
82
83pub async fn unstream_udp_packet(
84 reader: &mut Box<dyn AsyncReader>,
85 timeout: Option<Duration>,
86) -> UnstreamPacketResult {
87 let mut header_bytes = [0; STREAMED_UDP_PACKET_HEADER_SIZE];
88 let read_header_future = reader.read_exact(&mut header_bytes);
89 let header_size_result = match timeout {
90 None => read_header_future.await,
91 Some(duration) => match tokio_timeout(duration, read_header_future).await {
92 Ok(size_result) => size_result,
93 Err(_) => return UnstreamPacketResult::Timeout,
94 },
95 };
96
97 let header_size = match header_size_result {
98 Ok(size) => size,
99 Err(e) => {
100 log::error!("failed to read header: {}", e);
101 return UnstreamPacketResult::Error;
102 }
103 };
104
105 if header_size != STREAMED_UDP_PACKET_HEADER_SIZE {
106 log::error!("got unexpected header size in bytes {}", header_size);
107 return UnstreamPacketResult::Error;
108 }
109
110 let header = u16::from_be_bytes(header_bytes);
111 let header_usize = header as usize;
112 let mut payload = vec![0; header_usize];
113 let size = match reader.read_exact(&mut payload).await {
114 Ok(size) => size,
115 Err(e) => {
116 log::error!("failed to read payload: {}", e);
117 return UnstreamPacketResult::Error;
118 }
119 };
120
121 if size != header_usize {
122 log::error!("got unexpected data size in bytes {}", header_size);
123 return UnstreamPacketResult::Error;
124 }
125
126 UnstreamPacketResult::Payload(payload)
127}
128
129#[cfg(test)]
130mod tests {
131 use std::io::ErrorKind;
132
133 use tokio::io;
134 use tokio_test::io::Builder;
135
136 use crate::io::{AsyncReadWrapper, AsyncWriteWrapper};
137
138 use super::*;
139
140 #[tokio::test]
141 async fn stream_udp_packet_payload_too_small() -> Result<(), Box<dyn Error>> {
142 let payload = vec![1, 2, 3];
143 let mut writer: Box<dyn AsyncWriter> =
144 Box::new(AsyncWriteWrapper::new(Builder::new().build()));
145
146 stream_udp_packet(&payload, 7, &mut writer).await;
147 Ok(())
148 }
149
150 #[tokio::test]
151 async fn stream_udp_packet_size_not_fit_in_u16() -> Result<(), Box<dyn Error>> {
152 let payload = vec![0; u16::MAX as usize + 7];
153 let mut writer: Box<dyn AsyncWriter> =
154 Box::new(AsyncWriteWrapper::new(Builder::new().build()));
155
156 stream_udp_packet(&payload, payload.len(), &mut writer).await;
157 Ok(())
158 }
159
160 #[tokio::test]
161 async fn stream_udp_packet_write_header_failed() -> Result<(), Box<dyn Error>> {
162 let payload = vec![1, 2, 3];
163 let mut writer: Box<dyn AsyncWriter> = Box::new(AsyncWriteWrapper::new(
164 Builder::new()
165 .write_error(io::Error::new(ErrorKind::Other, "oh no!"))
166 .build(),
167 ));
168
169 stream_udp_packet(&payload, payload.len(), &mut writer).await;
170 Ok(())
171 }
172
173 #[tokio::test]
174 async fn stream_udp_packet_write_payload_failed() -> Result<(), Box<dyn Error>> {
175 let payload = vec![1, 2, 3];
176 let mut writer: Box<dyn AsyncWriter> = Box::new(AsyncWriteWrapper::new(
177 Builder::new()
178 .write(vec![0u8, 3].as_slice())
179 .write_error(io::Error::new(ErrorKind::Other, "oh no!"))
180 .build(),
181 ));
182
183 stream_udp_packet(&payload, payload.len(), &mut writer).await;
184 Ok(())
185 }
186
187 #[tokio::test]
188 async fn stream_udp_packet_success() -> Result<(), Box<dyn Error>> {
189 let payload = vec![1, 2, 3];
190 let mut writer: Box<dyn AsyncWriter> = Box::new(AsyncWriteWrapper::new(
191 Builder::new()
192 .write(vec![0u8, 3].as_slice())
193 .write(payload.as_slice())
194 .build(),
195 ));
196
197 stream_udp_packet(&payload, payload.len(), &mut writer).await;
198 Ok(())
199 }
200
201 #[tokio::test]
202 async fn stream_udp_packet_timeout() -> Result<(), Box<dyn Error>> {
203 let mut reader: Box<dyn AsyncReader> = Box::new(AsyncReadWrapper::new(
204 Builder::new().wait(Duration::from_secs(5)).build(),
205 ));
206
207 let res = unstream_udp_packet(&mut reader, Some(Duration::from_millis(1))).await;
208 assert_eq!(res, UnstreamPacketResult::Timeout);
209 Ok(())
210 }
211
212 #[tokio::test]
213 async fn stream_udp_packet_read_header_failed() -> Result<(), Box<dyn Error>> {
214 let mut reader: Box<dyn AsyncReader> =
215 Box::new(AsyncReadWrapper::new(Builder::new().build()));
216
217 let res = unstream_udp_packet(&mut reader, None).await;
218 assert_eq!(res, UnstreamPacketResult::Error);
219 Ok(())
220 }
221
222 #[tokio::test]
223 async fn stream_udp_packet_read_payload_failed() -> Result<(), Box<dyn Error>> {
224 let mut reader: Box<dyn AsyncReader> = Box::new(AsyncReadWrapper::new(
225 Builder::new().read(vec![0u8, 3].as_slice()).build(),
226 ));
227
228 let res = unstream_udp_packet(&mut reader, None).await;
229 assert_eq!(res, UnstreamPacketResult::Error);
230 Ok(())
231 }
232
233 #[tokio::test]
234 async fn stream_udp_packet_read_payload_success() -> Result<(), Box<dyn Error>> {
235 let payload = vec![1u8, 2, 3];
236 let mut reader: Box<dyn AsyncReader> = Box::new(AsyncReadWrapper::new(
237 Builder::new()
238 .read(vec![0u8, 3].as_slice())
239 .read(payload.as_slice())
240 .build(),
241 ));
242
243 let res = unstream_udp_packet(&mut reader, None).await;
244 assert_eq!(res, UnstreamPacketResult::Payload(payload));
245 Ok(())
246 }
247}