tf_rust_engineio/asynchronous/async_transports/
websocket.rs1use std::fmt::Debug;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use crate::asynchronous::transport::AsyncTransport;
6use crate::error::Result;
7use async_trait::async_trait;
8use bytes::Bytes;
9use futures_util::stream::StreamExt;
10use futures_util::Stream;
11use http::HeaderMap;
12use tokio::sync::RwLock;
13use tokio_tungstenite::connect_async;
14use tungstenite::client::IntoClientRequest;
15use url::Url;
16
17use super::websocket_general::AsyncWebsocketGeneralTransport;
18
19#[derive(Clone)]
23pub struct WebsocketTransport {
24 inner: AsyncWebsocketGeneralTransport,
25 base_url: Arc<RwLock<Url>>,
26}
27
28impl WebsocketTransport {
29 pub async fn new(base_url: Url, headers: Option<HeaderMap>) -> Result<Self> {
31 let mut url = base_url;
32 url.query_pairs_mut().append_pair("transport", "websocket");
33 url.set_scheme("ws").unwrap();
34
35 let mut req = url.clone().into_client_request()?;
36 if let Some(map) = headers {
37 req.headers_mut().extend(map);
39 }
40
41 let (ws_stream, _) = connect_async(req).await?;
42 let (sen, rec) = ws_stream.split();
43
44 let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;
45 Ok(WebsocketTransport {
46 inner,
47 base_url: Arc::new(RwLock::new(url)),
48 })
49 }
50
51 pub(crate) async fn upgrade(&self) -> Result<()> {
54 self.inner.upgrade().await
55 }
56
57 pub(crate) async fn poll_next(&self) -> Result<Option<Bytes>> {
58 self.inner.poll_next().await
59 }
60}
61
62#[async_trait]
63impl AsyncTransport for WebsocketTransport {
64 async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
65 self.inner.emit(data, is_binary_att).await
66 }
67
68 async fn base_url(&self) -> Result<Url> {
69 Ok(self.base_url.read().await.clone())
70 }
71
72 async fn set_base_url(&self, base_url: Url) -> Result<()> {
73 let mut url = base_url;
74 if !url
75 .query_pairs()
76 .any(|(k, v)| k == "transport" && v == "websocket")
77 {
78 url.query_pairs_mut().append_pair("transport", "websocket");
79 }
80 url.set_scheme("ws").unwrap();
81 *self.base_url.write().await = url;
82 Ok(())
83 }
84}
85
86impl Stream for WebsocketTransport {
87 type Item = Result<Bytes>;
88
89 fn poll_next(
90 mut self: Pin<&mut Self>,
91 cx: &mut std::task::Context<'_>,
92 ) -> std::task::Poll<Option<Self::Item>> {
93 self.inner.poll_next_unpin(cx)
94 }
95}
96
97impl Debug for WebsocketTransport {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("AsyncWebsocketTransport")
100 .field(
101 "base_url",
102 &self
103 .base_url
104 .try_read()
105 .map_or("Currently not available".to_owned(), |url| url.to_string()),
106 )
107 .finish()
108 }
109}
110
111#[cfg(test)]
112mod test {
113 use super::*;
114 use crate::ENGINE_IO_VERSION;
115 use std::str::FromStr;
116
117 async fn new() -> Result<WebsocketTransport> {
118 let url = crate::test::engine_io_server()?.to_string()
119 + "engine.io/?EIO="
120 + &ENGINE_IO_VERSION.to_string();
121 WebsocketTransport::new(Url::from_str(&url[..])?, None).await
122 }
123
124 #[tokio::test]
125 async fn websocket_transport_base_url() -> Result<()> {
126 let transport = new().await?;
127 let mut url = crate::test::engine_io_server()?;
128 url.set_path("/engine.io/");
129 url.query_pairs_mut()
130 .append_pair("EIO", &ENGINE_IO_VERSION.to_string())
131 .append_pair("transport", "websocket");
132 url.set_scheme("ws").unwrap();
133 assert_eq!(transport.base_url().await?.to_string(), url.to_string());
134 transport
135 .set_base_url(reqwest::Url::parse("https://127.0.0.1")?)
136 .await?;
137 assert_eq!(
138 transport.base_url().await?.to_string(),
139 "ws://127.0.0.1/?transport=websocket"
140 );
141 assert_ne!(transport.base_url().await?.to_string(), url.to_string());
142
143 transport
144 .set_base_url(reqwest::Url::parse(
145 "http://127.0.0.1/?transport=websocket",
146 )?)
147 .await?;
148 assert_eq!(
149 transport.base_url().await?.to_string(),
150 "ws://127.0.0.1/?transport=websocket"
151 );
152 assert_ne!(transport.base_url().await?.to_string(), url.to_string());
153 Ok(())
154 }
155
156 #[tokio::test]
157 async fn websocket_secure_debug() -> Result<()> {
158 let mut transport = new().await?;
159 assert_eq!(
160 format!("{:?}", transport),
161 format!(
162 "AsyncWebsocketTransport {{ base_url: {:?} }}",
163 transport.base_url().await?.to_string()
164 )
165 );
166 println!("{:?}", transport.next().await.unwrap());
167 println!("{:?}", transport.next().await.unwrap());
168 Ok(())
169 }
170
171 async fn spawn_close_4900_server() -> std::net::SocketAddr {
176 use futures_util::SinkExt;
177 use tokio::net::TcpListener;
178 use tokio_tungstenite::accept_async;
179 use tungstenite::protocol::frame::{coding::CloseCode, CloseFrame};
180 use tungstenite::Message;
181
182 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
183 let addr = listener.local_addr().unwrap();
184 tokio::spawn(async move {
185 if let Ok((stream, _)) = listener.accept().await {
186 if let Ok(mut ws) = accept_async(stream).await {
187 let _ = ws
188 .send(Message::Close(Some(CloseFrame {
189 code: CloseCode::Library(4900),
190 reason: "version handshake rejected".into(),
191 })))
192 .await;
193 while let Some(Ok(_)) = ws.next().await {}
196 }
197 }
198 });
199 addr
200 }
201
202 #[tokio::test]
205 async fn websocket_close_code_surfaced_via_poll_next() -> Result<()> {
206 let addr = spawn_close_4900_server().await;
207 let url = Url::parse(&format!("ws://{}/engine.io/", addr))?;
208 let transport = WebsocketTransport::new(url, None).await?;
209
210 match transport.poll_next().await {
211 Err(crate::error::Error::WebsocketClosed { code, reason }) => {
212 assert_eq!(code, 4900);
213 assert_eq!(reason, "version handshake rejected");
214 }
215 other => panic!("expected Err(WebsocketClosed {{ code: 4900, .. }}), got {other:?}"),
216 }
217 Ok(())
218 }
219
220 #[tokio::test]
224 async fn websocket_close_code_surfaced_via_build_handshake() -> Result<()> {
225 use crate::asynchronous::ClientBuilder;
226
227 let addr = spawn_close_4900_server().await;
228 let url = Url::parse(&format!("ws://{}/", addr))?;
229
230 match ClientBuilder::new(url).build_websocket().await {
231 Err(crate::error::Error::WebsocketClosed { code, reason }) => {
232 assert_eq!(code, 4900);
233 assert_eq!(reason, "version handshake rejected");
234 }
235 Err(other) => {
236 panic!("expected WebsocketClosed {{ code: 4900, .. }}, got Err({other:?})")
237 }
238 Ok(_) => panic!("expected build_websocket to fail with WebsocketClosed {{ 4900 }}"),
239 }
240 Ok(())
241 }
242
243 #[tokio::test]
248 async fn websocket_close_code_surfaced_via_stream_impl() -> Result<()> {
249 let addr = spawn_close_4900_server().await;
250 let url = Url::parse(&format!("ws://{}/engine.io/", addr))?;
251 let mut transport = WebsocketTransport::new(url, None).await?;
252
253 match transport.next().await {
254 Some(Err(crate::error::Error::WebsocketClosed { code, reason })) => {
255 assert_eq!(code, 4900);
256 assert_eq!(reason, "version handshake rejected");
257 }
258 other => {
259 panic!("expected Some(Err(WebsocketClosed {{ code: 4900, .. }})), got {other:?}")
260 }
261 }
262 Ok(())
263 }
264}