1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
use crate::tls::{ChannelBinding, TlsStream}; use bytes::{Buf, BufMut}; use std::io; use std::mem::MaybeUninit; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite}; pub enum MaybeTlsStream<S, T> { Raw(S), Tls(T), } impl<S, T> AsyncRead for MaybeTlsStream<S, T> where S: AsyncRead + Unpin, T: AsyncRead + Unpin, { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { match self { MaybeTlsStream::Raw(s) => s.prepare_uninitialized_buffer(buf), MaybeTlsStream::Tls(s) => s.prepare_uninitialized_buffer(buf), } } fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>> { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf), MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf), } } fn poll_read_buf<B>( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll<io::Result<usize>> where Self: Sized, B: BufMut, { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_read_buf(cx, buf), MaybeTlsStream::Tls(s) => Pin::new(s).poll_read_buf(cx, buf), } } } impl<S, T> AsyncWrite for MaybeTlsStream<S, T> where S: AsyncWrite + Unpin, T: AsyncWrite + Unpin, { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>> { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf), MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx), MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx), } } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx), MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx), } } fn poll_write_buf<B>( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll<io::Result<usize>> where Self: Sized, B: Buf, { match &mut *self { MaybeTlsStream::Raw(s) => Pin::new(s).poll_write_buf(cx, buf), MaybeTlsStream::Tls(s) => Pin::new(s).poll_write_buf(cx, buf), } } } impl<S, T> TlsStream for MaybeTlsStream<S, T> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { fn channel_binding(&self) -> ChannelBinding { match self { MaybeTlsStream::Raw(_) => ChannelBinding::none(), MaybeTlsStream::Tls(s) => s.channel_binding(), } } }