websocket_server_async/
stream.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7
8#[non_exhaustive]
9#[derive(Debug)]
10pub enum MaybeRustlsStream<S> {
11    Plain(S),
12    ServerTls(tokio_rustls::server::TlsStream<S>),
13}
14
15impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeRustlsStream<S> {
16    fn poll_read(
17        self: Pin<&mut Self>,
18        cx: &mut Context<'_>,
19        buf: &mut ReadBuf<'_>,
20    ) -> Poll<std::io::Result<()>> {
21        match self.get_mut() {
22            MaybeRustlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
23            MaybeRustlsStream::ServerTls(ref mut s) => Pin::new(s).poll_read(cx, buf),
24        }
25    }
26}
27
28impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeRustlsStream<S> {
29    fn poll_write(
30        self: Pin<&mut Self>,
31        cx: &mut Context<'_>,
32        buf: &[u8],
33    ) -> Poll<Result<usize, std::io::Error>> {
34        match self.get_mut() {
35            MaybeRustlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
36            MaybeRustlsStream::ServerTls(ref mut s) => Pin::new(s).poll_write(cx, buf),
37        }
38    }
39
40    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
41        match self.get_mut() {
42            MaybeRustlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
43            MaybeRustlsStream::ServerTls(ref mut s) => Pin::new(s).poll_flush(cx),
44        }
45    }
46
47    fn poll_shutdown(
48        self: Pin<&mut Self>,
49        cx: &mut Context<'_>,
50    ) -> Poll<Result<(), std::io::Error>> {
51        match self.get_mut() {
52            MaybeRustlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
53            MaybeRustlsStream::ServerTls(ref mut s) => Pin::new(s).poll_shutdown(cx),
54        }
55    }
56}