Skip to main content

tf_rust_engineio/transports/
polling.rs

1use crate::error::{Error, Result};
2use crate::transport::Transport;
3use base64::{engine::general_purpose, Engine as _};
4use bytes::{BufMut, Bytes, BytesMut};
5use native_tls::TlsConnector;
6use reqwest::{
7    blocking::{Client, ClientBuilder},
8    header::HeaderMap,
9};
10use std::sync::{Arc, RwLock};
11use std::time::Duration;
12use url::Url;
13
14#[derive(Debug, Clone)]
15pub struct PollingTransport {
16    client: Arc<Client>,
17    base_url: Arc<RwLock<Url>>,
18}
19
20impl PollingTransport {
21    /// Creates an instance of `PollingTransport`.
22    pub fn new(
23        base_url: Url,
24        tls_config: Option<TlsConnector>,
25        opening_headers: Option<HeaderMap>,
26    ) -> Self {
27        let client = match (tls_config, opening_headers) {
28            (Some(config), Some(map)) => ClientBuilder::new()
29                .use_preconfigured_tls(config)
30                .default_headers(map)
31                .build()
32                .unwrap(),
33            (Some(config), None) => ClientBuilder::new()
34                .use_preconfigured_tls(config)
35                .build()
36                .unwrap(),
37            (None, Some(map)) => ClientBuilder::new().default_headers(map).build().unwrap(),
38            (None, None) => Client::new(),
39        };
40
41        let mut url = base_url;
42        url.query_pairs_mut().append_pair("transport", "polling");
43
44        PollingTransport {
45            client: Arc::new(client),
46            base_url: Arc::new(RwLock::new(url)),
47        }
48    }
49}
50
51impl Transport for PollingTransport {
52    fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
53        let data_to_send = if is_binary_att {
54            // the binary attachment gets `base64` encoded
55            let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
56            packet_bytes.put_u8(b'b');
57
58            let encoded_data = general_purpose::STANDARD.encode(data);
59            packet_bytes.put(encoded_data.as_bytes());
60
61            packet_bytes.freeze()
62        } else {
63            data
64        };
65        let response = self
66            .client
67            .post(self.address()?)
68            .body(data_to_send)
69            .send()?;
70
71        let status = response.status().as_u16();
72        if status != 200 {
73            return Err(match response.text() {
74                Ok(body) => Error::HttpErrorWithBody { status, body },
75                Err(_) => Error::IncompleteHttp(status),
76            });
77        }
78
79        Ok(())
80    }
81
82    fn poll(&self, timeout: Duration) -> Result<Bytes> {
83        let response = self.client.get(self.address()?).timeout(timeout).send()?;
84
85        let status = response.status().as_u16();
86        if status != 200 {
87            return Err(match response.text() {
88                Ok(body) => Error::HttpErrorWithBody { status, body },
89                Err(_) => Error::IncompleteHttp(status),
90            });
91        }
92
93        Ok(response.bytes()?)
94    }
95
96    fn base_url(&self) -> Result<Url> {
97        Ok(self.base_url.read()?.clone())
98    }
99
100    fn set_base_url(&self, base_url: Url) -> Result<()> {
101        let mut url = base_url;
102        if !url
103            .query_pairs()
104            .any(|(k, v)| k == "transport" && v == "polling")
105        {
106            url.query_pairs_mut().append_pair("transport", "polling");
107        }
108        *self.base_url.write()? = url;
109        Ok(())
110    }
111}
112
113#[cfg(test)]
114mod test {
115    use super::*;
116    use std::str::FromStr;
117
118    #[test]
119    fn polling_transport_emit_returns_http_error_with_body() {
120        let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
121        let url = crate::test::spawn_http_error_mock(400, body);
122        let transport = PollingTransport::new(url, None, None);
123
124        let err = transport
125            .emit(Bytes::from_static(b"hello"), false)
126            .expect_err("emit should fail when server returns 400");
127
128        match err {
129            Error::HttpErrorWithBody { status, body: got } => {
130                assert_eq!(status, 400);
131                assert_eq!(got, body);
132            }
133            other => panic!("expected HttpErrorWithBody, got: {other:?}"),
134        }
135    }
136
137    #[test]
138    fn polling_transport_poll_returns_http_error_with_body() {
139        let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
140        let url = crate::test::spawn_http_error_mock(400, body);
141        let transport = PollingTransport::new(url, None, None);
142
143        let err = transport
144            .poll(Duration::from_secs(2))
145            .expect_err("poll should fail when server returns 400");
146
147        match err {
148            Error::HttpErrorWithBody { status, body: got } => {
149                assert_eq!(status, 400);
150                assert_eq!(got, body);
151            }
152            other => panic!("expected HttpErrorWithBody, got: {other:?}"),
153        }
154    }
155
156    #[test]
157    fn polling_transport_base_url() -> Result<()> {
158        let url = crate::test::engine_io_server()?.to_string();
159        let transport = PollingTransport::new(Url::from_str(&url[..]).unwrap(), None, None);
160        assert_eq!(
161            transport.base_url()?.to_string(),
162            url.clone() + "?transport=polling"
163        );
164        transport.set_base_url(Url::parse("https://127.0.0.1")?)?;
165        assert_eq!(
166            transport.base_url()?.to_string(),
167            "https://127.0.0.1/?transport=polling"
168        );
169        assert_ne!(transport.base_url()?.to_string(), url);
170
171        transport.set_base_url(Url::parse("http://127.0.0.1/?transport=polling")?)?;
172        assert_eq!(
173            transport.base_url()?.to_string(),
174            "http://127.0.0.1/?transport=polling"
175        );
176        assert_ne!(transport.base_url()?.to_string(), url);
177        Ok(())
178    }
179
180    #[test]
181    fn transport_debug() -> Result<()> {
182        let mut url = crate::test::engine_io_server()?;
183        let transport =
184            PollingTransport::new(Url::from_str(&url.to_string()[..]).unwrap(), None, None);
185        url.query_pairs_mut().append_pair("transport", "polling");
186        assert_eq!(format!("PollingTransport {{ client: {:?}, base_url: RwLock {{ data: {:?}, poisoned: false, .. }} }}", transport.client, url), format!("{:?}", transport));
187        let test: Box<dyn Transport> = Box::new(transport);
188        assert_eq!(
189            format!("Transport(base_url: Ok({:?}))", url),
190            format!("{:?}", test)
191        );
192        Ok(())
193    }
194}