websocket_relay/
stream.rs1use std::net::SocketAddr;
2use tokio::{
3 io::{AsyncRead, AsyncWrite},
4 net::TcpStream,
5};
6
7pub enum StreamType {
8 Plain(TcpStream),
9 Tls(Box<tokio_rustls::server::TlsStream<TcpStream>>),
10}
11
12impl AsyncRead for StreamType {
13 fn poll_read(
14 mut self: std::pin::Pin<&mut Self>,
15 cx: &mut std::task::Context<'_>,
16 buf: &mut tokio::io::ReadBuf<'_>,
17 ) -> std::task::Poll<std::io::Result<()>> {
18 match &mut *self {
19 Self::Plain(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
20 Self::Tls(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
21 }
22 }
23}
24
25impl AsyncWrite for StreamType {
26 fn poll_write(
27 mut self: std::pin::Pin<&mut Self>,
28 cx: &mut std::task::Context<'_>,
29 buf: &[u8],
30 ) -> std::task::Poll<Result<usize, std::io::Error>> {
31 match &mut *self {
32 Self::Plain(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
33 Self::Tls(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
34 }
35 }
36
37 fn poll_flush(
38 mut self: std::pin::Pin<&mut Self>,
39 cx: &mut std::task::Context<'_>,
40 ) -> std::task::Poll<Result<(), std::io::Error>> {
41 match &mut *self {
42 Self::Plain(stream) => std::pin::Pin::new(stream).poll_flush(cx),
43 Self::Tls(stream) => std::pin::Pin::new(stream).poll_flush(cx),
44 }
45 }
46
47 fn poll_shutdown(
48 mut self: std::pin::Pin<&mut Self>,
49 cx: &mut std::task::Context<'_>,
50 ) -> std::task::Poll<Result<(), std::io::Error>> {
51 match &mut *self {
52 Self::Plain(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
53 Self::Tls(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
54 }
55 }
56}
57
58impl StreamType {
59 pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
60 match self {
61 Self::Plain(stream) => stream.peer_addr(),
62 Self::Tls(stream) => stream.get_ref().0.peer_addr(),
63 }
64 }
65}