rust_tls_duplex_stream/
lib.rs

1//! Full duplex stream wrapper around rust-tls
2
3#![deny(clippy::correctness)]
4#![warn(
5    clippy::perf,
6    clippy::complexity,
7    clippy::style,
8    clippy::nursery,
9    clippy::pedantic,
10    clippy::clone_on_ref_ptr,
11    clippy::decimal_literal_representation,
12    clippy::float_cmp_const,
13    clippy::missing_docs_in_private_items,
14    clippy::multiple_inherent_impl,
15    clippy::unwrap_used,
16    clippy::cargo_common_metadata,
17    clippy::used_underscore_binding
18)]
19
20mod queue;
21mod read_pipe;
22mod write_pipe;
23use crate::queue::Queue;
24use crate::read_pipe::ReadPipe;
25use crate::write_pipe::WritePipe;
26use rustls::{ConnectionCommon, StreamOwned};
27use std::fmt::{Arguments, Debug};
28use std::io::{ErrorKind, Read, Write};
29use std::ops::{Deref, DerefMut};
30use std::sync::atomic::AtomicBool;
31use std::sync::atomic::Ordering::SeqCst;
32use std::sync::{Arc, LockResult, Mutex};
33use std::time::Duration;
34use std::{io, thread};
35
36#[derive(Debug)]
37pub struct RustTlsDuplexStream<C, S>
38where
39    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
40{
41    /// Flag for non blocking read.
42    non_blocking_read: AtomicBool,
43    /// Read timeout
44    read_timeout: Mutex<Option<Duration>>,
45    /// Write timeout
46    write_timeout: Mutex<Option<Duration>>,
47    /// Inner rust-tls pseudo connection
48    connection: Mutex<StreamOwned<C, CombinedPipe>>,
49    /// Read queue connected to the thread that reads data from the actual connection
50    read_q: Arc<Queue>,
51    /// Write queue connected to the thread that writes data to the actual connection
52    write_q: Arc<Queue>,
53    /// Guard mutex that prevents concurrent writes.
54    write_mutex: Mutex<()>,
55    /// Guard mutex that prevents concurrent reads.
56    read_mutex: Mutex<()>,
57}
58
59impl<C, S> RustTlsDuplexStream<C, S>
60where
61    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
62    S: rustls::SideData,
63{
64    ///
65    /// Creates a new 'unpooled' Tls stream wrapper.
66    /// This is a good choice for an application such as a client
67    /// that does not create connections and doesn't have a thread pool.
68    ///
69    /// This fn will spawn 2 new threads using `thread::Builder::new().spawn(...)`.
70    /// The threads will terminate when the returned stream is dropped and the read/write errors out.
71    ///
72    /// # Resource Leaks
73    /// Be aware that a Read which "blocks" forever will cause the thread that's reading to stay alive forever.
74    /// Either set a reasonable connection read timeout so that your Read will eventually return
75    /// or call your connections shutdown fn like `TcpStream::shutdown` if such a method exists
76    /// to ensure that all threads are stopped and no resources are leaked.
77    ///
78    /// # Errors
79    /// if `thread::Builder::new().spawn` fails to spawn 2 threads.
80    ///
81    pub fn new_unpooled<R, W>(con: C, read: R, write: W) -> io::Result<Self>
82    where
83        R: Read + Send + 'static,
84        W: Write + Send + 'static,
85    {
86        Self::new(con, read, write, |task| {
87            thread::Builder::new().spawn(task).map(|_| {})
88        })
89    }
90
91    ///
92    /// Creates a new Tls stream wrapper.
93    ///
94    /// This fn will spawn 2 new threads using  the provided thread spawner function.
95    /// The thread spawner function is called exactly twice. If the first call yields an error then
96    /// it's not called again. The functions ("threads") will end when the returned stream wrapper is dropped.
97    ///
98    /// # Resource Leaks
99    /// Be aware that a Read which "blocks" forever will cause the thread that's reading to stay alive forever.
100    /// Either set a reasonable connection read timeout so that your Read will eventually return
101    /// or call your connections shutdown fn like `TcpStream::shutdown` if such a method exists
102    /// to ensure that all threads are stopped and no resources are leaked.
103    ///
104    /// # Errors
105    /// propagated from the spawner fn.
106    ///
107    pub fn new<R, W, T>(con: C, read: R, write: W, spawner: T) -> io::Result<Self>
108    where
109        R: Read + Send + 'static,
110        W: Write + Send + 'static,
111        T: FnMut(Box<dyn FnOnce() + Send>) -> io::Result<()>,
112    {
113        let pipe = CombinedPipe::new(read, write, spawner)?;
114        let read_q = pipe.0.dup_queue();
115        let write_q = pipe.1.dup_queue();
116
117        Ok(Self {
118            non_blocking_read: AtomicBool::new(false),
119            read_q,
120            write_q,
121            write_mutex: Mutex::new(()),
122            read_mutex: Mutex::new(()),
123            connection: Mutex::new(StreamOwned::new(con, pipe)),
124            read_timeout: Mutex::new(None),
125            write_timeout: Mutex::new(None),
126        })
127    }
128
129    /// see `Write::write`
130    /// # Errors
131    /// propagated from `Write::write` once subsequent writes/flushes turn into `BrokenPipe`
132    pub fn write(&self, buffer: &[u8]) -> io::Result<usize> {
133        let _outer_guard = unwrap_poison(self.write_mutex.lock())?; //make writes block other writes
134        let timeout_copy = unwrap_poison(self.write_timeout.lock())?
135            .deref()
136            .as_ref()
137            .copied();
138        self.write_q.flush_low(timeout_copy)?;
139        unwrap_poison(self.connection.lock())?.write(buffer)
140    }
141
142    /// see `Write::flush`
143    /// # Errors
144    /// propagated from `Write::flush` once subsequent writes/flushes turn into `BrokenPipe`
145    pub fn flush(&self) -> io::Result<()> {
146        let _outer_guard = unwrap_poison(self.write_mutex.lock())?; //make writes block other writes
147        unwrap_poison(self.connection.lock())?.flush()?;
148        self.write_q.flush_zero()
149    }
150
151    /// see `Read::read`
152    /// # Errors
153    /// propagated from `Read::read` once subsequent reads turn into `BrokenPipe`
154    pub fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
155        let _outer_guard = unwrap_poison(self.read_mutex.lock())?; //make reads block other reads
156        loop {
157            let mut guard = unwrap_poison(self.connection.lock())?;
158            guard.sock.0.nb(true); //Return instantly if no data.
159            let res = guard.read(buffer);
160            guard.sock.0.nb(false); //We must clear this flag or writes may go ballistic.
161            return match res {
162                Ok(count) => {
163                    drop(guard);
164                    Ok(count)
165                }
166                Err(err) => {
167                    if self.non_blocking_read.load(SeqCst) {
168                        return Err(err);
169                    }
170                    if err.kind() == ErrorKind::WouldBlock {
171                        //We have entered the fun zone where reads would block writes
172                        let timeout_copy = unwrap_poison(self.read_timeout.lock())?
173                            .deref()
174                            .as_ref()
175                            .copied();
176                        //This drops guard as oon as a handle to read_q is acquired and will return once trying to read again is meaningful.
177                        self.read_q.await_pop(guard, timeout_copy)?;
178                        continue;
179                    }
180
181                    drop(guard);
182                    Err(err)
183                }
184            };
185        }
186    }
187
188    /// sets the timeout for the writing operation.
189    /// This has no effect on the underlying connection and purely deals with internal writing semantics.
190    /// Calls to fns that writs data will return `TimedOut` if no plain text data could be written.
191    /// Cause of this is likely to be that the underlying connection does not read data fast enough.
192    /// This is never caused by writing too much data.
193    /// # Errors
194    /// In case of poisoned mutex
195    ///
196    pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
197        *unwrap_poison(self.read_timeout.lock())? = timeout;
198        Ok(())
199    }
200
201    /// Returns the current read timeout if any
202    /// # Errors
203    /// In case of poisoned mutex
204    pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
205        Ok(unwrap_poison(self.read_timeout.lock())?.as_ref().copied())
206    }
207
208    /// sets non-blocking mode for read.
209    /// This has no effect on the underlying connection and purely deals with internal reading semantics.
210    /// Calls to fns that read data will return `WouldBlock` immediately if no plain text data is available to be read.
211    /// # Errors
212    /// In case of poisoned mutex
213    pub fn set_read_non_block(&self, on: bool) -> io::Result<()> {
214        self.non_blocking_read.store(on, SeqCst);
215        Ok(())
216    }
217
218    /// sets the timeout for the writing operation.
219    /// This has no effect on the underlying connection and purely deals with internal writing semantics.
220    /// Calls to fns that writs data will return `TimedOut` if no plain text data could be written.
221    /// Cause of this is likely to be that the underlying connection does not send data fast enough.
222    /// This is never caused by reading too much data.
223    /// # Errors
224    /// In case of poisoned mutex
225    pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
226        *unwrap_poison(self.write_timeout.lock())? = timeout;
227        Ok(())
228    }
229
230    /// Returns the current write timeout if any
231    /// # Errors
232    /// In case of poisoned mutex
233    pub fn write_timeout(&self) -> io::Result<Option<Duration>> {
234        Ok(unwrap_poison(self.write_timeout.lock())?.as_ref().copied())
235    }
236
237    /// See `Read::read_to_end`
238    /// # Errors
239    /// propagated
240    pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
241        Read::read_to_end(&mut &*self, buf)
242    }
243
244    /// See `Read::read_to_string`
245    /// # Errors
246    /// propagated
247    pub fn read_to_string(&self, buf: &mut String) -> io::Result<usize> {
248        Read::read_to_string(&mut &*self, buf)
249    }
250
251    /// See `Read::read_exact`
252    /// # Errors
253    /// propagated
254    pub fn read_exact(&self, buf: &mut [u8]) -> io::Result<()> {
255        Read::read_exact(&mut &*self, buf)
256    }
257
258    /// See `Write::write_all`
259    /// # Errors
260    /// propagated
261    pub fn write_all(&self, buf: &[u8]) -> io::Result<()> {
262        Write::write_all(&mut &*self, buf)
263    }
264
265    /// See `Write::write_fmt`
266    /// # Errors
267    /// propagated
268    pub fn write_fmt(&self, fmt: Arguments<'_>) -> io::Result<()> {
269        Write::write_fmt(&mut &*self, fmt)
270    }
271}
272
273impl<C, S> Read for RustTlsDuplexStream<C, S>
274where
275    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
276    S: rustls::SideData,
277{
278    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
279        Self::read(self, buf)
280    }
281}
282
283impl<C, S> Read for &RustTlsDuplexStream<C, S>
284where
285    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
286    S: rustls::SideData,
287{
288    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
289        RustTlsDuplexStream::<C, S>::read(self, buf)
290    }
291}
292
293impl<C, S> Write for RustTlsDuplexStream<C, S>
294where
295    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
296    S: rustls::SideData,
297{
298    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
299        Self::write(self, buf)
300    }
301
302    fn flush(&mut self) -> io::Result<()> {
303        Self::flush(self)
304    }
305}
306
307impl<C, S> Write for &RustTlsDuplexStream<C, S>
308where
309    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
310    S: rustls::SideData,
311{
312    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
313        RustTlsDuplexStream::<C, S>::write(self, buf)
314    }
315
316    fn flush(&mut self) -> io::Result<()> {
317        RustTlsDuplexStream::<C, S>::flush(self)
318    }
319}
320
321/// Read+Write combiner that is fed into rust-tls and delegates to our special ReadPipe/WritePipe
322/// that have careful blocking semantics
323#[derive(Debug)]
324struct CombinedPipe(ReadPipe, WritePipe);
325
326impl CombinedPipe {
327    ///Constructor for `CombinedPipe`
328    pub fn new<
329        R: Read + Send + 'static,
330        W: Write + Send + 'static,
331        T: FnMut(Box<dyn FnOnce() + Send>) -> io::Result<()>,
332    >(
333        read: R,
334        write: W,
335        mut spawner: T,
336    ) -> io::Result<Self> {
337        Ok(Self(
338            ReadPipe::new(read, &mut spawner)?,
339            WritePipe::new(write, &mut spawner)?,
340        ))
341    }
342}
343
344impl Read for CombinedPipe {
345    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
346        self.0.read(buf)
347    }
348}
349
350impl Write for CombinedPipe {
351    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
352        self.1.write(buf)
353    }
354
355    fn flush(&mut self) -> io::Result<()> {
356        self.1.flush()
357    }
358}
359
360/// Poison error to `io::Error`
361pub(crate) fn unwrap_poison<T>(result: LockResult<T>) -> io::Result<T> {
362    result.map_err(|_| io::Error::new(ErrorKind::Other, "Poisoned Mutex"))
363}