Skip to main content

tf_rust_engineio/asynchronous/async_transports/
websocket.rs

1use std::fmt::Debug;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use crate::asynchronous::transport::AsyncTransport;
6use crate::error::Result;
7use async_trait::async_trait;
8use bytes::Bytes;
9use futures_util::stream::StreamExt;
10use futures_util::Stream;
11use http::HeaderMap;
12use tokio::sync::RwLock;
13use tokio_tungstenite::connect_async;
14use tungstenite::client::IntoClientRequest;
15use url::Url;
16
17use super::websocket_general::AsyncWebsocketGeneralTransport;
18
19/// An asynchronous websocket transport type.
20/// This type only allows for plain websocket
21/// connections ("ws://").
22#[derive(Clone)]
23pub struct WebsocketTransport {
24    inner: AsyncWebsocketGeneralTransport,
25    base_url: Arc<RwLock<Url>>,
26}
27
28impl WebsocketTransport {
29    /// Creates a new instance over a request that might hold additional headers and an URL.
30    pub async fn new(base_url: Url, headers: Option<HeaderMap>) -> Result<Self> {
31        let mut url = base_url;
32        url.query_pairs_mut().append_pair("transport", "websocket");
33        url.set_scheme("ws").unwrap();
34
35        let mut req = url.clone().into_client_request()?;
36        if let Some(map) = headers {
37            // SAFETY: this unwrap never panics as the underlying request is just initialized and in proper state
38            req.headers_mut().extend(map);
39        }
40
41        let (ws_stream, _) = connect_async(req).await?;
42        let (sen, rec) = ws_stream.split();
43
44        let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;
45        Ok(WebsocketTransport {
46            inner,
47            base_url: Arc::new(RwLock::new(url)),
48        })
49    }
50
51    /// Sends probe packet to ensure connection is valid, then sends upgrade
52    /// request
53    pub(crate) async fn upgrade(&self) -> Result<()> {
54        self.inner.upgrade().await
55    }
56
57    pub(crate) async fn poll_next(&self) -> Result<Option<Bytes>> {
58        self.inner.poll_next().await
59    }
60}
61
62#[async_trait]
63impl AsyncTransport for WebsocketTransport {
64    async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
65        self.inner.emit(data, is_binary_att).await
66    }
67
68    async fn base_url(&self) -> Result<Url> {
69        Ok(self.base_url.read().await.clone())
70    }
71
72    async fn set_base_url(&self, base_url: Url) -> Result<()> {
73        let mut url = base_url;
74        if !url
75            .query_pairs()
76            .any(|(k, v)| k == "transport" && v == "websocket")
77        {
78            url.query_pairs_mut().append_pair("transport", "websocket");
79        }
80        url.set_scheme("ws").unwrap();
81        *self.base_url.write().await = url;
82        Ok(())
83    }
84}
85
86impl Stream for WebsocketTransport {
87    type Item = Result<Bytes>;
88
89    fn poll_next(
90        mut self: Pin<&mut Self>,
91        cx: &mut std::task::Context<'_>,
92    ) -> std::task::Poll<Option<Self::Item>> {
93        self.inner.poll_next_unpin(cx)
94    }
95}
96
97impl Debug for WebsocketTransport {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("AsyncWebsocketTransport")
100            .field(
101                "base_url",
102                &self
103                    .base_url
104                    .try_read()
105                    .map_or("Currently not available".to_owned(), |url| url.to_string()),
106            )
107            .finish()
108    }
109}
110
111#[cfg(test)]
112mod test {
113    use super::*;
114    use crate::ENGINE_IO_VERSION;
115    use std::str::FromStr;
116
117    async fn new() -> Result<WebsocketTransport> {
118        let url = crate::test::engine_io_server()?.to_string()
119            + "engine.io/?EIO="
120            + &ENGINE_IO_VERSION.to_string();
121        WebsocketTransport::new(Url::from_str(&url[..])?, None).await
122    }
123
124    #[tokio::test]
125    async fn websocket_transport_base_url() -> Result<()> {
126        let transport = new().await?;
127        let mut url = crate::test::engine_io_server()?;
128        url.set_path("/engine.io/");
129        url.query_pairs_mut()
130            .append_pair("EIO", &ENGINE_IO_VERSION.to_string())
131            .append_pair("transport", "websocket");
132        url.set_scheme("ws").unwrap();
133        assert_eq!(transport.base_url().await?.to_string(), url.to_string());
134        transport
135            .set_base_url(reqwest::Url::parse("https://127.0.0.1")?)
136            .await?;
137        assert_eq!(
138            transport.base_url().await?.to_string(),
139            "ws://127.0.0.1/?transport=websocket"
140        );
141        assert_ne!(transport.base_url().await?.to_string(), url.to_string());
142
143        transport
144            .set_base_url(reqwest::Url::parse(
145                "http://127.0.0.1/?transport=websocket",
146            )?)
147            .await?;
148        assert_eq!(
149            transport.base_url().await?.to_string(),
150            "ws://127.0.0.1/?transport=websocket"
151        );
152        assert_ne!(transport.base_url().await?.to_string(), url.to_string());
153        Ok(())
154    }
155
156    #[tokio::test]
157    async fn websocket_secure_debug() -> Result<()> {
158        let mut transport = new().await?;
159        assert_eq!(
160            format!("{:?}", transport),
161            format!(
162                "AsyncWebsocketTransport {{ base_url: {:?} }}",
163                transport.base_url().await?.to_string()
164            )
165        );
166        println!("{:?}", transport.next().await.unwrap());
167        println!("{:?}", transport.next().await.unwrap());
168        Ok(())
169    }
170
171    /// Spawns a minimal in-process websocket server that accepts a single
172    /// connection, immediately closes it with the A2C-SMCP version-handshake
173    /// rejection close code (4900, RFC6455 private range) and keeps the socket
174    /// open until the client has read the frame. Returns the bound address.
175    async fn spawn_close_4900_server() -> std::net::SocketAddr {
176        use futures_util::SinkExt;
177        use tokio::net::TcpListener;
178        use tokio_tungstenite::accept_async;
179        use tungstenite::protocol::frame::{coding::CloseCode, CloseFrame};
180        use tungstenite::Message;
181
182        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
183        let addr = listener.local_addr().unwrap();
184        tokio::spawn(async move {
185            if let Ok((stream, _)) = listener.accept().await {
186                if let Ok(mut ws) = accept_async(stream).await {
187                    let _ = ws
188                        .send(Message::Close(Some(CloseFrame {
189                            code: CloseCode::Library(4900),
190                            reason: "version handshake rejected".into(),
191                        })))
192                        .await;
193                    // keep the connection alive until the client has read the
194                    // close frame and disconnected
195                    while let Some(Ok(_)) = ws.next().await {}
196                }
197            }
198        });
199        addr
200    }
201
202    /// A close frame's numeric code (incl. the RFC6455 private range 4000-4999)
203    /// must be surfaced through `poll_next` rather than silently dropped.
204    #[tokio::test]
205    async fn websocket_close_code_surfaced_via_poll_next() -> Result<()> {
206        let addr = spawn_close_4900_server().await;
207        let url = Url::parse(&format!("ws://{}/engine.io/", addr))?;
208        let transport = WebsocketTransport::new(url, None).await?;
209
210        match transport.poll_next().await {
211            Err(crate::error::Error::WebsocketClosed { code, reason }) => {
212                assert_eq!(code, 4900);
213                assert_eq!(reason, "version handshake rejected");
214            }
215            other => panic!("expected Err(WebsocketClosed {{ code: 4900, .. }}), got {other:?}"),
216        }
217        Ok(())
218    }
219
220    /// When the server rejects a WS-only connection at the handshake phase by
221    /// closing with code 4900, the error must propagate out of
222    /// `build_websocket()` so consumers can read it from the `connect()` path.
223    #[tokio::test]
224    async fn websocket_close_code_surfaced_via_build_handshake() -> Result<()> {
225        use crate::asynchronous::ClientBuilder;
226
227        let addr = spawn_close_4900_server().await;
228        let url = Url::parse(&format!("ws://{}/", addr))?;
229
230        match ClientBuilder::new(url).build_websocket().await {
231            Err(crate::error::Error::WebsocketClosed { code, reason }) => {
232                assert_eq!(code, 4900);
233                assert_eq!(reason, "version handshake rejected");
234            }
235            Err(other) => {
236                panic!("expected WebsocketClosed {{ code: 4900, .. }}, got Err({other:?})")
237            }
238            Ok(_) => panic!("expected build_websocket to fail with WebsocketClosed {{ 4900 }}"),
239        }
240        Ok(())
241    }
242
243    /// The production async path drives the transport through its `Stream` impl
244    /// (`StreamExt::next`), not the inherent `poll_next`. Cover that branch
245    /// directly so the shared close-capture logic in
246    /// `AsyncWebsocketGeneralTransport` (reused by both ws and wss) is guarded.
247    #[tokio::test]
248    async fn websocket_close_code_surfaced_via_stream_impl() -> Result<()> {
249        let addr = spawn_close_4900_server().await;
250        let url = Url::parse(&format!("ws://{}/engine.io/", addr))?;
251        let mut transport = WebsocketTransport::new(url, None).await?;
252
253        match transport.next().await {
254            Some(Err(crate::error::Error::WebsocketClosed { code, reason })) => {
255                assert_eq!(code, 4900);
256                assert_eq!(reason, "version handshake rejected");
257            }
258            other => {
259                panic!("expected Some(Err(WebsocketClosed {{ code: 4900, .. }})), got {other:?}")
260            }
261        }
262        Ok(())
263    }
264}