rust_tls_duplex_stream/lib.rs
1//! Full duplex stream wrapper around rust-tls
2
3#![deny(clippy::correctness)]
4#![warn(
5 clippy::perf,
6 clippy::complexity,
7 clippy::style,
8 clippy::nursery,
9 clippy::pedantic,
10 clippy::clone_on_ref_ptr,
11 clippy::decimal_literal_representation,
12 clippy::float_cmp_const,
13 clippy::missing_docs_in_private_items,
14 clippy::multiple_inherent_impl,
15 clippy::unwrap_used,
16 clippy::cargo_common_metadata,
17 clippy::used_underscore_binding
18)]
19
20mod queue;
21mod read_pipe;
22mod write_pipe;
23use crate::queue::Queue;
24use crate::read_pipe::ReadPipe;
25use crate::write_pipe::WritePipe;
26use rustls::{ConnectionCommon, StreamOwned};
27use std::fmt::{Arguments, Debug};
28use std::io::{ErrorKind, Read, Write};
29use std::ops::{Deref, DerefMut};
30use std::sync::atomic::AtomicBool;
31use std::sync::atomic::Ordering::SeqCst;
32use std::sync::{Arc, LockResult, Mutex};
33use std::time::Duration;
34use std::{io, thread};
35
36#[derive(Debug)]
37pub struct RustTlsDuplexStream<C, S>
38where
39 C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
40{
41 /// Flag for non blocking read.
42 non_blocking_read: AtomicBool,
43 /// Read timeout
44 read_timeout: Mutex<Option<Duration>>,
45 /// Write timeout
46 write_timeout: Mutex<Option<Duration>>,
47 /// Inner rust-tls pseudo connection
48 connection: Mutex<StreamOwned<C, CombinedPipe>>,
49 /// Read queue connected to the thread that reads data from the actual connection
50 read_q: Arc<Queue>,
51 /// Write queue connected to the thread that writes data to the actual connection
52 write_q: Arc<Queue>,
53 /// Guard mutex that prevents concurrent writes.
54 write_mutex: Mutex<()>,
55 /// Guard mutex that prevents concurrent reads.
56 read_mutex: Mutex<()>,
57}
58
59impl<C, S> RustTlsDuplexStream<C, S>
60where
61 C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
62 S: rustls::SideData,
63{
64 ///
65 /// Creates a new 'unpooled' Tls stream wrapper.
66 /// This is a good choice for an application such as a client
67 /// that does not create connections and doesn't have a thread pool.
68 ///
69 /// This fn will spawn 2 new threads using `thread::Builder::new().spawn(...)`.
70 /// The threads will terminate when the returned stream is dropped and the read/write errors out.
71 ///
72 /// # Resource Leaks
73 /// Be aware that a Read which "blocks" forever will cause the thread that's reading to stay alive forever.
74 /// Either set a reasonable connection read timeout so that your Read will eventually return
75 /// or call your connections shutdown fn like `TcpStream::shutdown` if such a method exists
76 /// to ensure that all threads are stopped and no resources are leaked.
77 ///
78 /// # Errors
79 /// if `thread::Builder::new().spawn` fails to spawn 2 threads.
80 ///
81 pub fn new_unpooled<R, W>(con: C, read: R, write: W) -> io::Result<Self>
82 where
83 R: Read + Send + 'static,
84 W: Write + Send + 'static,
85 {
86 Self::new(con, read, write, |task| {
87 thread::Builder::new().spawn(task).map(|_| {})
88 })
89 }
90
91 ///
92 /// Creates a new Tls stream wrapper.
93 ///
94 /// This fn will spawn 2 new threads using the provided thread spawner function.
95 /// The thread spawner function is called exactly twice. If the first call yields an error then
96 /// it's not called again. The functions ("threads") will end when the returned stream wrapper is dropped.
97 ///
98 /// # Resource Leaks
99 /// Be aware that a Read which "blocks" forever will cause the thread that's reading to stay alive forever.
100 /// Either set a reasonable connection read timeout so that your Read will eventually return
101 /// or call your connections shutdown fn like `TcpStream::shutdown` if such a method exists
102 /// to ensure that all threads are stopped and no resources are leaked.
103 ///
104 /// # Errors
105 /// propagated from the spawner fn.
106 ///
107 pub fn new<R, W, T>(con: C, read: R, write: W, spawner: T) -> io::Result<Self>
108 where
109 R: Read + Send + 'static,
110 W: Write + Send + 'static,
111 T: FnMut(Box<dyn FnOnce() + Send>) -> io::Result<()>,
112 {
113 Self::new_with_initial_data(con, read, write, spawner, Vec::new())
114 }
115 ///
116 /// Creates a new Tls stream wrapper.
117 ///
118 /// This fn will spawn 2 new threads using the provided thread spawner function.
119 /// The thread spawner function is called exactly twice. If the first call yields an error then
120 /// it's not called again. The functions ("threads") will end when the returned stream wrapper is dropped.
121 ///
122 /// # Resource Leaks
123 /// Be aware that a Read which "blocks" forever will cause the thread that's reading to stay alive forever.
124 /// Either set a reasonable connection read timeout so that your Read will eventually return
125 /// or call your connections shutdown fn like `TcpStream::shutdown` if such a method exists
126 /// to ensure that all threads are stopped and no resources are leaked.
127 ///
128 /// # Arguments
129 /// * `initial_data` contains data that will be consumed before data is read from the provided `io::Read`.
130 /// this data may be the result of some previous operation that first detected if a connection is even TLS or not.
131 /// Usually the `initial_data` Vec would contain the TLS hello message or the start of it.
132 ///
133 /// # Errors
134 /// propagated from the spawner fn.
135 ///
136 pub fn new_with_initial_data<R, W, T>(
137 con: C,
138 read: R,
139 write: W,
140 spawner: T,
141 initial_data: Vec<u8>,
142 ) -> io::Result<Self>
143 where
144 R: Read + Send + 'static,
145 W: Write + Send + 'static,
146 T: FnMut(Box<dyn FnOnce() + Send>) -> io::Result<()>,
147 {
148 let pipe = CombinedPipe::new(read, write, spawner, initial_data)?;
149 let read_q = pipe.0.dup_queue();
150 let write_q = pipe.1.dup_queue();
151
152 Ok(Self {
153 non_blocking_read: AtomicBool::new(false),
154 read_q,
155 write_q,
156 write_mutex: Mutex::new(()),
157 read_mutex: Mutex::new(()),
158 connection: Mutex::new(StreamOwned::new(con, pipe)),
159 read_timeout: Mutex::new(None),
160 write_timeout: Mutex::new(None),
161 })
162 }
163
164 /// see `Write::write`
165 /// # Errors
166 /// propagated from `Write::write` once subsequent writes/flushes turn into `BrokenPipe`
167 pub fn write(&self, buffer: &[u8]) -> io::Result<usize> {
168 let _outer_guard = unwrap_poison(self.write_mutex.lock())?; //make writes block other writes
169 let timeout_copy = unwrap_poison(self.write_timeout.lock())?
170 .deref()
171 .as_ref()
172 .copied();
173 self.write_q.flush_low(timeout_copy)?;
174 unwrap_poison(self.connection.lock())?.write(buffer)
175 }
176
177 /// see `Write::flush`
178 /// # Errors
179 /// propagated from `Write::flush` once subsequent writes/flushes turn into `BrokenPipe`
180 pub fn flush(&self) -> io::Result<()> {
181 let _outer_guard = unwrap_poison(self.write_mutex.lock())?; //make writes block other writes
182 unwrap_poison(self.connection.lock())?.flush()?;
183 self.write_q.flush_zero()
184 }
185
186 /// see `Read::read`
187 /// # Errors
188 /// propagated from `Read::read` once subsequent reads turn into `BrokenPipe`
189 pub fn read(&self, buffer: &mut [u8]) -> io::Result<usize> {
190 let _outer_guard = unwrap_poison(self.read_mutex.lock())?; //make reads block other reads
191 loop {
192 let mut guard = unwrap_poison(self.connection.lock())?;
193 guard.sock.0.nb(true); //Return instantly if no data.
194 let res = guard.read(buffer);
195 guard.sock.0.nb(false); //We must clear this flag or writes may go ballistic.
196 return match res {
197 Ok(count) => {
198 drop(guard);
199 Ok(count)
200 }
201 Err(err) => {
202 if self.non_blocking_read.load(SeqCst) {
203 return Err(err);
204 }
205 if err.kind() == ErrorKind::WouldBlock {
206 //We have entered the fun zone where reads would block writes
207 let timeout_copy = unwrap_poison(self.read_timeout.lock())?
208 .deref()
209 .as_ref()
210 .copied();
211 //This drops guard as soon as a handle to read_q is acquired and will return once trying to read again is meaningful.
212 self.read_q.await_pop(guard, timeout_copy)?;
213 continue;
214 }
215
216 drop(guard);
217 Err(err)
218 }
219 };
220 }
221 }
222
223 /// sets the timeout for the writing operation.
224 /// This has no effect on the underlying connection and purely deals with internal writing semantics.
225 /// Calls to fns that writs data will return `TimedOut` if no plain text data could be written.
226 /// Cause of this is likely to be that the underlying connection does not read data fast enough.
227 /// This is never caused by writing too much data.
228 /// # Errors
229 /// In case of poisoned mutex
230 ///
231 pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
232 *unwrap_poison(self.read_timeout.lock())? = timeout;
233 Ok(())
234 }
235
236 /// Returns the current read timeout if any
237 /// # Errors
238 /// In case of poisoned mutex
239 pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
240 Ok(unwrap_poison(self.read_timeout.lock())?.as_ref().copied())
241 }
242
243 /// sets non-blocking mode for read.
244 /// This has no effect on the underlying connection and purely deals with internal reading semantics.
245 /// Calls to fns that read data will return `WouldBlock` immediately if no plain text data is available to be read.
246 /// # Errors
247 /// In case of poisoned mutex
248 pub fn set_read_non_block(&self, on: bool) -> io::Result<()> {
249 self.non_blocking_read.store(on, SeqCst);
250 Ok(())
251 }
252
253 /// sets the timeout for the writing operation.
254 /// This has no effect on the underlying connection and purely deals with internal writing semantics.
255 /// Calls to fns that writs data will return `TimedOut` if no plain text data could be written.
256 /// Cause of this is likely to be that the underlying connection does not send data fast enough.
257 /// This is never caused by reading too much data.
258 /// # Errors
259 /// In case of poisoned mutex
260 pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
261 *unwrap_poison(self.write_timeout.lock())? = timeout;
262 Ok(())
263 }
264
265 /// Returns the current write timeout if any
266 /// # Errors
267 /// In case of poisoned mutex
268 pub fn write_timeout(&self) -> io::Result<Option<Duration>> {
269 Ok(unwrap_poison(self.write_timeout.lock())?.as_ref().copied())
270 }
271
272 /// See `Read::read_to_end`
273 /// # Errors
274 /// propagated
275 pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
276 Read::read_to_end(&mut &*self, buf)
277 }
278
279 /// See `Read::read_to_string`
280 /// # Errors
281 /// propagated
282 pub fn read_to_string(&self, buf: &mut String) -> io::Result<usize> {
283 Read::read_to_string(&mut &*self, buf)
284 }
285
286 /// See `Read::read_exact`
287 /// # Errors
288 /// propagated
289 pub fn read_exact(&self, buf: &mut [u8]) -> io::Result<()> {
290 Read::read_exact(&mut &*self, buf)
291 }
292
293 /// See `Write::write_all`
294 /// # Errors
295 /// propagated
296 pub fn write_all(&self, buf: &[u8]) -> io::Result<()> {
297 Write::write_all(&mut &*self, buf)
298 }
299
300 /// See `Write::write_fmt`
301 /// # Errors
302 /// propagated
303 pub fn write_fmt(&self, fmt: Arguments<'_>) -> io::Result<()> {
304 Write::write_fmt(&mut &*self, fmt)
305 }
306}
307
308impl<C, S> Read for RustTlsDuplexStream<C, S>
309where
310 C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
311 S: rustls::SideData,
312{
313 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
314 Self::read(self, buf)
315 }
316}
317
318impl<C, S> Read for &RustTlsDuplexStream<C, S>
319where
320 C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
321 S: rustls::SideData,
322{
323 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
324 RustTlsDuplexStream::<C, S>::read(self, buf)
325 }
326}
327
328impl<C, S> Write for RustTlsDuplexStream<C, S>
329where
330 C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
331 S: rustls::SideData,
332{
333 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
334 Self::write(self, buf)
335 }
336
337 fn flush(&mut self) -> io::Result<()> {
338 Self::flush(self)
339 }
340}
341
342impl<C, S> Write for &RustTlsDuplexStream<C, S>
343where
344 C: DerefMut + Deref<Target = ConnectionCommon<S>> + Send,
345 S: rustls::SideData,
346{
347 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
348 RustTlsDuplexStream::<C, S>::write(self, buf)
349 }
350
351 fn flush(&mut self) -> io::Result<()> {
352 RustTlsDuplexStream::<C, S>::flush(self)
353 }
354}
355
356/// Read+Write combiner that is fed into rust-tls and delegates to our special ReadPipe/WritePipe
357/// that have careful blocking semantics
358#[derive(Debug)]
359struct CombinedPipe(ReadPipe, WritePipe);
360
361impl CombinedPipe {
362 ///Constructor for `CombinedPipe`
363 pub fn new<
364 R: Read + Send + 'static,
365 W: Write + Send + 'static,
366 T: FnMut(Box<dyn FnOnce() + Send>) -> io::Result<()>,
367 >(
368 read: R,
369 write: W,
370 mut spawner: T,
371 initial_data: Vec<u8>,
372 ) -> io::Result<Self> {
373 Ok(Self(
374 ReadPipe::new(read, &mut spawner, initial_data)?,
375 WritePipe::new(write, &mut spawner)?,
376 ))
377 }
378}
379
380impl Read for CombinedPipe {
381 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
382 self.0.read(buf)
383 }
384}
385
386impl Write for CombinedPipe {
387 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
388 self.1.write(buf)
389 }
390
391 fn flush(&mut self) -> io::Result<()> {
392 self.1.flush()
393 }
394}
395
396/// Poison error to `io::Error`
397pub(crate) fn unwrap_poison<T>(result: LockResult<T>) -> io::Result<T> {
398 result.map_err(|_| io::Error::new(ErrorKind::Other, "Poisoned Mutex"))
399}