1#[cfg(feature = "sync")]
2mod blocking {
3 use std::{
4 io::{BufReader, BufWriter, Read, Write},
5 net::TcpStream,
6 };
7
8 use crate::codec::Split;
9 #[allow(missing_docs)]
10 pub trait RW: Read + Write {}
11
12 impl<S: Read + Write> RW for S {}
13
14 #[cfg(any(feature = "sync_tls_rustls", feature = "sync_tls_native"))]
15 mod split {
16 use std::{
17 io::{ErrorKind, Read, Write},
18 sync::{Arc, Mutex},
19 };
20
21 use crate::codec::Split;
22
23 pub struct ReadHalf<T> {
25 pub inner: Arc<Mutex<T>>,
27 }
28
29 pub struct WriteHalf<T> {
31 pub inner: Arc<Mutex<T>>,
33 }
34
35 macro_rules! try_lock {
36 ($lock:expr) => {
37 match $lock.lock() {
38 Ok(guard) => guard,
39 Err(_) => {
40 return Err(std::io::Error::new(
41 ErrorKind::BrokenPipe,
42 format!("lock poisoned"),
43 ));
44 }
45 }
46 };
47 }
48
49 impl<T: Read> Read for ReadHalf<T> {
50 fn read_vectored(
51 &mut self,
52 bufs: &mut [std::io::IoSliceMut<'_>],
53 ) -> std::io::Result<usize> {
54 try_lock!(self.inner).read_vectored(bufs)
55 }
56
57 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
58 try_lock!(self.inner).read(buf)
59 }
60 }
61
62 impl<T: Write> Write for WriteHalf<T> {
63 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
64 try_lock!(self.inner).write(buf)
65 }
66
67 fn flush(&mut self) -> std::io::Result<()> {
68 try_lock!(self.inner).flush()
69 }
70 }
71
72 #[cfg(feature = "sync_tls_rustls")]
73 impl<S: Read + Write> Split for rustls_connector::TlsStream<S> {
74 type R = ReadHalf<rustls_connector::TlsStream<S>>;
75
76 type W = WriteHalf<rustls_connector::TlsStream<S>>;
77
78 fn split(self) -> (Self::R, Self::W) {
79 let inner = Arc::new(Mutex::new(self));
80 let inner_c = inner.clone();
81 (ReadHalf { inner }, WriteHalf { inner: inner_c })
82 }
83 }
84
85 #[cfg(feature = "sync_tls_native")]
86 impl<S: Read + Write> Split for native_tls::TlsStream<S> {
87 type R = ReadHalf<native_tls::TlsStream<S>>;
88
89 type W = WriteHalf<native_tls::TlsStream<S>>;
90
91 fn split(self) -> (Self::R, Self::W) {
92 let inner = Arc::new(Mutex::new(self));
93 let inner_c = inner.clone();
94 (ReadHalf { inner }, WriteHalf { inner: inner_c })
95 }
96 }
97 }
98
99 macro_rules! def {
100 ($name:ident, $raw:ty, $rustls:ty, $native:ty, $doc:literal) => {
101 #[doc=$doc]
102 pub enum $name {
103 Raw($raw),
105 #[cfg(feature = "sync_tls_rustls")]
107 Rustls($rustls),
108 #[cfg(feature = "sync_tls_native")]
110 NativeTls($native),
111 }
112
113 impl std::fmt::Debug for $name {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 match self {
116 Self::Raw(_) => f.debug_tuple("Raw").finish(),
117 #[cfg(feature = "sync_tls_rustls")]
118 Self::Rustls(_) => f.debug_tuple("Rustls").finish(),
119 #[cfg(feature = "sync_tls_native")]
120 Self::NativeTls(_) => f.debug_tuple("NativeTls").finish(),
121 }
122 }
123 }
124 };
125 }
126
127 def!(
128 SyncStreamRead,
129 TcpStream,
130 split::ReadHalf<rustls_connector::TlsStream<TcpStream>>,
131 split::ReadHalf<native_tls::TlsStream<TcpStream>>,
132 "a wrapper of most common use raw/ssl tcp based stream"
133 );
134
135 def!(
136 SyncStreamWrite,
137 TcpStream,
138 split::WriteHalf<rustls_connector::TlsStream<TcpStream>>,
139 split::WriteHalf<native_tls::TlsStream<TcpStream>>,
140 "a wrapper of most common use raw/ssl tcp based stream"
141 );
142
143 def!(
144 SyncStream,
145 TcpStream,
146 rustls_connector::TlsStream<TcpStream>,
147 native_tls::TlsStream<TcpStream>,
148 "a wrapper of most common use raw/ssl tcp based stream"
149 );
150
151 macro_rules! impl_read {
152 ($name:ty) => {
153 impl Read for $name {
154 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
155 match self {
156 Self::Raw(s) => s.read(buf),
157 #[cfg(feature = "sync_tls_rustls")]
158 Self::Rustls(s) => s.read(buf),
159 #[cfg(feature = "sync_tls_native")]
160 Self::NativeTls(s) => s.read(buf),
161 }
162 }
163
164 fn read_vectored(
165 &mut self,
166 bufs: &mut [std::io::IoSliceMut<'_>],
167 ) -> std::io::Result<usize> {
168 match self {
169 Self::Raw(s) => s.read_vectored(bufs),
170 #[cfg(feature = "sync_tls_rustls")]
171 Self::Rustls(s) => s.read_vectored(bufs),
172 #[cfg(feature = "sync_tls_native")]
173 Self::NativeTls(s) => s.read_vectored(bufs),
174 }
175 }
176 }
177 };
178 }
179
180 impl_read!(SyncStream);
181 impl_read!(SyncStreamRead);
182
183 macro_rules! impl_write {
184 ($item:ty) => {
185 impl Write for $item {
186 fn write_vectored(
187 &mut self,
188 bufs: &[std::io::IoSlice<'_>],
189 ) -> std::io::Result<usize> {
190 match self {
191 Self::Raw(s) => s.write_vectored(bufs),
192 #[cfg(feature = "sync_tls_rustls")]
193 Self::Rustls(s) => s.write_vectored(bufs),
194 #[cfg(feature = "sync_tls_native")]
195 Self::NativeTls(s) => s.write_vectored(bufs),
196 }
197 }
198
199 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
200 match self {
201 Self::Raw(s) => s.write(buf),
202 #[cfg(feature = "sync_tls_rustls")]
203 Self::Rustls(s) => s.write(buf),
204 #[cfg(feature = "sync_tls_native")]
205 Self::NativeTls(s) => s.write(buf),
206 }
207 }
208
209 fn flush(&mut self) -> std::io::Result<()> {
210 match self {
211 Self::Raw(s) => s.flush(),
212 #[cfg(feature = "sync_tls_rustls")]
213 Self::Rustls(s) => s.flush(),
214 #[cfg(feature = "sync_tls_native")]
215 Self::NativeTls(s) => s.flush(),
216 }
217 }
218 }
219 };
220 }
221
222 impl_write!(SyncStream);
223 impl_write!(SyncStreamWrite);
224
225 impl Split for SyncStream {
226 type R = SyncStreamRead;
227
228 type W = SyncStreamWrite;
229
230 fn split(self) -> (Self::R, Self::W) {
231 match self {
232 Self::Raw(s) => {
233 let (read, write) = s.split();
234 (SyncStreamRead::Raw(read), SyncStreamWrite::Raw(write))
235 }
236 #[cfg(feature = "sync_tls_rustls")]
237 Self::Rustls(s) => {
238 let s = std::sync::Arc::new(std::sync::Mutex::new(s));
239 (
240 SyncStreamRead::Rustls(split::ReadHalf { inner: s.clone() }),
241 SyncStreamWrite::Rustls(split::WriteHalf { inner: s }),
242 )
243 }
244 #[cfg(feature = "sync_tls_native")]
245 Self::NativeTls(s) => {
246 let s = std::sync::Arc::new(std::sync::Mutex::new(s));
247 (
248 SyncStreamRead::NativeTls(split::ReadHalf { inner: s.clone() }),
249 SyncStreamWrite::NativeTls(split::WriteHalf { inner: s }),
250 )
251 }
252 }
253 }
254 }
255
256 pub struct BufStream<S: Read + Write>(pub BufReader<WrappedWriter<S>>);
258
259 impl<S: Read + Write> BufStream<S> {
260 pub fn new(stream: S) -> Self {
262 Self(BufReader::new(WrappedWriter(BufWriter::new(stream))))
263 }
264
265 pub fn with_capacity(read: usize, write: usize, stream: S) -> Self {
267 let writer = BufWriter::with_capacity(write, stream);
268 let reader = BufReader::with_capacity(read, WrappedWriter(writer));
269 Self(reader)
270 }
271
272 pub fn get_mut(&mut self) -> &mut S {
274 self.0.get_mut().0.get_mut()
275 }
276 }
277
278 impl<S: Read + Write> std::fmt::Debug for BufStream<S> {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 f.debug_struct("BufStream").finish()
281 }
282 }
283
284 impl<S: Read + Write> Read for BufStream<S> {
285 fn read_vectored(
286 &mut self,
287 bufs: &mut [std::io::IoSliceMut<'_>],
288 ) -> std::io::Result<usize> {
289 self.0.read_vectored(bufs)
290 }
291
292 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
293 self.0.read(buf)
294 }
295 }
296 impl<S: Read + Write> Write for BufStream<S> {
297 fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
298 self.0.get_mut().write_vectored(bufs)
299 }
300
301 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
302 self.0.get_mut().write(buf)
303 }
304
305 fn flush(&mut self) -> std::io::Result<()> {
306 self.0.get_mut().flush()
307 }
308 }
309
310 pub struct WrappedWriter<S: Write>(pub BufWriter<S>);
312
313 impl<S: Read + Write> Read for WrappedWriter<S> {
314 fn read_vectored(
315 &mut self,
316 bufs: &mut [std::io::IoSliceMut<'_>],
317 ) -> std::io::Result<usize> {
318 self.0.get_mut().read_vectored(bufs)
319 }
320
321 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
322 self.0.get_mut().read(buf)
323 }
324 }
325
326 impl<S: Write> Write for WrappedWriter<S> {
327 fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
328 self.0.write_vectored(bufs)
329 }
330
331 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
332 self.0.write(buf)
333 }
334
335 fn flush(&mut self) -> std::io::Result<()> {
336 self.0.flush()
337 }
338 }
339
340 impl<S, R, W> crate::codec::Split for BufStream<S>
341 where
342 R: Read,
343 W: Write,
344 S: Read + Write + crate::codec::Split<R = R, W = W> + std::fmt::Debug,
345 {
346 type R = BufReader<R>;
347
348 type W = BufWriter<W>;
349
350 fn split(self) -> (Self::R, Self::W) {
351 let read_cap = self.0.capacity();
352 let write_cap = self.0.get_ref().0.capacity();
353 let inner = self.0.into_inner().0.into_inner().unwrap();
354 let (r, w) = inner.split();
355 (
356 BufReader::with_capacity(read_cap, r),
357 BufWriter::with_capacity(write_cap, w),
358 )
359 }
360 }
361}
362
363#[cfg(feature = "sync")]
364pub use blocking::*;
365
366#[cfg(feature = "async")]
367mod non_blocking {
368 use std::pin::Pin;
369
370 use tokio::{
371 io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf},
372 net::TcpStream,
373 };
374
375 use crate::codec::Split;
376
377 #[allow(missing_docs)]
378 pub trait AsyncRW: AsyncRead + AsyncWrite + Unpin {}
379
380 impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRW for S {}
381
382 pub enum AsyncStream {
384 Raw(TcpStream),
386 #[cfg(feature = "async_tls_rustls")]
388 Rustls(tokio_rustls::TlsStream<TcpStream>),
389 #[cfg(feature = "async_tls_native")]
391 NativeTls(tokio_native_tls::TlsStream<TcpStream>),
392 }
393
394 impl Split for AsyncStream {
395 type R = ReadHalf<Self>;
396
397 type W = WriteHalf<Self>;
398
399 fn split(self) -> (Self::R, Self::W) {
400 tokio::io::split(self)
401 }
402 }
403
404 impl AsyncRead for AsyncStream {
405 fn poll_read(
406 self: std::pin::Pin<&mut Self>,
407 cx: &mut std::task::Context<'_>,
408 buf: &mut tokio::io::ReadBuf<'_>,
409 ) -> std::task::Poll<std::io::Result<()>> {
410 match self.get_mut() {
411 AsyncStream::Raw(s) => std::pin::Pin::new(s).poll_read(cx, buf),
412 #[cfg(feature = "async_tls_rustls")]
413 AsyncStream::Rustls(s) => std::pin::Pin::new(s).poll_read(cx, buf),
414 #[cfg(feature = "async_tls_native")]
415 AsyncStream::NativeTls(s) => std::pin::Pin::new(s).poll_read(cx, buf),
416 }
417 }
418 }
419
420 impl AsyncWrite for AsyncStream {
421 fn poll_write(
422 self: Pin<&mut Self>,
423 cx: &mut std::task::Context<'_>,
424 buf: &[u8],
425 ) -> std::task::Poll<Result<usize, std::io::Error>> {
426 match self.get_mut() {
427 AsyncStream::Raw(s) => std::pin::Pin::new(s).poll_write(cx, buf),
428 #[cfg(feature = "async_tls_rustls")]
429 AsyncStream::Rustls(s) => std::pin::Pin::new(s).poll_write(cx, buf),
430 #[cfg(feature = "async_tls_native")]
431 AsyncStream::NativeTls(s) => std::pin::Pin::new(s).poll_write(cx, buf),
432 }
433 }
434
435 fn poll_flush(
436 self: Pin<&mut Self>,
437 cx: &mut std::task::Context<'_>,
438 ) -> std::task::Poll<Result<(), std::io::Error>> {
439 match self.get_mut() {
440 AsyncStream::Raw(s) => std::pin::Pin::new(s).poll_flush(cx),
441 #[cfg(feature = "async_tls_rustls")]
442 AsyncStream::Rustls(s) => std::pin::Pin::new(s).poll_flush(cx),
443 #[cfg(feature = "async_tls_native")]
444 AsyncStream::NativeTls(s) => std::pin::Pin::new(s).poll_flush(cx),
445 }
446 }
447
448 fn poll_shutdown(
449 self: Pin<&mut Self>,
450 cx: &mut std::task::Context<'_>,
451 ) -> std::task::Poll<Result<(), std::io::Error>> {
452 match self.get_mut() {
453 AsyncStream::Raw(s) => std::pin::Pin::new(s).poll_shutdown(cx),
454 #[cfg(feature = "async_tls_rustls")]
455 AsyncStream::Rustls(s) => std::pin::Pin::new(s).poll_shutdown(cx),
456 #[cfg(feature = "async_tls_native")]
457 AsyncStream::NativeTls(s) => std::pin::Pin::new(s).poll_shutdown(cx),
458 }
459 }
460 }
461}
462
463#[cfg(feature = "async")]
464pub use non_blocking::*;