tokio_netstring_trait/
lib.rs

1#![warn(
2    missing_debug_implementations,
3    missing_docs,
4    rust_2018_idioms,
5    unreachable_pub
6)]
7
8//! # NOTICE
9//! This is the very first release and my first project in rust. Feedback is appreciated.
10
11use async_trait::async_trait;
12use log::trace;
13use std::io;
14use std::io::{Cursor, ErrorKind, Write};
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16
17// The length of a netstring is encoded in decimal. A u32 in decimal is 10 characters long.
18// The assumption is made that messages larger than u32::MAX are faulty packages and
19// they will therefore not be processed.
20const MAX_NETSTRING_LENGTH_DEC: usize = 10;
21
22async fn tag<T: AsyncRead + Unpin + ?Sized>(expected: u8, reader: &mut T) -> io::Result<()> {
23    let received = reader.read_u8().await?;
24    if expected != received {
25        Err(ErrorKind::InvalidData.into())
26    } else {
27        Ok(())
28    }
29}
30
31async fn read_netstring_length<T: AsyncRead + Unpin + ?Sized>(reader: &mut T) -> io::Result<usize> {
32    let mut buffer = [0u8; MAX_NETSTRING_LENGTH_DEC];
33    let mut read_buffer_len = 0usize;
34
35    for i in buffer.iter_mut() {
36        match reader.read_u8().await? {
37            b':' => break,
38            byte @ (b'0'..=b'9') => {
39                *i = byte;
40                read_buffer_len += 1;
41            }
42            _ => return Err(ErrorKind::InvalidData.into()),
43        }
44    }
45
46    if read_buffer_len == MAX_NETSTRING_LENGTH_DEC {
47        tag(b':', reader).await?;
48    }
49
50    // SAFETY: The validation was already performed when writing into the buffer
51    // that this is a valid string that contains only numbers.
52    unsafe {
53        Ok(std::str::from_utf8_unchecked(&buffer[..read_buffer_len])
54            .parse()
55            .unwrap())
56    }
57}
58
59#[cfg(err_drop_message)]
60async fn drop_message<T: AsyncRead + Unpin + ?Sized>(
61    reader: &mut T,
62    mut size: usize,
63) -> io::Result<usize> {
64    const INTERN_BUFFER_SIZE: usize = 4096;
65
66    let mut intern_buffer = [0u8; INTERN_BUFFER_SIZE];
67    while size > INTERN_BUFFER_SIZE {
68        size -= reader.read_exact(&mut intern_buffer).await?;
69    }
70    reader.read_exact(&mut intern_buffer[..size]).await?;
71
72    Err(ErrorKind::BrokenPipe.into())
73}
74
75#[cfg(not(err_drop_message))]
76async fn drop_message<T: AsyncRead + Unpin + ?Sized>(
77    _reader: &mut T,
78    _size: usize,
79) -> io::Result<usize> {
80    Err(ErrorKind::BrokenPipe.into())
81}
82
83/// The `AsyncNetstringRead` trait allows you to read one netstring at a time from any stream
84/// that has `AsyncRead` implemented. No implementation is thread-safe and multiple simultaneous
85/// reads can corrupt the message stream irreparably.
86#[async_trait]
87pub trait AsyncNetstringRead: AsyncRead + Unpin {
88    /// This method allows to read one netstring into the buffer given. It is advised to use
89    /// this Trait on a [tokio::io::BufReader] to avoid repeated system calls during parsing.
90    ///
91    /// # Usage
92    /// ```no_exec
93    /// use tokio_netstring::NetstringReader;
94    ///
95    /// let buf = [0; 1024];
96    /// let len: usize = stream.read_netstring(&mut buf).await.unwrap();
97    /// let buf: &[u8] = &buf[..len];
98    /// ```
99    ///
100    /// # Errors
101    /// This method returns a `tokio::io::Result` which is a re-export from `std::io::Result`.
102    ///
103    /// ## ErrorKind::UnexpectedEof
104    /// This error kind is returned, if the stream got closed, before a Netstring could be fully read.
105    ///
106    /// ## ErrorKind::BrokenPipe
107    /// This error type indicates that the buffer provided is to small for the netstring to fit in.
108    /// In the current implementation this error is irrecoverable as it has corrupted the stream.
109    /// Future implementations may allow to recover from this.
110    ///
111    /// Is the feature `err_drop_message` set, then the netstring will be dropped. Therefor is the
112    /// stream afterwards in a known stream an can be further used.
113    ///
114    /// ## ErrorKind::InvalidData
115    /// This error can be returned on three occasions:
116    ///
117    /// 1. The size provided is to big. The length of the netstring is stored as a `usize`. Should
118    /// the message provide a longer value, it is most likely an error and will be returned as such.
119    ///
120    /// 1. The Separator between length and the netstring is not `b':'`.
121    ///
122    /// 1. The Netstring does not end with a `b','`.
123    ///
124    /// In all cases the stream is irreparably corrupted and the connection should therefor be dropped.
125    async fn read_netstring(&mut self, buffer: &mut [u8]) -> io::Result<usize> {
126        let length = read_netstring_length(self).await?;
127
128        if buffer.len() >= length {
129            self.read_exact(&mut buffer[..length]).await?;
130        } else {
131            return drop_message(self, length).await;
132        }
133
134        trace!(
135            "READING NETSTRING: {}:{},",
136            length,
137            std::str::from_utf8(&buffer[..length]).unwrap()
138        );
139
140        tag(b',', self).await?;
141
142        return Ok(length);
143    }
144
145    /// This method allows to read one netstring. It returns the netstring as a `Vec<u8>` and
146    /// allocates the memory itself, therefore avoiding a to small buffer.
147    ///
148    /// While this may be use full during development, it should be avoided in production, since it
149    /// can allocate memory and a DDOS attack is therefore easily possible.
150    ///
151    /// # Usage
152    /// ```no_exec
153    /// use tokio_netstring::NetstringReader;
154    ///
155    /// let netstring: Vec<u8> = stream.read_netstring_alloc(&mut buf).await.unwrap();
156    /// ```
157    ///
158    /// # Errors
159    /// It returns the same errors as [AsyncNetstringRead::read_netstring], but can't fail because
160    /// the buffer is to small.
161    ///
162    async fn read_netstring_alloc(&mut self) -> io::Result<Vec<u8>> {
163        let length = read_netstring_length(self).await?;
164        let mut buffer = Vec::with_capacity(length);
165
166        // SAFETY: The buffer has capacity `length` therefore indexing it until there is safe.
167        let buffer_slice = unsafe { buffer.get_unchecked_mut(..length) };
168
169        self.read_exact(buffer_slice).await?;
170
171        tag(b',', self).await?;
172
173        // SAFETY: We have read all the bytes from the source into the Vec. At that point all
174        // values up until length have to be initialized.
175        unsafe { buffer.set_len(length) };
176
177        return Ok(buffer);
178    }
179}
180
181impl<Reader: AsyncRead + Unpin + ?Sized> AsyncNetstringRead for Reader {}
182
183/// The `NetstringWriter` trait allows to write a slice of bytes as a netstring to any stream that
184/// implements `AsyncWrite`
185#[async_trait]
186pub trait AsyncNetstringWrite: AsyncWrite + Unpin {
187    /// Write the slice as a netstring to the stream.
188    ///
189    /// # Usage
190    /// ```no_exec
191    /// use tokio_netstring::NetstringWriter;
192    ///
193    /// let msg = "Hello, World!";
194    /// stream.write_netstring(&msg.as_bytes());
195    /// ```
196    ///
197    /// # Errors
198    /// This method returns a `tokio::io::Result` which is a re-export from `std::io::Result`. It
199    /// returns `ErrorKind::WriteZero` if the stream was closed an no more data can be sent.
200    ///
201    async fn write_netstring(&mut self, data: &[u8]) -> io::Result<()> {
202        let mut buffer = [0u8; 2 * MAX_NETSTRING_LENGTH_DEC + 1];
203        let len = {
204            let mut writer = Cursor::new(&mut buffer[..]);
205            write!(writer, "{}", data.len())?;
206            writer.position() as usize
207        };
208        buffer[len] = b':';
209
210        trace!(
211            "WRITING NETSTRING: {}{},",
212            std::str::from_utf8(&buffer[..len + 1]).unwrap(),
213            std::str::from_utf8(data).unwrap()
214        );
215
216        self.write_all(&buffer[..len + 1]).await?;
217        self.write_all(data).await?;
218        self.write_all(b",").await?;
219        self.flush().await
220    }
221}
222
223impl<Writer: AsyncWrite + Unpin + ?Sized> AsyncNetstringWrite for Writer {}