Skip to main content

tf_rust_engineio/asynchronous/async_transports/
polling.rs

1use adler32::adler32;
2use async_stream::try_stream;
3use async_trait::async_trait;
4use base64::{engine::general_purpose, Engine as _};
5use bytes::{BufMut, Bytes, BytesMut};
6use futures_util::{Stream, StreamExt};
7use http::HeaderMap;
8use native_tls::TlsConnector;
9use reqwest::{Client, ClientBuilder, Response};
10use std::fmt::Debug;
11use std::time::SystemTime;
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use url::Url;
15
16use crate::asynchronous::generator::StreamGenerator;
17use crate::{asynchronous::transport::AsyncTransport, error::Result, Error};
18
19/// An asynchronous polling type. Makes use of the nonblocking reqwest types and
20/// methods.
21#[derive(Clone)]
22pub struct PollingTransport {
23    client: Client,
24    base_url: Arc<RwLock<Url>>,
25    generator: StreamGenerator<Bytes>,
26}
27
28impl PollingTransport {
29    pub fn new(
30        base_url: Url,
31        tls_config: Option<TlsConnector>,
32        opening_headers: Option<HeaderMap>,
33    ) -> Self {
34        let client = match (tls_config, opening_headers) {
35            (Some(config), Some(map)) => ClientBuilder::new()
36                .use_preconfigured_tls(config)
37                .default_headers(map)
38                .build()
39                .unwrap(),
40            (Some(config), None) => ClientBuilder::new()
41                .use_preconfigured_tls(config)
42                .build()
43                .unwrap(),
44            (None, Some(map)) => ClientBuilder::new().default_headers(map).build().unwrap(),
45            (None, None) => Client::new(),
46        };
47
48        let mut url = base_url;
49        url.query_pairs_mut().append_pair("transport", "polling");
50
51        PollingTransport {
52            client: client.clone(),
53            base_url: Arc::new(RwLock::new(url.clone())),
54            generator: StreamGenerator::new(Self::stream(url, client)),
55        }
56    }
57
58    fn address(mut url: Url) -> Result<Url> {
59        let reader = format!("{:#?}", SystemTime::now());
60        let hash = adler32(reader.as_bytes()).unwrap();
61        url.query_pairs_mut().append_pair("t", &hash.to_string());
62        Ok(url)
63    }
64
65    fn send_request(url: Url, client: Client) -> impl Stream<Item = Result<Response>> {
66        try_stream! {
67            let address = Self::address(url);
68
69            yield client
70                .get(address?)
71                .send().await?
72        }
73    }
74
75    fn stream(
76        url: Url,
77        client: Client,
78    ) -> Pin<Box<dyn Stream<Item = Result<Bytes>> + 'static + Send>> {
79        Box::pin(try_stream! {
80            loop {
81                for await elem in Self::send_request(url.clone(), client.clone()) {
82                    for await bytes in elem?.bytes_stream() {
83                        yield bytes?;
84                    }
85                }
86            }
87        })
88    }
89}
90
91impl Stream for PollingTransport {
92    type Item = Result<Bytes>;
93
94    fn poll_next(
95        mut self: Pin<&mut Self>,
96        cx: &mut std::task::Context<'_>,
97    ) -> std::task::Poll<Option<Self::Item>> {
98        self.generator.poll_next_unpin(cx)
99    }
100}
101
102#[async_trait]
103impl AsyncTransport for PollingTransport {
104    async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
105        let data_to_send = if is_binary_att {
106            // the binary attachment gets `base64` encoded
107            let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
108            packet_bytes.put_u8(b'b');
109
110            let encoded_data = general_purpose::STANDARD.encode(data);
111            packet_bytes.put(encoded_data.as_bytes());
112
113            packet_bytes.freeze()
114        } else {
115            data
116        };
117
118        let status = self
119            .client
120            .post(self.address().await?)
121            .body(data_to_send)
122            .send()
123            .await?
124            .status()
125            .as_u16();
126
127        if status != 200 {
128            let error = Error::IncompleteHttp(status);
129            return Err(error);
130        }
131
132        Ok(())
133    }
134
135    async fn base_url(&self) -> Result<Url> {
136        Ok(self.base_url.read().await.clone())
137    }
138
139    async fn set_base_url(&self, base_url: Url) -> Result<()> {
140        let mut url = base_url;
141        if !url
142            .query_pairs()
143            .any(|(k, v)| k == "transport" && v == "polling")
144        {
145            url.query_pairs_mut().append_pair("transport", "polling");
146        }
147        *self.base_url.write().await = url;
148        Ok(())
149    }
150}
151
152impl Debug for PollingTransport {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        f.debug_struct("PollingTransport")
155            .field("client", &self.client)
156            .field("base_url", &self.base_url)
157            .finish()
158    }
159}
160
161#[cfg(test)]
162mod test {
163    use crate::asynchronous::transport::AsyncTransport;
164
165    use super::*;
166    use std::str::FromStr;
167
168    #[tokio::test]
169    async fn polling_transport_base_url() -> Result<()> {
170        let url = crate::test::engine_io_server()?.to_string();
171        let transport = PollingTransport::new(Url::from_str(&url[..]).unwrap(), None, None);
172        assert_eq!(
173            transport.base_url().await?.to_string(),
174            url.clone() + "?transport=polling"
175        );
176        transport
177            .set_base_url(Url::parse("https://127.0.0.1")?)
178            .await?;
179        assert_eq!(
180            transport.base_url().await?.to_string(),
181            "https://127.0.0.1/?transport=polling"
182        );
183        assert_ne!(transport.base_url().await?.to_string(), url);
184
185        transport
186            .set_base_url(Url::parse("http://127.0.0.1/?transport=polling")?)
187            .await?;
188        assert_eq!(
189            transport.base_url().await?.to_string(),
190            "http://127.0.0.1/?transport=polling"
191        );
192        assert_ne!(transport.base_url().await?.to_string(), url);
193        Ok(())
194    }
195}