Skip to main content

rust_tls_duplex_stream/
lib.rs

1//! Full duplex stream wrapper around rust-tls
2
3#![deny(clippy::correctness)]
4#![warn(
5    clippy::perf,
6    clippy::complexity,
7    clippy::style,
8    clippy::nursery,
9    clippy::pedantic,
10    clippy::clone_on_ref_ptr,
11    clippy::decimal_literal_representation,
12    clippy::float_cmp_const,
13    clippy::missing_docs_in_private_items,
14    clippy::multiple_inherent_impl,
15    clippy::unwrap_used,
16    clippy::cargo_common_metadata,
17    clippy::used_underscore_binding
18)]
19
20mod queue;
21mod read_pipe;
22mod write_pipe;
23use crate::queue::Queue;
24use crate::read_pipe::ReadPipe;
25use crate::write_pipe::WritePipe;
26use rustls::{ConnectionCommon, StreamOwned};
27use std::fmt::{Arguments, Debug};
28use std::io::{ErrorKind, Read, Write};
29use std::ops::{Deref, DerefMut};
30use std::sync::atomic::AtomicBool;
31use std::sync::atomic::Ordering::SeqCst;
32use std::sync::{Arc, LockResult, Mutex};
33use std::time::Duration;
34use std::{io, thread};
35
36#[derive(Debug)]
37pub struct RustTlsDuplexStream<C, S>
38where
39    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
40{
41    /// Flag for non blocking read.
42    non_blocking_read: AtomicBool,
43    /// Read timeout
44    read_timeout: Mutex<Option<Duration>>,
45    /// Write timeout
46    write_timeout: Mutex<Option<Duration>>,
47    /// Inner rust-tls pseudo connection
48    connection: Mutex<StreamOwned<C, CombinedPipe>>,
49    /// Read queue connected to the thread that reads data from the actual connection
50    read_q: Arc<Queue>,
51    /// Write queue connected to the thread that writes data to the actual connection
52    write_q: Arc<Queue>,
53    /// Guard mutex that prevents concurrent writes.
54    write_mutex: Mutex<()>,
55    /// Guard mutex that prevents concurrent reads.
56    read_mutex: Mutex<()>,
57}
58
59impl<C, S> RustTlsDuplexStream<C, S>
60where
61    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
62    S: rustls::SideData,
63{
64    ///
65    /// Creates a new 'unpooled' Tls stream wrapper.
66    /// This is a good choice for an application such as a client
67    /// that does not create connections and doesn't have a thread pool.
68    ///
69    /// This fn will spawn 2 new threads using `thread::Builder::new().spawn(...)`.
70    /// The threads will terminate when the returned stream is dropped and the read/write errors out.
71    ///
72    /// # Resource Leaks
73    /// Be aware that a Read which "blocks" forever will cause the thread that's reading to stay alive forever.
74    /// Either set a reasonable connection read timeout so that your Read will eventually return
75    /// or call your connections shutdown fn like `TcpStream::shutdown` if such a method exists
76    /// to ensure that all threads are stopped and no resources are leaked.
77    ///
78    /// # Errors
79    /// if `thread::Builder::new().spawn` fails to spawn 2 threads.
80    ///
81    pub fn new_unpooled<R, W>(con: C, read: R, write: W) -> io::Result<Self>
82    where
83        R: Read + Send + 'static,
84        W: Write + Send + 'static,
85    {
86        Self::new(con, read, write, |task| {
87            thread::Builder::new().spawn(task).map(|_| {})
88        })
89    }
90
91    ///
92    /// Creates a new Tls stream wrapper.
93    ///
94    /// This fn will spawn 2 new threads using  the provided thread spawner function.
95    /// The thread spawner function is called exactly twice. If the first call yields an error then
96    /// it's not called again. The functions ("threads") will end when the returned stream wrapper is dropped.
97    ///
98    /// # Resource Leaks
99    /// Be aware that a Read which "blocks" forever will cause the thread that's reading to stay alive forever.
100    /// Either set a reasonable connection read timeout so that your Read will eventually return
101    /// or call your connections shutdown fn like `TcpStream::shutdown` if such a method exists
102    /// to ensure that all threads are stopped and no resources are leaked.
103    ///
104    /// # Errors
105    /// propagated from the spawner fn.
106    ///
107    pub fn new<R, W, T>(con: C, read: R, write: W, spawner: T) -> io::Result<Self>
108    where
109        R: Read + Send + 'static,
110        W: Write + Send + 'static,
111        T: FnMut(Box<dyn FnOnce() + Send>) -> io::Result<()>,
112    {
113        Self::new_with_initial_data(con, read, write, spawner, Vec::new())
114    }
115    ///
116    /// Creates a new Tls stream wrapper.
117    ///
118    /// This fn will spawn 2 new threads using  the provided thread spawner function.
119    /// The thread spawner function is called exactly twice. If the first call yields an error then
120    /// it's not called again. The functions ("threads") will end when the returned stream wrapper is dropped.
121    ///
122    /// # Resource Leaks
123    /// Be aware that a Read which "blocks" forever will cause the thread that's reading to stay alive forever.
124    /// Either set a reasonable connection read timeout so that your Read will eventually return
125    /// or call your connections shutdown fn like `TcpStream::shutdown` if such a method exists
126    /// to ensure that all threads are stopped and no resources are leaked.
127    ///
128    /// # Arguments
129    /// * `initial_data` contains data that will be consumed before data is read from the provided `io::Read`.
130    ///   this data may be the result of some previous operation that first detected if a connection is even TLS or not.
131    ///   Usually the `initial_data` Vec would contain the TLS hello message or the start of it.
132    ///
133    /// # Errors
134    /// propagated from the spawner fn.
135    ///
136    pub fn new_with_initial_data<R, W, T>(
137        con: C,
138        read: R,
139        write: W,
140        spawner: T,
141        initial_data: Vec<u8>,
142    ) -> io::Result<Self>
143    where
144        R: Read + Send + 'static,
145        W: Write + Send + 'static,
146        T: FnMut(Box<dyn FnOnce() + Send>) -> io::Result<()>,
147    {
148        let pipe = CombinedPipe::new(read, write, spawner, initial_data)?;
149        let read_q = pipe.0.dup_queue();
150        let write_q = pipe.1.dup_queue();
151
152        Ok(Self {
153            non_blocking_read: AtomicBool::new(false),
154            read_q,
155            write_q,
156            write_mutex: Mutex::new(()),
157            read_mutex: Mutex::new(()),
158            connection: Mutex::new(StreamOwned::new(con, pipe)),
159            read_timeout: Mutex::new(None),
160            write_timeout: Mutex::new(None),
161        })
162    }
163
164    /// see `Write::write`
165    /// # Errors
166    /// propagated from `Write::write` once subsequent writes/flushes turn into `BrokenPipe`
167    pub fn write(&self, buffer: &[u8]) -> io::Result<usize> {
168        let _outer_guard = unwrap_poison(self.write_mutex.lock())?; //make writes block other writes
169        let timeout_copy = unwrap_poison(self.write_timeout.lock())?
170            .deref()
171            .as_ref()
172            .copied();
173        self.write_q.flush_low(timeout_copy)?;
174        unwrap_poison(self.connection.lock())?.write(buffer)
175    }
176
177    /// see `Write::flush`
178    /// # Errors
179    /// propagated from `Write::flush` once subsequent writes/flushes turn into `BrokenPipe`
180    pub fn flush(&self) -> io::Result<()> {
181        let _outer_guard = unwrap_poison(self.write_mutex.lock())?; //make writes block other writes
182        unwrap_poison(self.connection.lock())?.flush()?;
183        self.write_q.flush_zero()
184    }
185
186    /// see `Read::read`
187    /// # Errors
188    /// propagated from `Read::read` once subsequent reads turn into `BrokenPipe`
189    pub fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
190        let _outer_guard = unwrap_poison(self.read_mutex.lock())?; //make reads block other reads
191        loop {
192            let mut guard = unwrap_poison(self.connection.lock())?;
193            guard.sock.0.nb(true); //Return instantly if no data.
194            let res = guard.read(buffer);
195            guard.sock.0.nb(false); //We must clear this flag or writes may go ballistic.
196            return match res {
197                Ok(count) => {
198                    drop(guard);
199                    Ok(count)
200                }
201                Err(err) => {
202                    if self.non_blocking_read.load(SeqCst) {
203                        return Err(err);
204                    }
205                    if err.kind() == ErrorKind::WouldBlock {
206                        //We have entered the fun zone where reads would block writes
207                        let timeout_copy = unwrap_poison(self.read_timeout.lock())?
208                            .deref()
209                            .as_ref()
210                            .copied();
211                        //This drops guard as soon as a handle to read_q is acquired and will return once trying to read again is meaningful.
212                        self.read_q.await_pop(guard, timeout_copy)?;
213                        continue;
214                    }
215
216                    drop(guard);
217                    Err(err)
218                }
219            };
220        }
221    }
222
223    /// sets the timeout for the writing operation.
224    /// This has no effect on the underlying connection and purely deals with internal writing semantics.
225    /// Calls to fns that writs data will return `TimedOut` if no plain text data could be written.
226    /// Cause of this is likely to be that the underlying connection does not read data fast enough.
227    /// This is never caused by writing too much data.
228    /// # Errors
229    /// In case of poisoned mutex
230    ///
231    pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
232        *unwrap_poison(self.read_timeout.lock())? = timeout;
233        Ok(())
234    }
235
236    /// Returns the current read timeout if any
237    /// # Errors
238    /// In case of poisoned mutex
239    pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
240        Ok(unwrap_poison(self.read_timeout.lock())?.as_ref().copied())
241    }
242
243    /// sets non-blocking mode for read.
244    /// This has no effect on the underlying connection and purely deals with internal reading semantics.
245    /// Calls to fns that read data will return `WouldBlock` immediately if no plain text data is available to be read.
246    /// # Errors
247    /// In case of poisoned mutex
248    pub fn set_read_non_block(&self, on: bool) -> io::Result<()> {
249        self.non_blocking_read.store(on, SeqCst);
250        Ok(())
251    }
252
253    /// sets the timeout for the writing operation.
254    /// This has no effect on the underlying connection and purely deals with internal writing semantics.
255    /// Calls to fns that writs data will return `TimedOut` if no plain text data could be written.
256    /// Cause of this is likely to be that the underlying connection does not send data fast enough.
257    /// This is never caused by reading too much data.
258    /// # Errors
259    /// In case of poisoned mutex
260    pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
261        *unwrap_poison(self.write_timeout.lock())? = timeout;
262        Ok(())
263    }
264
265    /// Returns the current write timeout if any
266    /// # Errors
267    /// In case of poisoned mutex
268    pub fn write_timeout(&self) -> io::Result<Option<Duration>> {
269        Ok(unwrap_poison(self.write_timeout.lock())?.as_ref().copied())
270    }
271
272    /// See `Read::read_to_end`
273    /// # Errors
274    /// propagated
275    pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
276        Read::read_to_end(&mut &*self, buf)
277    }
278
279    /// See `Read::read_to_string`
280    /// # Errors
281    /// propagated
282    pub fn read_to_string(&self, buf: &mut String) -> io::Result<usize> {
283        Read::read_to_string(&mut &*self, buf)
284    }
285
286    /// See `Read::read_exact`
287    /// # Errors
288    /// propagated
289    pub fn read_exact(&self, buf: &mut [u8]) -> io::Result<()> {
290        Read::read_exact(&mut &*self, buf)
291    }
292
293    /// See `Write::write_all`
294    /// # Errors
295    /// propagated
296    pub fn write_all(&self, buf: &[u8]) -> io::Result<()> {
297        Write::write_all(&mut &*self, buf)
298    }
299
300    /// See `Write::write_fmt`
301    /// # Errors
302    /// propagated
303    pub fn write_fmt(&self, fmt: Arguments<'_>) -> io::Result<()> {
304        Write::write_fmt(&mut &*self, fmt)
305    }
306}
307
308impl<C, S> Read for RustTlsDuplexStream<C, S>
309where
310    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
311    S: rustls::SideData,
312{
313    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
314        Self::read(self, buf)
315    }
316}
317
318impl<C, S> Read for &RustTlsDuplexStream<C, S>
319where
320    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
321    S: rustls::SideData,
322{
323    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
324        RustTlsDuplexStream::<C, S>::read(self, buf)
325    }
326}
327
328impl<C, S> Write for RustTlsDuplexStream<C, S>
329where
330    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
331    S: rustls::SideData,
332{
333    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
334        Self::write(self, buf)
335    }
336
337    fn flush(&mut self) -> io::Result<()> {
338        Self::flush(self)
339    }
340}
341
342impl<C, S> Write for &RustTlsDuplexStream<C, S>
343where
344    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
345    S: rustls::SideData,
346{
347    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
348        RustTlsDuplexStream::<C, S>::write(self, buf)
349    }
350
351    fn flush(&mut self) -> io::Result<()> {
352        RustTlsDuplexStream::<C, S>::flush(self)
353    }
354}
355
356/// Read+Write combiner that is fed into rust-tls and delegates to our special ReadPipe/WritePipe
357/// that have careful blocking semantics
358#[derive(Debug)]
359struct CombinedPipe(ReadPipe, WritePipe);
360
361impl CombinedPipe {
362    ///Constructor for `CombinedPipe`
363    pub fn new<
364        R: Read + Send + 'static,
365        W: Write + Send + 'static,
366        T: FnMut(Box<dyn FnOnce() + Send>) -> io::Result<()>,
367    >(
368        read: R,
369        write: W,
370        mut spawner: T,
371        initial_data: Vec<u8>,
372    ) -> io::Result<Self> {
373        Ok(Self(
374            ReadPipe::new(read, &mut spawner, initial_data)?,
375            WritePipe::new(write, &mut spawner)?,
376        ))
377    }
378}
379
380impl Read for CombinedPipe {
381    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
382        self.0.read(buf)
383    }
384}
385
386impl Write for CombinedPipe {
387    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
388        self.1.write(buf)
389    }
390
391    fn flush(&mut self) -> io::Result<()> {
392        self.1.flush()
393    }
394}
395
396/// Poison error to `io::Error`
397pub(crate) fn unwrap_poison<T>(result: LockResult<T>) -> io::Result<T> {
398    result.map_err(|_| io::Error::new(ErrorKind::Other, "Poisoned Mutex"))
399}