tf_rust_engineio/asynchronous/async_transports/
polling.rs1use adler32::adler32;
2use async_stream::try_stream;
3use async_trait::async_trait;
4use base64::{engine::general_purpose, Engine as _};
5use bytes::{BufMut, Bytes, BytesMut};
6use futures_util::{Stream, StreamExt};
7use http::HeaderMap;
8use native_tls::TlsConnector;
9use reqwest::{Client, ClientBuilder, Response};
10use std::fmt::Debug;
11use std::time::SystemTime;
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use url::Url;
15
16use crate::asynchronous::generator::StreamGenerator;
17use crate::{asynchronous::transport::AsyncTransport, error::Result, Error};
18
19#[derive(Clone)]
22pub struct PollingTransport {
23 client: Client,
24 base_url: Arc<RwLock<Url>>,
25 generator: StreamGenerator<Bytes>,
26}
27
28impl PollingTransport {
29 pub fn new(
30 base_url: Url,
31 tls_config: Option<TlsConnector>,
32 opening_headers: Option<HeaderMap>,
33 ) -> Self {
34 let client = match (tls_config, opening_headers) {
35 (Some(config), Some(map)) => ClientBuilder::new()
36 .use_preconfigured_tls(config)
37 .default_headers(map)
38 .build()
39 .unwrap(),
40 (Some(config), None) => ClientBuilder::new()
41 .use_preconfigured_tls(config)
42 .build()
43 .unwrap(),
44 (None, Some(map)) => ClientBuilder::new().default_headers(map).build().unwrap(),
45 (None, None) => Client::new(),
46 };
47
48 let mut url = base_url;
49 url.query_pairs_mut().append_pair("transport", "polling");
50
51 PollingTransport {
52 client: client.clone(),
53 base_url: Arc::new(RwLock::new(url.clone())),
54 generator: StreamGenerator::new(Self::stream(url, client)),
55 }
56 }
57
58 fn address(mut url: Url) -> Result<Url> {
59 let reader = format!("{:#?}", SystemTime::now());
60 let hash = adler32(reader.as_bytes()).unwrap();
61 url.query_pairs_mut().append_pair("t", &hash.to_string());
62 Ok(url)
63 }
64
65 fn send_request(url: Url, client: Client) -> impl Stream<Item = Result<Response>> {
66 try_stream! {
67 let address = Self::address(url);
68
69 yield client
70 .get(address?)
71 .send().await?
72 }
73 }
74
75 fn stream(
76 url: Url,
77 client: Client,
78 ) -> Pin<Box<dyn Stream<Item = Result<Bytes>> + 'static + Send>> {
79 Box::pin(try_stream! {
80 loop {
81 for await elem in Self::send_request(url.clone(), client.clone()) {
82 for await bytes in elem?.bytes_stream() {
83 yield bytes?;
84 }
85 }
86 }
87 })
88 }
89}
90
91impl Stream for PollingTransport {
92 type Item = Result<Bytes>;
93
94 fn poll_next(
95 mut self: Pin<&mut Self>,
96 cx: &mut std::task::Context<'_>,
97 ) -> std::task::Poll<Option<Self::Item>> {
98 self.generator.poll_next_unpin(cx)
99 }
100}
101
102#[async_trait]
103impl AsyncTransport for PollingTransport {
104 async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
105 let data_to_send = if is_binary_att {
106 let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
108 packet_bytes.put_u8(b'b');
109
110 let encoded_data = general_purpose::STANDARD.encode(data);
111 packet_bytes.put(encoded_data.as_bytes());
112
113 packet_bytes.freeze()
114 } else {
115 data
116 };
117
118 let status = self
119 .client
120 .post(self.address().await?)
121 .body(data_to_send)
122 .send()
123 .await?
124 .status()
125 .as_u16();
126
127 if status != 200 {
128 let error = Error::IncompleteHttp(status);
129 return Err(error);
130 }
131
132 Ok(())
133 }
134
135 async fn base_url(&self) -> Result<Url> {
136 Ok(self.base_url.read().await.clone())
137 }
138
139 async fn set_base_url(&self, base_url: Url) -> Result<()> {
140 let mut url = base_url;
141 if !url
142 .query_pairs()
143 .any(|(k, v)| k == "transport" && v == "polling")
144 {
145 url.query_pairs_mut().append_pair("transport", "polling");
146 }
147 *self.base_url.write().await = url;
148 Ok(())
149 }
150}
151
152impl Debug for PollingTransport {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 f.debug_struct("PollingTransport")
155 .field("client", &self.client)
156 .field("base_url", &self.base_url)
157 .finish()
158 }
159}
160
161#[cfg(test)]
162mod test {
163 use crate::asynchronous::transport::AsyncTransport;
164
165 use super::*;
166 use std::str::FromStr;
167
168 #[tokio::test]
169 async fn polling_transport_base_url() -> Result<()> {
170 let url = crate::test::engine_io_server()?.to_string();
171 let transport = PollingTransport::new(Url::from_str(&url[..]).unwrap(), None, None);
172 assert_eq!(
173 transport.base_url().await?.to_string(),
174 url.clone() + "?transport=polling"
175 );
176 transport
177 .set_base_url(Url::parse("https://127.0.0.1")?)
178 .await?;
179 assert_eq!(
180 transport.base_url().await?.to_string(),
181 "https://127.0.0.1/?transport=polling"
182 );
183 assert_ne!(transport.base_url().await?.to_string(), url);
184
185 transport
186 .set_base_url(Url::parse("http://127.0.0.1/?transport=polling")?)
187 .await?;
188 assert_eq!(
189 transport.base_url().await?.to_string(),
190 "http://127.0.0.1/?transport=polling"
191 );
192 assert_ne!(transport.base_url().await?.to_string(), url);
193 Ok(())
194 }
195}