tf_rust_engineio/transports/
polling.rs1use crate::error::{Error, Result};
2use crate::transport::Transport;
3use base64::{engine::general_purpose, Engine as _};
4use bytes::{BufMut, Bytes, BytesMut};
5use native_tls::TlsConnector;
6use reqwest::{
7 blocking::{Client, ClientBuilder},
8 header::HeaderMap,
9};
10use std::sync::{Arc, RwLock};
11use std::time::Duration;
12use url::Url;
13
14#[derive(Debug, Clone)]
15pub struct PollingTransport {
16 client: Arc<Client>,
17 base_url: Arc<RwLock<Url>>,
18}
19
20impl PollingTransport {
21 pub fn new(
23 base_url: Url,
24 tls_config: Option<TlsConnector>,
25 opening_headers: Option<HeaderMap>,
26 ) -> Self {
27 let client = match (tls_config, opening_headers) {
28 (Some(config), Some(map)) => ClientBuilder::new()
29 .use_preconfigured_tls(config)
30 .default_headers(map)
31 .build()
32 .unwrap(),
33 (Some(config), None) => ClientBuilder::new()
34 .use_preconfigured_tls(config)
35 .build()
36 .unwrap(),
37 (None, Some(map)) => ClientBuilder::new().default_headers(map).build().unwrap(),
38 (None, None) => Client::new(),
39 };
40
41 let mut url = base_url;
42 url.query_pairs_mut().append_pair("transport", "polling");
43
44 PollingTransport {
45 client: Arc::new(client),
46 base_url: Arc::new(RwLock::new(url)),
47 }
48 }
49}
50
51impl Transport for PollingTransport {
52 fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
53 let data_to_send = if is_binary_att {
54 let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
56 packet_bytes.put_u8(b'b');
57
58 let encoded_data = general_purpose::STANDARD.encode(data);
59 packet_bytes.put(encoded_data.as_bytes());
60
61 packet_bytes.freeze()
62 } else {
63 data
64 };
65 let response = self
66 .client
67 .post(self.address()?)
68 .body(data_to_send)
69 .send()?;
70
71 let status = response.status().as_u16();
72 if status != 200 {
73 return Err(match response.text() {
74 Ok(body) => Error::HttpErrorWithBody { status, body },
75 Err(_) => Error::IncompleteHttp(status),
76 });
77 }
78
79 Ok(())
80 }
81
82 fn poll(&self, timeout: Duration) -> Result<Bytes> {
83 let response = self.client.get(self.address()?).timeout(timeout).send()?;
84
85 let status = response.status().as_u16();
86 if status != 200 {
87 return Err(match response.text() {
88 Ok(body) => Error::HttpErrorWithBody { status, body },
89 Err(_) => Error::IncompleteHttp(status),
90 });
91 }
92
93 Ok(response.bytes()?)
94 }
95
96 fn base_url(&self) -> Result<Url> {
97 Ok(self.base_url.read()?.clone())
98 }
99
100 fn set_base_url(&self, base_url: Url) -> Result<()> {
101 let mut url = base_url;
102 if !url
103 .query_pairs()
104 .any(|(k, v)| k == "transport" && v == "polling")
105 {
106 url.query_pairs_mut().append_pair("transport", "polling");
107 }
108 *self.base_url.write()? = url;
109 Ok(())
110 }
111}
112
113#[cfg(test)]
114mod test {
115 use super::*;
116 use std::str::FromStr;
117
118 #[test]
119 fn polling_transport_emit_returns_http_error_with_body() {
120 let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
121 let url = crate::test::spawn_http_error_mock(400, body);
122 let transport = PollingTransport::new(url, None, None);
123
124 let err = transport
125 .emit(Bytes::from_static(b"hello"), false)
126 .expect_err("emit should fail when server returns 400");
127
128 match err {
129 Error::HttpErrorWithBody { status, body: got } => {
130 assert_eq!(status, 400);
131 assert_eq!(got, body);
132 }
133 other => panic!("expected HttpErrorWithBody, got: {other:?}"),
134 }
135 }
136
137 #[test]
138 fn polling_transport_poll_returns_http_error_with_body() {
139 let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
140 let url = crate::test::spawn_http_error_mock(400, body);
141 let transport = PollingTransport::new(url, None, None);
142
143 let err = transport
144 .poll(Duration::from_secs(2))
145 .expect_err("poll should fail when server returns 400");
146
147 match err {
148 Error::HttpErrorWithBody { status, body: got } => {
149 assert_eq!(status, 400);
150 assert_eq!(got, body);
151 }
152 other => panic!("expected HttpErrorWithBody, got: {other:?}"),
153 }
154 }
155
156 #[test]
157 fn polling_transport_base_url() -> Result<()> {
158 let url = crate::test::engine_io_server()?.to_string();
159 let transport = PollingTransport::new(Url::from_str(&url[..]).unwrap(), None, None);
160 assert_eq!(
161 transport.base_url()?.to_string(),
162 url.clone() + "?transport=polling"
163 );
164 transport.set_base_url(Url::parse("https://127.0.0.1")?)?;
165 assert_eq!(
166 transport.base_url()?.to_string(),
167 "https://127.0.0.1/?transport=polling"
168 );
169 assert_ne!(transport.base_url()?.to_string(), url);
170
171 transport.set_base_url(Url::parse("http://127.0.0.1/?transport=polling")?)?;
172 assert_eq!(
173 transport.base_url()?.to_string(),
174 "http://127.0.0.1/?transport=polling"
175 );
176 assert_ne!(transport.base_url()?.to_string(), url);
177 Ok(())
178 }
179
180 #[test]
181 fn transport_debug() -> Result<()> {
182 let mut url = crate::test::engine_io_server()?;
183 let transport =
184 PollingTransport::new(Url::from_str(&url.to_string()[..]).unwrap(), None, None);
185 url.query_pairs_mut().append_pair("transport", "polling");
186 assert_eq!(format!("PollingTransport {{ client: {:?}, base_url: RwLock {{ data: {:?}, poisoned: false, .. }} }}", transport.client, url), format!("{:?}", transport));
187 let test: Box<dyn Transport> = Box::new(transport);
188 assert_eq!(
189 format!("Transport(base_url: Ok({:?}))", url),
190 format!("{:?}", test)
191 );
192 Ok(())
193 }
194}