Skip to main content

tf_rust_engineio/transports/
websocket.rs

1use crate::{
2    asynchronous::{
3        async_transports::WebsocketTransport as AsyncWebsocketTransport, transport::AsyncTransport,
4    },
5    error::Result,
6    transport::Transport,
7    Error,
8};
9use bytes::Bytes;
10use http::HeaderMap;
11use std::{sync::Arc, time::Duration};
12use tokio::runtime::Runtime;
13use url::Url;
14
15#[derive(Clone)]
16pub struct WebsocketTransport {
17    runtime: Arc<Runtime>,
18    inner: Arc<AsyncWebsocketTransport>,
19}
20
21impl WebsocketTransport {
22    /// Creates an instance of `WebsocketTransport`.
23    pub fn new(base_url: Url, headers: Option<HeaderMap>) -> Result<Self> {
24        let runtime = tokio::runtime::Builder::new_current_thread()
25            .enable_all()
26            .build()?;
27
28        let inner = runtime.block_on(AsyncWebsocketTransport::new(base_url, headers))?;
29
30        Ok(WebsocketTransport {
31            runtime: Arc::new(runtime),
32            inner: Arc::new(inner),
33        })
34    }
35
36    /// Sends probe packet to ensure connection is valid, then sends upgrade
37    /// request
38    pub(crate) fn upgrade(&self) -> Result<()> {
39        self.runtime.block_on(async { self.inner.upgrade().await })
40    }
41}
42
43impl Transport for WebsocketTransport {
44    fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
45        self.runtime
46            .block_on(async { self.inner.emit(data, is_binary_att).await })
47    }
48
49    fn poll(&self, timeout: Duration) -> Result<Bytes> {
50        self.runtime.block_on(async {
51            let r = match tokio::time::timeout(timeout, self.inner.poll_next()).await {
52                Ok(r) => r,
53                Err(_) => return Err(Error::PingTimeout()),
54            };
55            match r {
56                Ok(b) => b.ok_or(Error::IncompletePacket()),
57                // propagate the real transport error (e.g. `WebsocketClosed`)
58                // instead of masking it, so the close code is not lost
59                Err(e) => Err(e),
60            }
61        })
62    }
63
64    fn base_url(&self) -> Result<url::Url> {
65        self.runtime.block_on(async { self.inner.base_url().await })
66    }
67
68    fn set_base_url(&self, url: url::Url) -> Result<()> {
69        self.runtime
70            .block_on(async { self.inner.set_base_url(url).await })
71    }
72}
73
74impl std::fmt::Debug for WebsocketTransport {
75    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
76        f.write_fmt(format_args!(
77            "WebsocketTransport(base_url: {:?})",
78            self.base_url(),
79        ))
80    }
81}
82
83#[cfg(test)]
84mod test {
85    use super::*;
86    use crate::ENGINE_IO_VERSION;
87    use std::str::FromStr;
88
89    const TIMEOUT_DURATION: Duration = Duration::from_secs(45);
90
91    fn new() -> Result<WebsocketTransport> {
92        let url = crate::test::engine_io_server()?.to_string()
93            + "engine.io/?EIO="
94            + &ENGINE_IO_VERSION.to_string();
95        WebsocketTransport::new(Url::from_str(&url[..])?, None)
96    }
97
98    #[test]
99    fn websocket_transport_base_url() -> Result<()> {
100        let transport = new()?;
101        let mut url = crate::test::engine_io_server()?;
102        url.set_path("/engine.io/");
103        url.query_pairs_mut()
104            .append_pair("EIO", &ENGINE_IO_VERSION.to_string())
105            .append_pair("transport", "websocket");
106        url.set_scheme("ws").unwrap();
107        assert_eq!(transport.base_url()?.to_string(), url.to_string());
108        transport.set_base_url(reqwest::Url::parse("https://127.0.0.1")?)?;
109        assert_eq!(
110            transport.base_url()?.to_string(),
111            "ws://127.0.0.1/?transport=websocket"
112        );
113        assert_ne!(transport.base_url()?.to_string(), url.to_string());
114
115        transport.set_base_url(reqwest::Url::parse(
116            "http://127.0.0.1/?transport=websocket",
117        )?)?;
118        assert_eq!(
119            transport.base_url()?.to_string(),
120            "ws://127.0.0.1/?transport=websocket"
121        );
122        assert_ne!(transport.base_url()?.to_string(), url.to_string());
123        Ok(())
124    }
125
126    #[test]
127    fn websocket_secure_debug() -> Result<()> {
128        let transport = new()?;
129        assert_eq!(
130            format!("{:?}", transport),
131            format!("WebsocketTransport(base_url: {:?})", transport.base_url())
132        );
133        println!("{:?}", transport.poll(TIMEOUT_DURATION).unwrap());
134        println!("{:?}", transport.poll(TIMEOUT_DURATION).unwrap());
135        Ok(())
136    }
137}