tf_rust_engineio/asynchronous/async_transports/
polling.rs1use adler32::adler32;
2use async_stream::try_stream;
3use async_trait::async_trait;
4use base64::{engine::general_purpose, Engine as _};
5use bytes::{BufMut, Bytes, BytesMut};
6use futures_util::{Stream, StreamExt};
7use http::HeaderMap;
8use native_tls::TlsConnector;
9use reqwest::{Client, ClientBuilder, Response};
10use std::fmt::Debug;
11use std::time::SystemTime;
12use std::{pin::Pin, sync::Arc};
13use tokio::sync::RwLock;
14use url::Url;
15
16use crate::asynchronous::generator::StreamGenerator;
17use crate::{asynchronous::transport::AsyncTransport, error::Result, Error};
18
19#[derive(Clone)]
22pub struct PollingTransport {
23 client: Client,
24 base_url: Arc<RwLock<Url>>,
25 generator: StreamGenerator<Bytes>,
26}
27
28impl PollingTransport {
29 pub fn new(
30 base_url: Url,
31 tls_config: Option<TlsConnector>,
32 opening_headers: Option<HeaderMap>,
33 ) -> Self {
34 let client = match (tls_config, opening_headers) {
35 (Some(config), Some(map)) => ClientBuilder::new()
36 .use_preconfigured_tls(config)
37 .default_headers(map)
38 .build()
39 .unwrap(),
40 (Some(config), None) => ClientBuilder::new()
41 .use_preconfigured_tls(config)
42 .build()
43 .unwrap(),
44 (None, Some(map)) => ClientBuilder::new().default_headers(map).build().unwrap(),
45 (None, None) => Client::new(),
46 };
47
48 let mut url = base_url;
49 url.query_pairs_mut().append_pair("transport", "polling");
50
51 PollingTransport {
52 client: client.clone(),
53 base_url: Arc::new(RwLock::new(url.clone())),
54 generator: StreamGenerator::new(Self::stream(url, client)),
55 }
56 }
57
58 fn address(mut url: Url) -> Result<Url> {
59 let reader = format!("{:#?}", SystemTime::now());
60 let hash = adler32(reader.as_bytes()).unwrap();
61 url.query_pairs_mut().append_pair("t", &hash.to_string());
62 Ok(url)
63 }
64
65 fn send_request(url: Url, client: Client) -> impl Stream<Item = Result<Response>> {
66 try_stream! {
67 let address = Self::address(url);
68
69 let response = client
70 .get(address?)
71 .send().await?;
72
73 let status = response.status().as_u16();
74 if status != 200 {
75 let err = match response.text().await {
76 Ok(body) => Error::HttpErrorWithBody { status, body },
77 Err(_) => Error::IncompleteHttp(status),
78 };
79 Err(err)?;
80 unreachable!();
81 }
82
83 yield response
84 }
85 }
86
87 fn stream(
88 url: Url,
89 client: Client,
90 ) -> Pin<Box<dyn Stream<Item = Result<Bytes>> + 'static + Send>> {
91 Box::pin(try_stream! {
92 loop {
93 for await elem in Self::send_request(url.clone(), client.clone()) {
94 for await bytes in elem?.bytes_stream() {
95 yield bytes?;
96 }
97 }
98 }
99 })
100 }
101}
102
103impl Stream for PollingTransport {
104 type Item = Result<Bytes>;
105
106 fn poll_next(
107 mut self: Pin<&mut Self>,
108 cx: &mut std::task::Context<'_>,
109 ) -> std::task::Poll<Option<Self::Item>> {
110 self.generator.poll_next_unpin(cx)
111 }
112}
113
114#[async_trait]
115impl AsyncTransport for PollingTransport {
116 async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
117 let data_to_send = if is_binary_att {
118 let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
120 packet_bytes.put_u8(b'b');
121
122 let encoded_data = general_purpose::STANDARD.encode(data);
123 packet_bytes.put(encoded_data.as_bytes());
124
125 packet_bytes.freeze()
126 } else {
127 data
128 };
129
130 let response = self
131 .client
132 .post(self.address().await?)
133 .body(data_to_send)
134 .send()
135 .await?;
136
137 let status = response.status().as_u16();
138 if status != 200 {
139 return Err(match response.text().await {
140 Ok(body) => Error::HttpErrorWithBody { status, body },
141 Err(_) => Error::IncompleteHttp(status),
142 });
143 }
144
145 Ok(())
146 }
147
148 async fn base_url(&self) -> Result<Url> {
149 Ok(self.base_url.read().await.clone())
150 }
151
152 async fn set_base_url(&self, base_url: Url) -> Result<()> {
153 let mut url = base_url;
154 if !url
155 .query_pairs()
156 .any(|(k, v)| k == "transport" && v == "polling")
157 {
158 url.query_pairs_mut().append_pair("transport", "polling");
159 }
160 *self.base_url.write().await = url;
161 Ok(())
162 }
163}
164
165impl Debug for PollingTransport {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 f.debug_struct("PollingTransport")
168 .field("client", &self.client)
169 .field("base_url", &self.base_url)
170 .finish()
171 }
172}
173
174#[cfg(test)]
175mod test {
176 use crate::asynchronous::transport::AsyncTransport;
177
178 use super::*;
179 use bytes::Bytes;
180 use futures_util::StreamExt;
181 use std::str::FromStr;
182
183 #[tokio::test]
184 async fn polling_transport_emit_returns_http_error_with_body() {
185 let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
186 let url = crate::test::spawn_http_error_mock(400, body);
187 let transport = PollingTransport::new(url, None, None);
188
189 let err = transport
190 .emit(Bytes::from_static(b"hello"), false)
191 .await
192 .expect_err("emit should fail when server returns 400");
193
194 match err {
195 Error::HttpErrorWithBody { status, body: got } => {
196 assert_eq!(status, 400);
197 assert_eq!(got, body);
198 }
199 other => panic!("expected HttpErrorWithBody, got: {other:?}"),
200 }
201 }
202
203 #[tokio::test]
204 async fn polling_transport_get_returns_http_error_with_body() {
205 let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
206 let url = crate::test::spawn_http_error_mock(400, body);
207 let mut transport = PollingTransport::new(url, None, None);
208
209 let err = transport
210 .next()
211 .await
212 .expect("stream should yield an item")
213 .expect_err("GET should fail when server returns 400");
214
215 match err {
216 Error::HttpErrorWithBody { status, body: got } => {
217 assert_eq!(status, 400);
218 assert_eq!(got, body);
219 }
220 other => panic!("expected HttpErrorWithBody, got: {other:?}"),
221 }
222 }
223
224 #[tokio::test]
225 async fn polling_transport_base_url() -> Result<()> {
226 let url = crate::test::engine_io_server()?.to_string();
227 let transport = PollingTransport::new(Url::from_str(&url[..]).unwrap(), None, None);
228 assert_eq!(
229 transport.base_url().await?.to_string(),
230 url.clone() + "?transport=polling"
231 );
232 transport
233 .set_base_url(Url::parse("https://127.0.0.1")?)
234 .await?;
235 assert_eq!(
236 transport.base_url().await?.to_string(),
237 "https://127.0.0.1/?transport=polling"
238 );
239 assert_ne!(transport.base_url().await?.to_string(), url);
240
241 transport
242 .set_base_url(Url::parse("http://127.0.0.1/?transport=polling")?)
243 .await?;
244 assert_eq!(
245 transport.base_url().await?.to_string(),
246 "http://127.0.0.1/?transport=polling"
247 );
248 assert_ne!(transport.base_url().await?.to_string(), url);
249 Ok(())
250 }
251}