websockets_monoio/
client.rs1use anyhow::Result;
2use fastwebsockets_monoio::{Role, WebSocket};
3use monoio::net::TcpStream;
4use monoio_compat::{AsyncRead, AsyncWrite, StreamWrapper};
5
6use crate::http_upgrade::{generate_client_key, read_response, write_request};
7use crate::tls::{connect_wss, default_connector};
8use crate::url::{Scheme, parse_ws_or_wss};
9
10#[allow(clippy::large_enum_variant)]
13pub enum AnyStream {
14 Plain(StreamWrapper<TcpStream>),
15 Tls(StreamWrapper<monoio_rustls::ClientTlsStream<TcpStream>>),
16}
17
18impl monoio_compat::AsyncRead for AnyStream {
19 fn poll_read(
20 self: core::pin::Pin<&mut Self>,
21 cx: &mut core::task::Context<'_>,
22 buf: &mut tokio::io::ReadBuf<'_>,
23 ) -> core::task::Poll<std::io::Result<()>> {
24 unsafe {
25 match self.get_unchecked_mut() {
26 AnyStream::Plain(s) => core::pin::Pin::new_unchecked(s).poll_read(cx, buf),
27 AnyStream::Tls(s) => core::pin::Pin::new_unchecked(s).poll_read(cx, buf),
28 }
29 }
30 }
31}
32
33impl monoio_compat::AsyncWrite for AnyStream {
34 fn poll_write(
35 self: core::pin::Pin<&mut Self>,
36 cx: &mut core::task::Context<'_>,
37 buf: &[u8],
38 ) -> core::task::Poll<Result<usize, std::io::Error>> {
39 unsafe {
40 match self.get_unchecked_mut() {
41 AnyStream::Plain(s) => core::pin::Pin::new_unchecked(s).poll_write(cx, buf),
42 AnyStream::Tls(s) => core::pin::Pin::new_unchecked(s).poll_write(cx, buf),
43 }
44 }
45 }
46
47 fn poll_flush(
48 self: core::pin::Pin<&mut Self>,
49 cx: &mut core::task::Context<'_>,
50 ) -> core::task::Poll<Result<(), std::io::Error>> {
51 unsafe {
52 match self.get_unchecked_mut() {
53 AnyStream::Plain(s) => core::pin::Pin::new_unchecked(s).poll_flush(cx),
54 AnyStream::Tls(s) => core::pin::Pin::new_unchecked(s).poll_flush(cx),
55 }
56 }
57 }
58
59 fn poll_shutdown(
60 self: core::pin::Pin<&mut Self>,
61 cx: &mut core::task::Context<'_>,
62 ) -> core::task::Poll<Result<(), std::io::Error>> {
63 unsafe {
64 match self.get_unchecked_mut() {
65 AnyStream::Plain(s) => core::pin::Pin::new_unchecked(s).poll_shutdown(cx),
66 AnyStream::Tls(s) => core::pin::Pin::new_unchecked(s).poll_shutdown(cx),
67 }
68 }
69 }
70}
71
72pub type WsStream = AnyStream;
74
75pub struct WsClient {
76 pub ws: WebSocket<WsStream>,
77}
78
79impl WsClient {
80 pub async fn connect(url: &str, extra_headers: &[(&str, &str)]) -> Result<Self> {
82 let u = parse_ws_or_wss(url)?;
83
84 let mut stream = match u.scheme {
86 Scheme::Ws => {
87 let tcp = TcpStream::connect((u.host, u.port)).await?;
88 AnyStream::Plain(StreamWrapper::new(tcp))
89 }
90 Scheme::Wss => {
91 let connector = default_connector();
92 let tls = connect_wss(u.host, u.port, connector).await?;
93 AnyStream::Tls(StreamWrapper::new(tls))
94 }
95 };
96
97 let key = generate_client_key();
99 write_request(
100 &mut stream,
101 u.host,
102 u.path_and_query,
103 &key.sec_websocket_key,
104 extra_headers,
105 )
106 .await?;
107 read_response(&mut stream, &key.expected_accept).await?;
108
109 let mut ws = WebSocket::after_handshake(stream, Role::Client);
111 ws.set_auto_close(true);
112 ws.set_auto_pong(true);
113 if matches!(u.scheme, Scheme::Wss) {
114 ws.set_writev(false);
116 }
117
118 Ok(Self { ws })
119 }
120
121 pub fn into_inner(self) -> WebSocket<WsStream> {
122 self.ws
123 }
124}
125
126pub trait TokioIo: AsyncRead + AsyncWrite + Unpin {}
128impl<T: AsyncRead + AsyncWrite + Unpin> TokioIo for T {}