tunneler_core/general/
traits.rs

1use async_trait::async_trait;
2
3use tokio::io::{AsyncReadExt, AsyncWriteExt};
4
5use crate::message::Message;
6
7/// Used to read from an actual TCP-Connection
8#[async_trait]
9pub trait ConnectionReader {
10    /// Reads an arbitrary amount of bytes from the Connection
11    /// to at most fill the buffer
12    async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize>;
13
14    /// Reads from the Connection until the Buffer is filled
15    async fn read_full(&mut self, buf: &mut [u8]) -> std::io::Result<usize>;
16
17    /// Reads the next `size` amount of bytes from the connection
18    /// and throws them away
19    async fn drain(&mut self, size: usize) {
20        let mut buf = vec![0; size];
21        if let Err(e) = self.read_full(&mut buf).await {
22            error!("Draining: {}", e);
23        }
24    }
25}
26
27/// Used to write over an actual TCP-Connection
28#[async_trait]
29pub trait ConnectionWriter {
30    /// Attempts to write the entire buffer to the underlying
31    /// Connection
32    async fn write_full(&mut self, buf: &[u8]) -> std::io::Result<()>;
33
34    /// Attempts to write the message to the underlying connection
35    async fn write_msg(&mut self, msg: &Message, tmp_buf: &mut [u8; 13]) -> std::io::Result<()> {
36        let data = msg.serialize(tmp_buf);
37        if let Err(e) = self.write_full(tmp_buf).await {
38            return Err(e);
39        }
40        if let Err(e) = self.write_full(data).await {
41            return Err(e);
42        }
43
44        Ok(())
45    }
46}
47
48#[async_trait]
49impl<T> ConnectionReader for T
50where
51    T: AsyncReadExt + Send + Unpin,
52{
53    async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
54        self.read(buf).await
55    }
56
57    async fn read_full(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
58        self.read_exact(buf).await
59    }
60}
61
62#[async_trait]
63impl<T> ConnectionWriter for T
64where
65    T: AsyncWriteExt + Send + Unpin,
66{
67    async fn write_full(&mut self, buf: &[u8]) -> std::io::Result<()> {
68        self.write_all(buf).await
69    }
70}