ws_tool/
stream.rs

1#[cfg(feature = "sync")]
2mod blocking {
3    use std::{
4        io::{BufReader, BufWriter, Read, Write},
5        net::TcpStream,
6    };
7
8    use crate::codec::Split;
9    #[allow(missing_docs)]
10    pub trait RW: Read + Write {}
11
12    impl<S: Read + Write> RW for S {}
13
14    #[cfg(any(feature = "sync_tls_rustls", feature = "sync_tls_native"))]
15    mod split {
16        use std::{
17            io::{ErrorKind, Read, Write},
18            sync::{Arc, Mutex},
19        };
20
21        use crate::codec::Split;
22
23        /// reader part of a stream
24        pub struct ReadHalf<T> {
25            /// inner stream
26            pub inner: Arc<Mutex<T>>,
27        }
28
29        /// writer part of a stream
30        pub struct WriteHalf<T> {
31            /// inner stream
32            pub inner: Arc<Mutex<T>>,
33        }
34
35        macro_rules! try_lock {
36            ($lock:expr) => {
37                match $lock.lock() {
38                    Ok(guard) => guard,
39                    Err(_) => {
40                        return Err(std::io::Error::new(
41                            ErrorKind::BrokenPipe,
42                            format!("lock poisoned"),
43                        ));
44                    }
45                }
46            };
47        }
48
49        impl<T: Read> Read for ReadHalf<T> {
50            fn read_vectored(
51                &mut self,
52                bufs: &mut [std::io::IoSliceMut<'_>],
53            ) -> std::io::Result<usize> {
54                try_lock!(self.inner).read_vectored(bufs)
55            }
56
57            fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
58                try_lock!(self.inner).read(buf)
59            }
60        }
61
62        impl<T: Write> Write for WriteHalf<T> {
63            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
64                try_lock!(self.inner).write(buf)
65            }
66
67            fn flush(&mut self) -> std::io::Result<()> {
68                try_lock!(self.inner).flush()
69            }
70        }
71
72        #[cfg(feature = "sync_tls_rustls")]
73        impl<S: Read + Write> Split for rustls_connector::TlsStream<S> {
74            type R = ReadHalf<rustls_connector::TlsStream<S>>;
75
76            type W = WriteHalf<rustls_connector::TlsStream<S>>;
77
78            fn split(self) -> (Self::R, Self::W) {
79                let inner = Arc::new(Mutex::new(self));
80                let inner_c = inner.clone();
81                (ReadHalf { inner }, WriteHalf { inner: inner_c })
82            }
83        }
84
85        #[cfg(feature = "sync_tls_native")]
86        impl<S: Read + Write> Split for native_tls::TlsStream<S> {
87            type R = ReadHalf<native_tls::TlsStream<S>>;
88
89            type W = WriteHalf<native_tls::TlsStream<S>>;
90
91            fn split(self) -> (Self::R, Self::W) {
92                let inner = Arc::new(Mutex::new(self));
93                let inner_c = inner.clone();
94                (ReadHalf { inner }, WriteHalf { inner: inner_c })
95            }
96        }
97    }
98
99    macro_rules! def {
100        ($name:ident, $raw:ty, $rustls:ty, $native:ty, $doc:literal) => {
101            #[doc=$doc]
102            pub enum $name {
103                /// raw tcp stream
104                Raw($raw),
105                /// rustls wrapped stream
106                #[cfg(feature = "sync_tls_rustls")]
107                Rustls($rustls),
108                /// native tls wrapped stream
109                #[cfg(feature = "sync_tls_native")]
110                NativeTls($native),
111            }
112
113            impl std::fmt::Debug for $name {
114                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115                    match self {
116                        Self::Raw(_) => f.debug_tuple("Raw").finish(),
117                        #[cfg(feature = "sync_tls_rustls")]
118                        Self::Rustls(_) => f.debug_tuple("Rustls").finish(),
119                        #[cfg(feature = "sync_tls_native")]
120                        Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(),
121                    }
122                }
123            }
124        };
125    }
126
127    def!(
128        SyncStreamRead,
129        TcpStream,
130        split::ReadHalf<rustls_connector::TlsStream<TcpStream>>,
131        split::ReadHalf<native_tls::TlsStream<TcpStream>>,
132        "a wrapper of most common use raw/ssl tcp based stream"
133    );
134
135    def!(
136        SyncStreamWrite,
137        TcpStream,
138        split::WriteHalf<rustls_connector::TlsStream<TcpStream>>,
139        split::WriteHalf<native_tls::TlsStream<TcpStream>>,
140        "a wrapper of most common use raw/ssl tcp based stream"
141    );
142
143    def!(
144        SyncStream,
145        TcpStream,
146        rustls_connector::TlsStream<TcpStream>,
147        native_tls::TlsStream<TcpStream>,
148        "a wrapper of most common use raw/ssl tcp based stream"
149    );
150
151    macro_rules! impl_read {
152        ($name:ty) => {
153            impl Read for $name {
154                fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
155                    match self {
156                        Self::Raw(s) => s.read(buf),
157                        #[cfg(feature = "sync_tls_rustls")]
158                        Self::Rustls(s) => s.read(buf),
159                        #[cfg(feature = "sync_tls_native")]
160                        Self::NativeTls(s) => s.read(buf),
161                    }
162                }
163
164                fn read_vectored(
165                    &mut self,
166                    bufs: &mut [std::io::IoSliceMut<'_>],
167                ) -> std::io::Result<usize> {
168                    match self {
169                        Self::Raw(s) => s.read_vectored(bufs),
170                        #[cfg(feature = "sync_tls_rustls")]
171                        Self::Rustls(s) => s.read_vectored(bufs),
172                        #[cfg(feature = "sync_tls_native")]
173                        Self::NativeTls(s) => s.read_vectored(bufs),
174                    }
175                }
176            }
177        };
178    }
179
180    impl_read!(SyncStream);
181    impl_read!(SyncStreamRead);
182
183    macro_rules! impl_write {
184        ($item:ty) => {
185            impl Write for $item {
186                fn write_vectored(
187                    &mut self,
188                    bufs: &[std::io::IoSlice<'_>],
189                ) -> std::io::Result<usize> {
190                    match self {
191                        Self::Raw(s) => s.write_vectored(bufs),
192                        #[cfg(feature = "sync_tls_rustls")]
193                        Self::Rustls(s) => s.write_vectored(bufs),
194                        #[cfg(feature = "sync_tls_native")]
195                        Self::NativeTls(s) => s.write_vectored(bufs),
196                    }
197                }
198
199                fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
200                    match self {
201                        Self::Raw(s) => s.write(buf),
202                        #[cfg(feature = "sync_tls_rustls")]
203                        Self::Rustls(s) => s.write(buf),
204                        #[cfg(feature = "sync_tls_native")]
205                        Self::NativeTls(s) => s.write(buf),
206                    }
207                }
208
209                fn flush(&mut self) -> std::io::Result<()> {
210                    match self {
211                        Self::Raw(s) => s.flush(),
212                        #[cfg(feature = "sync_tls_rustls")]
213                        Self::Rustls(s) => s.flush(),
214                        #[cfg(feature = "sync_tls_native")]
215                        Self::NativeTls(s) => s.flush(),
216                    }
217                }
218            }
219        };
220    }
221
222    impl_write!(SyncStream);
223    impl_write!(SyncStreamWrite);
224
225    impl Split for SyncStream {
226        type R = SyncStreamRead;
227
228        type W = SyncStreamWrite;
229
230        fn split(self) -> (Self::R, Self::W) {
231            match self {
232                Self::Raw(s) => {
233                    let (read, write) = s.split();
234                    (SyncStreamRead::Raw(read), SyncStreamWrite::Raw(write))
235                }
236                #[cfg(feature = "sync_tls_rustls")]
237                Self::Rustls(s) => {
238                    let s = std::sync::Arc::new(std::sync::Mutex::new(s));
239                    (
240                        SyncStreamRead::Rustls(split::ReadHalf { inner: s.clone() }),
241                        SyncStreamWrite::Rustls(split::WriteHalf { inner: s }),
242                    )
243                }
244                #[cfg(feature = "sync_tls_native")]
245                Self::NativeTls(s) => {
246                    let s = std::sync::Arc::new(std::sync::Mutex::new(s));
247                    (
248                        SyncStreamRead::NativeTls(split::ReadHalf { inner: s.clone() }),
249                        SyncStreamWrite::NativeTls(split::WriteHalf { inner: s }),
250                    )
251                }
252            }
253        }
254    }
255
256    /// a buffered stream
257    pub struct BufStream<S: Read + Write>(pub BufReader<WrappedWriter<S>>);
258
259    impl<S: Read + Write> BufStream<S> {
260        /// create buf stream with default buffer size
261        pub fn new(stream: S) -> Self {
262            Self(BufReader::new(WrappedWriter(BufWriter::new(stream))))
263        }
264
265        /// specify buf capacity
266        pub fn with_capacity(read: usize, write: usize, stream: S) -> Self {
267            let writer = BufWriter::with_capacity(write, stream);
268            let reader = BufReader::with_capacity(read, WrappedWriter(writer));
269            Self(reader)
270        }
271
272        /// get mut ref of underlaying stream
273        pub fn get_mut(&mut self) -> &mut S {
274            self.0.get_mut().0.get_mut()
275        }
276    }
277
278    impl<S: Read + Write> std::fmt::Debug for BufStream<S> {
279        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280            f.debug_struct("BufStream").finish()
281        }
282    }
283
284    impl<S: Read + Write> Read for BufStream<S> {
285        fn read_vectored(
286            &mut self,
287            bufs: &mut [std::io::IoSliceMut<'_>],
288        ) -> std::io::Result<usize> {
289            self.0.read_vectored(bufs)
290        }
291
292        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
293            self.0.read(buf)
294        }
295    }
296    impl<S: Read + Write> Write for BufStream<S> {
297        fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
298            self.0.get_mut().write_vectored(bufs)
299        }
300
301        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
302            self.0.get_mut().write(buf)
303        }
304
305        fn flush(&mut self) -> std::io::Result<()> {
306            self.0.get_mut().flush()
307        }
308    }
309
310    /// simple wrapper of buf writer
311    pub struct WrappedWriter<S: Write>(pub BufWriter<S>);
312
313    impl<S: Read + Write> Read for WrappedWriter<S> {
314        fn read_vectored(
315            &mut self,
316            bufs: &mut [std::io::IoSliceMut<'_>],
317        ) -> std::io::Result<usize> {
318            self.0.get_mut().read_vectored(bufs)
319        }
320
321        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
322            self.0.get_mut().read(buf)
323        }
324    }
325
326    impl<S: Write> Write for WrappedWriter<S> {
327        fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
328            self.0.write_vectored(bufs)
329        }
330
331        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
332            self.0.write(buf)
333        }
334
335        fn flush(&mut self) -> std::io::Result<()> {
336            self.0.flush()
337        }
338    }
339
340    impl<S, R, W> crate::codec::Split for BufStream<S>
341    where
342        R: Read,
343        W: Write,
344        S: Read + Write + crate::codec::Split<R = R, W = W> + std::fmt::Debug,
345    {
346        type R = BufReader<R>;
347
348        type W = BufWriter<W>;
349
350        fn split(self) -> (Self::R, Self::W) {
351            let read_cap = self.0.capacity();
352            let write_cap = self.0.get_ref().0.capacity();
353            let inner = self.0.into_inner().0.into_inner().unwrap();
354            let (r, w) = inner.split();
355            (
356                BufReader::with_capacity(read_cap, r),
357                BufWriter::with_capacity(write_cap, w),
358            )
359        }
360    }
361}
362
363#[cfg(feature = "sync")]
364pub use blocking::*;
365
366#[cfg(feature = "async")]
367mod non_blocking {
368    use std::pin::Pin;
369
370    use tokio::{
371        io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf},
372        net::TcpStream,
373    };
374
375    use crate::codec::Split;
376
377    #[allow(missing_docs)]
378    pub trait AsyncRW: AsyncRead + AsyncWrite + Unpin {}
379
380    impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRW for S {}
381
382    /// a wrapper of most common use raw/ssl tcp based stream
383    pub enum AsyncStream {
384        /// raw tcp stream
385        Raw(TcpStream),
386        /// rustls wrapped stream
387        #[cfg(feature = "async_tls_rustls")]
388        Rustls(tokio_rustls::TlsStream<TcpStream>),
389        /// native tls wrapped stream
390        #[cfg(feature = "async_tls_native")]
391        NativeTls(tokio_native_tls::TlsStream<TcpStream>),
392    }
393
394    impl Split for AsyncStream {
395        type R = ReadHalf<Self>;
396
397        type W = WriteHalf<Self>;
398
399        fn split(self) -> (Self::R, Self::W) {
400            tokio::io::split(self)
401        }
402    }
403
404    impl AsyncRead for AsyncStream {
405        fn poll_read(
406            self: std::pin::Pin<&mut Self>,
407            cx: &mut std::task::Context<'_>,
408            buf: &mut tokio::io::ReadBuf<'_>,
409        ) -> std::task::Poll<std::io::Result<()>> {
410            match self.get_mut() {
411                AsyncStream::Raw(s) => std::pin::Pin::new(s).poll_read(cx, buf),
412                #[cfg(feature = "async_tls_rustls")]
413                AsyncStream::Rustls(s) => std::pin::Pin::new(s).poll_read(cx, buf),
414                #[cfg(feature = "async_tls_native")]
415                AsyncStream::NativeTls(s) => std::pin::Pin::new(s).poll_read(cx, buf),
416            }
417        }
418    }
419
420    impl AsyncWrite for AsyncStream {
421        fn poll_write(
422            self: Pin<&mut Self>,
423            cx: &mut std::task::Context<'_>,
424            buf: &[u8],
425        ) -> std::task::Poll<Result<usize, std::io::Error>> {
426            match self.get_mut() {
427                AsyncStream::Raw(s) => std::pin::Pin::new(s).poll_write(cx, buf),
428                #[cfg(feature = "async_tls_rustls")]
429                AsyncStream::Rustls(s) => std::pin::Pin::new(s).poll_write(cx, buf),
430                #[cfg(feature = "async_tls_native")]
431                AsyncStream::NativeTls(s) => std::pin::Pin::new(s).poll_write(cx, buf),
432            }
433        }
434
435        fn poll_flush(
436            self: Pin<&mut Self>,
437            cx: &mut std::task::Context<'_>,
438        ) -> std::task::Poll<Result<(), std::io::Error>> {
439            match self.get_mut() {
440                AsyncStream::Raw(s) => std::pin::Pin::new(s).poll_flush(cx),
441                #[cfg(feature = "async_tls_rustls")]
442                AsyncStream::Rustls(s) => std::pin::Pin::new(s).poll_flush(cx),
443                #[cfg(feature = "async_tls_native")]
444                AsyncStream::NativeTls(s) => std::pin::Pin::new(s).poll_flush(cx),
445            }
446        }
447
448        fn poll_shutdown(
449            self: Pin<&mut Self>,
450            cx: &mut std::task::Context<'_>,
451        ) -> std::task::Poll<Result<(), std::io::Error>> {
452            match self.get_mut() {
453                AsyncStream::Raw(s) => std::pin::Pin::new(s).poll_shutdown(cx),
454                #[cfg(feature = "async_tls_rustls")]
455                AsyncStream::Rustls(s) => std::pin::Pin::new(s).poll_shutdown(cx),
456                #[cfg(feature = "async_tls_native")]
457                AsyncStream::NativeTls(s) => std::pin::Pin::new(s).poll_shutdown(cx),
458            }
459        }
460    }
461}
462
463#[cfg(feature = "async")]
464pub use non_blocking::*;