Skip to main content

tf_rust_engineio/asynchronous/async_transports/
polling.rs

1use adler32::adler32;
2use async_stream::try_stream;
3use async_trait::async_trait;
4use base64::{engine::general_purpose, Engine as _};
5use bytes::{BufMut, Bytes, BytesMut};
6use futures_util::{Stream, StreamExt};
7use http::HeaderMap;
8use native_tls::TlsConnector;
9use reqwest::{Client, ClientBuilder, Response};
10use std::fmt::Debug;
11use std::time::SystemTime;
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use url::Url;
15
16use crate::asynchronous::generator::StreamGenerator;
17use crate::{asynchronous::transport::AsyncTransport, error::Result, Error};
18
19/// An asynchronous polling type. Makes use of the nonblocking reqwest types and
20/// methods.
21#[derive(Clone)]
22pub struct PollingTransport {
23    client: Client,
24    base_url: Arc<RwLock<Url>>,
25    generator: StreamGenerator<Bytes>,
26}
27
28impl PollingTransport {
29    pub fn new(
30        base_url: Url,
31        tls_config: Option<TlsConnector>,
32        opening_headers: Option<HeaderMap>,
33    ) -> Self {
34        let client = match (tls_config, opening_headers) {
35            (Some(config), Some(map)) => ClientBuilder::new()
36                .use_preconfigured_tls(config)
37                .default_headers(map)
38                .build()
39                .unwrap(),
40            (Some(config), None) => ClientBuilder::new()
41                .use_preconfigured_tls(config)
42                .build()
43                .unwrap(),
44            (None, Some(map)) => ClientBuilder::new().default_headers(map).build().unwrap(),
45            (None, None) => Client::new(),
46        };
47
48        let mut url = base_url;
49        url.query_pairs_mut().append_pair("transport", "polling");
50
51        PollingTransport {
52            client: client.clone(),
53            base_url: Arc::new(RwLock::new(url.clone())),
54            generator: StreamGenerator::new(Self::stream(url, client)),
55        }
56    }
57
58    fn address(mut url: Url) -> Result<Url> {
59        let reader = format!("{:#?}", SystemTime::now());
60        let hash = adler32(reader.as_bytes()).unwrap();
61        url.query_pairs_mut().append_pair("t", &hash.to_string());
62        Ok(url)
63    }
64
65    fn send_request(url: Url, client: Client) -> impl Stream<Item = Result<Response>> {
66        try_stream! {
67            let address = Self::address(url);
68
69            let response = client
70                .get(address?)
71                .send().await?;
72
73            let status = response.status().as_u16();
74            if status != 200 {
75                let err = match response.text().await {
76                    Ok(body) => Error::HttpErrorWithBody { status, body },
77                    Err(_) => Error::IncompleteHttp(status),
78                };
79                Err(err)?;
80                unreachable!();
81            }
82
83            yield response
84        }
85    }
86
87    fn stream(
88        url: Url,
89        client: Client,
90    ) -> Pin<Box<dyn Stream<Item = Result<Bytes>> + 'static + Send>> {
91        Box::pin(try_stream! {
92            loop {
93                for await elem in Self::send_request(url.clone(), client.clone()) {
94                    for await bytes in elem?.bytes_stream() {
95                        yield bytes?;
96                    }
97                }
98            }
99        })
100    }
101}
102
103impl Stream for PollingTransport {
104    type Item = Result<Bytes>;
105
106    fn poll_next(
107        mut self: Pin<&mut Self>,
108        cx: &mut std::task::Context<'_>,
109    ) -> std::task::Poll<Option<Self::Item>> {
110        self.generator.poll_next_unpin(cx)
111    }
112}
113
114#[async_trait]
115impl AsyncTransport for PollingTransport {
116    async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
117        let data_to_send = if is_binary_att {
118            // the binary attachment gets `base64` encoded
119            let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
120            packet_bytes.put_u8(b'b');
121
122            let encoded_data = general_purpose::STANDARD.encode(data);
123            packet_bytes.put(encoded_data.as_bytes());
124
125            packet_bytes.freeze()
126        } else {
127            data
128        };
129
130        let response = self
131            .client
132            .post(self.address().await?)
133            .body(data_to_send)
134            .send()
135            .await?;
136
137        let status = response.status().as_u16();
138        if status != 200 {
139            return Err(match response.text().await {
140                Ok(body) => Error::HttpErrorWithBody { status, body },
141                Err(_) => Error::IncompleteHttp(status),
142            });
143        }
144
145        Ok(())
146    }
147
148    async fn base_url(&self) -> Result<Url> {
149        Ok(self.base_url.read().await.clone())
150    }
151
152    async fn set_base_url(&self, base_url: Url) -> Result<()> {
153        let mut url = base_url;
154        if !url
155            .query_pairs()
156            .any(|(k, v)| k == "transport" && v == "polling")
157        {
158            url.query_pairs_mut().append_pair("transport", "polling");
159        }
160        *self.base_url.write().await = url;
161        Ok(())
162    }
163}
164
165impl Debug for PollingTransport {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        f.debug_struct("PollingTransport")
168            .field("client", &self.client)
169            .field("base_url", &self.base_url)
170            .finish()
171    }
172}
173
174#[cfg(test)]
175mod test {
176    use crate::asynchronous::transport::AsyncTransport;
177
178    use super::*;
179    use bytes::Bytes;
180    use futures_util::StreamExt;
181    use std::str::FromStr;
182
183    #[tokio::test]
184    async fn polling_transport_emit_returns_http_error_with_body() {
185        let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
186        let url = crate::test::spawn_http_error_mock(400, body);
187        let transport = PollingTransport::new(url, None, None);
188
189        let err = transport
190            .emit(Bytes::from_static(b"hello"), false)
191            .await
192            .expect_err("emit should fail when server returns 400");
193
194        match err {
195            Error::HttpErrorWithBody { status, body: got } => {
196                assert_eq!(status, 400);
197                assert_eq!(got, body);
198            }
199            other => panic!("expected HttpErrorWithBody, got: {other:?}"),
200        }
201    }
202
203    #[tokio::test]
204    async fn polling_transport_get_returns_http_error_with_body() {
205        let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
206        let url = crate::test::spawn_http_error_mock(400, body);
207        let mut transport = PollingTransport::new(url, None, None);
208
209        let err = transport
210            .next()
211            .await
212            .expect("stream should yield an item")
213            .expect_err("GET should fail when server returns 400");
214
215        match err {
216            Error::HttpErrorWithBody { status, body: got } => {
217                assert_eq!(status, 400);
218                assert_eq!(got, body);
219            }
220            other => panic!("expected HttpErrorWithBody, got: {other:?}"),
221        }
222    }
223
224    #[tokio::test]
225    async fn polling_transport_base_url() -> Result<()> {
226        let url = crate::test::engine_io_server()?.to_string();
227        let transport = PollingTransport::new(Url::from_str(&url[..]).unwrap(), None, None);
228        assert_eq!(
229            transport.base_url().await?.to_string(),
230            url.clone() + "?transport=polling"
231        );
232        transport
233            .set_base_url(Url::parse("https://127.0.0.1")?)
234            .await?;
235        assert_eq!(
236            transport.base_url().await?.to_string(),
237            "https://127.0.0.1/?transport=polling"
238        );
239        assert_ne!(transport.base_url().await?.to_string(), url);
240
241        transport
242            .set_base_url(Url::parse("http://127.0.0.1/?transport=polling")?)
243            .await?;
244        assert_eq!(
245            transport.base_url().await?.to_string(),
246            "http://127.0.0.1/?transport=polling"
247        );
248        assert_ne!(transport.base_url().await?.to_string(), url);
249        Ok(())
250    }
251}