websocket_relay/
stream.rs

1use std::net::SocketAddr;
2use tokio::{
3    io::{AsyncRead, AsyncWrite},
4    net::TcpStream,
5};
6
7pub enum StreamType {
8    Plain(TcpStream),
9    Tls(Box<tokio_rustls::server::TlsStream<TcpStream>>),
10}
11
12impl AsyncRead for StreamType {
13    fn poll_read(
14        mut self: std::pin::Pin<&mut Self>,
15        cx: &mut std::task::Context<'_>,
16        buf: &mut tokio::io::ReadBuf<'_>,
17    ) -> std::task::Poll<std::io::Result<()>> {
18        match &mut *self {
19            Self::Plain(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
20            Self::Tls(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
21        }
22    }
23}
24
25impl AsyncWrite for StreamType {
26    fn poll_write(
27        mut self: std::pin::Pin<&mut Self>,
28        cx: &mut std::task::Context<'_>,
29        buf: &[u8],
30    ) -> std::task::Poll<Result<usize, std::io::Error>> {
31        match &mut *self {
32            Self::Plain(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
33            Self::Tls(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
34        }
35    }
36
37    fn poll_flush(
38        mut self: std::pin::Pin<&mut Self>,
39        cx: &mut std::task::Context<'_>,
40    ) -> std::task::Poll<Result<(), std::io::Error>> {
41        match &mut *self {
42            Self::Plain(stream) => std::pin::Pin::new(stream).poll_flush(cx),
43            Self::Tls(stream) => std::pin::Pin::new(stream).poll_flush(cx),
44        }
45    }
46
47    fn poll_shutdown(
48        mut self: std::pin::Pin<&mut Self>,
49        cx: &mut std::task::Context<'_>,
50    ) -> std::task::Poll<Result<(), std::io::Error>> {
51        match &mut *self {
52            Self::Plain(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
53            Self::Tls(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
54        }
55    }
56}
57
58impl StreamType {
59    pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
60        match self {
61            Self::Plain(stream) => stream.peer_addr(),
62            Self::Tls(stream) => stream.get_ref().0.peer_addr(),
63        }
64    }
65}