tokio_netstring_trait/lib.rs
1#![warn(
2 missing_debug_implementations,
3 missing_docs,
4 rust_2018_idioms,
5 unreachable_pub
6)]
7
8//! # NOTICE
9//! This is the very first release and my first project in rust. Feedback is appreciated.
10
11use async_trait::async_trait;
12use log::trace;
13use std::io;
14use std::io::{Cursor, ErrorKind, Write};
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16
17// The length of a netstring is encoded in decimal. A u32 in decimal is 10 characters long.
18// The assumption is made that messages larger than u32::MAX are faulty packages and
19// they will therefore not be processed.
20const MAX_NETSTRING_LENGTH_DEC: usize = 10;
21
22async fn tag<T: AsyncRead + Unpin + ?Sized>(expected: u8, reader: &mut T) -> io::Result<()> {
23 let received = reader.read_u8().await?;
24 if expected != received {
25 Err(ErrorKind::InvalidData.into())
26 } else {
27 Ok(())
28 }
29}
30
31async fn read_netstring_length<T: AsyncRead + Unpin + ?Sized>(reader: &mut T) -> io::Result<usize> {
32 let mut buffer = [0u8; MAX_NETSTRING_LENGTH_DEC];
33 let mut read_buffer_len = 0usize;
34
35 for i in buffer.iter_mut() {
36 match reader.read_u8().await? {
37 b':' => break,
38 byte @ (b'0'..=b'9') => {
39 *i = byte;
40 read_buffer_len += 1;
41 }
42 _ => return Err(ErrorKind::InvalidData.into()),
43 }
44 }
45
46 if read_buffer_len == MAX_NETSTRING_LENGTH_DEC {
47 tag(b':', reader).await?;
48 }
49
50 // SAFETY: The validation was already performed when writing into the buffer
51 // that this is a valid string that contains only numbers.
52 unsafe {
53 Ok(std::str::from_utf8_unchecked(&buffer[..read_buffer_len])
54 .parse()
55 .unwrap())
56 }
57}
58
59#[cfg(err_drop_message)]
60async fn drop_message<T: AsyncRead + Unpin + ?Sized>(
61 reader: &mut T,
62 mut size: usize,
63) -> io::Result<usize> {
64 const INTERN_BUFFER_SIZE: usize = 4096;
65
66 let mut intern_buffer = [0u8; INTERN_BUFFER_SIZE];
67 while size > INTERN_BUFFER_SIZE {
68 size -= reader.read_exact(&mut intern_buffer).await?;
69 }
70 reader.read_exact(&mut intern_buffer[..size]).await?;
71
72 Err(ErrorKind::BrokenPipe.into())
73}
74
75#[cfg(not(err_drop_message))]
76async fn drop_message<T: AsyncRead + Unpin + ?Sized>(
77 _reader: &mut T,
78 _size: usize,
79) -> io::Result<usize> {
80 Err(ErrorKind::BrokenPipe.into())
81}
82
83/// The `AsyncNetstringRead` trait allows you to read one netstring at a time from any stream
84/// that has `AsyncRead` implemented. No implementation is thread-safe and multiple simultaneous
85/// reads can corrupt the message stream irreparably.
86#[async_trait]
87pub trait AsyncNetstringRead: AsyncRead + Unpin {
88 /// This method allows to read one netstring into the buffer given. It is advised to use
89 /// this Trait on a [tokio::io::BufReader] to avoid repeated system calls during parsing.
90 ///
91 /// # Usage
92 /// ```no_exec
93 /// use tokio_netstring::NetstringReader;
94 ///
95 /// let buf = [0; 1024];
96 /// let len: usize = stream.read_netstring(&mut buf).await.unwrap();
97 /// let buf: &[u8] = &buf[..len];
98 /// ```
99 ///
100 /// # Errors
101 /// This method returns a `tokio::io::Result` which is a re-export from `std::io::Result`.
102 ///
103 /// ## ErrorKind::UnexpectedEof
104 /// This error kind is returned, if the stream got closed, before a Netstring could be fully read.
105 ///
106 /// ## ErrorKind::BrokenPipe
107 /// This error type indicates that the buffer provided is to small for the netstring to fit in.
108 /// In the current implementation this error is irrecoverable as it has corrupted the stream.
109 /// Future implementations may allow to recover from this.
110 ///
111 /// Is the feature `err_drop_message` set, then the netstring will be dropped. Therefor is the
112 /// stream afterwards in a known stream an can be further used.
113 ///
114 /// ## ErrorKind::InvalidData
115 /// This error can be returned on three occasions:
116 ///
117 /// 1. The size provided is to big. The length of the netstring is stored as a `usize`. Should
118 /// the message provide a longer value, it is most likely an error and will be returned as such.
119 ///
120 /// 1. The Separator between length and the netstring is not `b':'`.
121 ///
122 /// 1. The Netstring does not end with a `b','`.
123 ///
124 /// In all cases the stream is irreparably corrupted and the connection should therefor be dropped.
125 async fn read_netstring(&mut self, buffer: &mut [u8]) -> io::Result<usize> {
126 let length = read_netstring_length(self).await?;
127
128 if buffer.len() >= length {
129 self.read_exact(&mut buffer[..length]).await?;
130 } else {
131 return drop_message(self, length).await;
132 }
133
134 trace!(
135 "READING NETSTRING: {}:{},",
136 length,
137 std::str::from_utf8(&buffer[..length]).unwrap()
138 );
139
140 tag(b',', self).await?;
141
142 return Ok(length);
143 }
144
145 /// This method allows to read one netstring. It returns the netstring as a `Vec<u8>` and
146 /// allocates the memory itself, therefore avoiding a to small buffer.
147 ///
148 /// While this may be use full during development, it should be avoided in production, since it
149 /// can allocate memory and a DDOS attack is therefore easily possible.
150 ///
151 /// # Usage
152 /// ```no_exec
153 /// use tokio_netstring::NetstringReader;
154 ///
155 /// let netstring: Vec<u8> = stream.read_netstring_alloc(&mut buf).await.unwrap();
156 /// ```
157 ///
158 /// # Errors
159 /// It returns the same errors as [AsyncNetstringRead::read_netstring], but can't fail because
160 /// the buffer is to small.
161 ///
162 async fn read_netstring_alloc(&mut self) -> io::Result<Vec<u8>> {
163 let length = read_netstring_length(self).await?;
164 let mut buffer = Vec::with_capacity(length);
165
166 // SAFETY: The buffer has capacity `length` therefore indexing it until there is safe.
167 let buffer_slice = unsafe { buffer.get_unchecked_mut(..length) };
168
169 self.read_exact(buffer_slice).await?;
170
171 tag(b',', self).await?;
172
173 // SAFETY: We have read all the bytes from the source into the Vec. At that point all
174 // values up until length have to be initialized.
175 unsafe { buffer.set_len(length) };
176
177 return Ok(buffer);
178 }
179}
180
181impl<Reader: AsyncRead + Unpin + ?Sized> AsyncNetstringRead for Reader {}
182
183/// The `NetstringWriter` trait allows to write a slice of bytes as a netstring to any stream that
184/// implements `AsyncWrite`
185#[async_trait]
186pub trait AsyncNetstringWrite: AsyncWrite + Unpin {
187 /// Write the slice as a netstring to the stream.
188 ///
189 /// # Usage
190 /// ```no_exec
191 /// use tokio_netstring::NetstringWriter;
192 ///
193 /// let msg = "Hello, World!";
194 /// stream.write_netstring(&msg.as_bytes());
195 /// ```
196 ///
197 /// # Errors
198 /// This method returns a `tokio::io::Result` which is a re-export from `std::io::Result`. It
199 /// returns `ErrorKind::WriteZero` if the stream was closed an no more data can be sent.
200 ///
201 async fn write_netstring(&mut self, data: &[u8]) -> io::Result<()> {
202 let mut buffer = [0u8; 2 * MAX_NETSTRING_LENGTH_DEC + 1];
203 let len = {
204 let mut writer = Cursor::new(&mut buffer[..]);
205 write!(writer, "{}", data.len())?;
206 writer.position() as usize
207 };
208 buffer[len] = b':';
209
210 trace!(
211 "WRITING NETSTRING: {}{},",
212 std::str::from_utf8(&buffer[..len + 1]).unwrap(),
213 std::str::from_utf8(data).unwrap()
214 );
215
216 self.write_all(&buffer[..len + 1]).await?;
217 self.write_all(data).await?;
218 self.write_all(b",").await?;
219 self.flush().await
220 }
221}
222
223impl<Writer: AsyncWrite + Unpin + ?Sized> AsyncNetstringWrite for Writer {}