websockets_monoio/
client.rs

1use anyhow::Result;
2use fastwebsockets_monoio::{Role, WebSocket};
3use monoio::net::TcpStream;
4use monoio_compat::{AsyncRead, AsyncWrite, StreamWrapper};
5
6use crate::http_upgrade::{generate_client_key, read_response, write_request};
7use crate::tls::{connect_wss, default_connector};
8use crate::url::{Scheme, parse_ws_or_wss};
9
10/// A unified IO stream that can be plain TCP or TLS over TCP, both wrapped
11/// in `monoio_compat::StreamWrapper` to provide AsyncRead/AsyncWrite.
12#[allow(clippy::large_enum_variant)]
13pub enum AnyStream {
14    Plain(StreamWrapper<TcpStream>),
15    Tls(StreamWrapper<monoio_rustls::ClientTlsStream<TcpStream>>),
16}
17
18impl monoio_compat::AsyncRead for AnyStream {
19    fn poll_read(
20        self: core::pin::Pin<&mut Self>,
21        cx: &mut core::task::Context<'_>,
22        buf: &mut tokio::io::ReadBuf<'_>,
23    ) -> core::task::Poll<std::io::Result<()>> {
24        unsafe {
25            match self.get_unchecked_mut() {
26                AnyStream::Plain(s) => core::pin::Pin::new_unchecked(s).poll_read(cx, buf),
27                AnyStream::Tls(s) => core::pin::Pin::new_unchecked(s).poll_read(cx, buf),
28            }
29        }
30    }
31}
32
33impl monoio_compat::AsyncWrite for AnyStream {
34    fn poll_write(
35        self: core::pin::Pin<&mut Self>,
36        cx: &mut core::task::Context<'_>,
37        buf: &[u8],
38    ) -> core::task::Poll<Result<usize, std::io::Error>> {
39        unsafe {
40            match self.get_unchecked_mut() {
41                AnyStream::Plain(s) => core::pin::Pin::new_unchecked(s).poll_write(cx, buf),
42                AnyStream::Tls(s) => core::pin::Pin::new_unchecked(s).poll_write(cx, buf),
43            }
44        }
45    }
46
47    fn poll_flush(
48        self: core::pin::Pin<&mut Self>,
49        cx: &mut core::task::Context<'_>,
50    ) -> core::task::Poll<Result<(), std::io::Error>> {
51        unsafe {
52            match self.get_unchecked_mut() {
53                AnyStream::Plain(s) => core::pin::Pin::new_unchecked(s).poll_flush(cx),
54                AnyStream::Tls(s) => core::pin::Pin::new_unchecked(s).poll_flush(cx),
55            }
56        }
57    }
58
59    fn poll_shutdown(
60        self: core::pin::Pin<&mut Self>,
61        cx: &mut core::task::Context<'_>,
62    ) -> core::task::Poll<Result<(), std::io::Error>> {
63        unsafe {
64            match self.get_unchecked_mut() {
65                AnyStream::Plain(s) => core::pin::Pin::new_unchecked(s).poll_shutdown(cx),
66                AnyStream::Tls(s) => core::pin::Pin::new_unchecked(s).poll_shutdown(cx),
67            }
68        }
69    }
70}
71
72/// Exposed stream type used by `WsClient`.
73pub type WsStream = AnyStream;
74
75pub struct WsClient {
76    pub ws: WebSocket<WsStream>,
77}
78
79impl WsClient {
80    /// Connect to a `ws://` or `wss://` URL and complete the WebSocket handshake.
81    pub async fn connect(url: &str, extra_headers: &[(&str, &str)]) -> Result<Self> {
82        let u = parse_ws_or_wss(url)?;
83
84        // Establish underlying transport (TCP or TLS over TCP)
85        let mut stream = match u.scheme {
86            Scheme::Ws => {
87                let tcp = TcpStream::connect((u.host, u.port)).await?;
88                AnyStream::Plain(StreamWrapper::new(tcp))
89            }
90            Scheme::Wss => {
91                let connector = default_connector();
92                let tls = connect_wss(u.host, u.port, connector).await?;
93                AnyStream::Tls(StreamWrapper::new(tls))
94            }
95        };
96
97        // HTTP Upgrade handshake
98        let key = generate_client_key();
99        write_request(
100            &mut stream,
101            u.host,
102            u.path_and_query,
103            &key.sec_websocket_key,
104            extra_headers,
105        )
106        .await?;
107        read_response(&mut stream, &key.expected_accept).await?;
108
109        // Switch to WebSocket
110        let mut ws = WebSocket::after_handshake(stream, Role::Client);
111        ws.set_auto_close(true);
112        ws.set_auto_pong(true);
113        if matches!(u.scheme, Scheme::Wss) {
114            // TLS backends generally buffer writes, so gathering is less effective.
115            ws.set_writev(false);
116        }
117
118        Ok(Self { ws })
119    }
120
121    pub fn into_inner(self) -> WebSocket<WsStream> {
122        self.ws
123    }
124}
125
126// Convenience trait bound if you want to reuse upgrade for different streams.
127pub trait TokioIo: AsyncRead + AsyncWrite + Unpin {}
128impl<T: AsyncRead + AsyncWrite + Unpin> TokioIo for T {}