Skip to main content

tf_rust_engineio/asynchronous/client/
builder.rs

1use crate::{
2    asynchronous::{
3        async_socket::Socket as InnerSocket,
4        async_transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport},
5        callback::OptionalCallback,
6        transport::AsyncTransport,
7    },
8    error::Result,
9    header::HeaderMap,
10    packet::HandshakePacket,
11    Error, Packet, ENGINE_IO_VERSION,
12};
13use bytes::Bytes;
14use futures_util::{future::BoxFuture, StreamExt};
15use native_tls::TlsConnector;
16use url::Url;
17
18use super::Client;
19
20#[derive(Clone, Debug)]
21pub struct ClientBuilder {
22    url: Url,
23    tls_config: Option<TlsConnector>,
24    headers: Option<HeaderMap>,
25    handshake: Option<HandshakePacket>,
26    on_error: OptionalCallback<String>,
27    on_open: OptionalCallback<()>,
28    on_close: OptionalCallback<()>,
29    on_data: OptionalCallback<Bytes>,
30    on_packet: OptionalCallback<Packet>,
31}
32
33impl ClientBuilder {
34    pub fn new(url: Url) -> Self {
35        let mut url = url;
36        url.query_pairs_mut()
37            .append_pair("EIO", &ENGINE_IO_VERSION.to_string());
38
39        // No path add engine.io
40        if url.path() == "/" {
41            url.set_path("/engine.io/");
42        }
43        ClientBuilder {
44            url,
45            headers: None,
46            tls_config: None,
47            handshake: None,
48            on_close: OptionalCallback::default(),
49            on_data: OptionalCallback::default(),
50            on_error: OptionalCallback::default(),
51            on_open: OptionalCallback::default(),
52            on_packet: OptionalCallback::default(),
53        }
54    }
55
56    /// Specify transport's tls config
57    pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
58        self.tls_config = Some(tls_config);
59        self
60    }
61
62    /// Specify transport's HTTP headers
63    pub fn headers(mut self, headers: HeaderMap) -> Self {
64        self.headers = Some(headers);
65        self
66    }
67
68    /// Registers the `on_close` callback.
69    #[cfg(feature = "async-callbacks")]
70    pub fn on_close<T>(mut self, callback: T) -> Self
71    where
72        T: 'static + Send + Sync + Fn(()) -> BoxFuture<'static, ()>,
73    {
74        self.on_close = OptionalCallback::new(callback);
75        self
76    }
77
78    /// Registers the `on_data` callback.
79    #[cfg(feature = "async-callbacks")]
80    pub fn on_data<T>(mut self, callback: T) -> Self
81    where
82        T: 'static + Send + Sync + Fn(Bytes) -> BoxFuture<'static, ()>,
83    {
84        self.on_data = OptionalCallback::new(callback);
85        self
86    }
87
88    /// Registers the `on_error` callback.
89    #[cfg(feature = "async-callbacks")]
90    pub fn on_error<T>(mut self, callback: T) -> Self
91    where
92        T: 'static + Send + Sync + Fn(String) -> BoxFuture<'static, ()>,
93    {
94        self.on_error = OptionalCallback::new(callback);
95        self
96    }
97
98    /// Registers the `on_open` callback.
99    #[cfg(feature = "async-callbacks")]
100    pub fn on_open<T>(mut self, callback: T) -> Self
101    where
102        T: 'static + Send + Sync + Fn(()) -> BoxFuture<'static, ()>,
103    {
104        self.on_open = OptionalCallback::new(callback);
105        self
106    }
107
108    /// Registers the `on_packet` callback.
109    #[cfg(feature = "async-callbacks")]
110    pub fn on_packet<T>(mut self, callback: T) -> Self
111    where
112        T: 'static + Send + Sync + Fn(Packet) -> BoxFuture<'static, ()>,
113    {
114        self.on_packet = OptionalCallback::new(callback);
115        self
116    }
117
118    /// Performs the handshake
119    async fn handshake_with_transport<T: AsyncTransport + Unpin>(
120        &mut self,
121        transport: &mut T,
122    ) -> Result<()> {
123        // No need to handshake twice
124        if self.handshake.is_some() {
125            return Ok(());
126        }
127
128        let mut url = self.url.clone();
129
130        let handshake: HandshakePacket =
131            Packet::try_from(transport.next().await.ok_or(Error::IncompletePacket())??)?
132                .try_into()?;
133
134        // update the base_url with the new sid
135        url.query_pairs_mut().append_pair("sid", &handshake.sid[..]);
136
137        self.handshake = Some(handshake);
138
139        self.url = url;
140
141        Ok(())
142    }
143
144    async fn handshake(&mut self) -> Result<()> {
145        if self.handshake.is_some() {
146            return Ok(());
147        }
148
149        let headers = if let Some(map) = self.headers.clone() {
150            Some(map.try_into()?)
151        } else {
152            None
153        };
154
155        // Start with polling transport
156        let mut transport =
157            PollingTransport::new(self.url.clone(), self.tls_config.clone(), headers);
158
159        self.handshake_with_transport(&mut transport).await
160    }
161
162    /// Build websocket if allowed, if not fall back to polling
163    pub async fn build(mut self) -> Result<Client> {
164        self.handshake().await?;
165
166        if self.websocket_upgrade()? {
167            self.build_websocket_with_upgrade().await
168        } else {
169            self.build_polling().await
170        }
171    }
172
173    /// Build socket with polling transport
174    pub async fn build_polling(mut self) -> Result<Client> {
175        self.handshake().await?;
176
177        // Make a polling transport with new sid
178        let transport = PollingTransport::new(
179            self.url,
180            self.tls_config,
181            self.headers.map(|v| v.try_into().unwrap()),
182        );
183
184        // SAFETY: handshake function called previously.
185        Ok(Client::new(InnerSocket::new(
186            transport.into(),
187            self.handshake.unwrap(),
188            self.on_close,
189            self.on_data,
190            self.on_error,
191            self.on_open,
192            self.on_packet,
193        )))
194    }
195
196    /// Build socket with a polling transport then upgrade to websocket transport
197    pub async fn build_websocket_with_upgrade(mut self) -> Result<Client> {
198        self.handshake().await?;
199
200        if self.websocket_upgrade()? {
201            self.build_websocket().await
202        } else {
203            Err(Error::IllegalWebsocketUpgrade())
204        }
205    }
206
207    /// Build socket with only a websocket transport
208    pub async fn build_websocket(mut self) -> Result<Client> {
209        let headers = if let Some(map) = self.headers.clone() {
210            Some(map.try_into()?)
211        } else {
212            None
213        };
214
215        match self.url.scheme() {
216            "http" | "ws" => {
217                let mut transport = WebsocketTransport::new(self.url.clone(), headers).await?;
218
219                if self.handshake.is_some() {
220                    transport.upgrade().await?;
221                } else {
222                    self.handshake_with_transport(&mut transport).await?;
223                }
224                // NOTE: Although self.url contains the sid, it does not propagate to the transport
225                // SAFETY: handshake function called previously.
226                Ok(Client::new(InnerSocket::new(
227                    transport.into(),
228                    self.handshake.unwrap(),
229                    self.on_close,
230                    self.on_data,
231                    self.on_error,
232                    self.on_open,
233                    self.on_packet,
234                )))
235            }
236            "https" | "wss" => {
237                let mut transport = WebsocketSecureTransport::new(
238                    self.url.clone(),
239                    self.tls_config.clone(),
240                    headers,
241                )
242                .await?;
243
244                if self.handshake.is_some() {
245                    transport.upgrade().await?;
246                } else {
247                    self.handshake_with_transport(&mut transport).await?;
248                }
249                // NOTE: Although self.url contains the sid, it does not propagate to the transport
250                // SAFETY: handshake function called previously.
251                Ok(Client::new(InnerSocket::new(
252                    transport.into(),
253                    self.handshake.unwrap(),
254                    self.on_close,
255                    self.on_data,
256                    self.on_error,
257                    self.on_open,
258                    self.on_packet,
259                )))
260            }
261            _ => Err(Error::InvalidUrlScheme(self.url.scheme().to_string())),
262        }
263    }
264
265    /// Build websocket if allowed, if not allowed or errored fall back to polling.
266    /// WARNING: websocket errors suppressed, no indication of websocket success or failure.
267    pub async fn build_with_fallback(self) -> Result<Client> {
268        let result = self.clone().build().await;
269        if result.is_err() {
270            self.build_polling().await
271        } else {
272            result
273        }
274    }
275
276    /// Checks the handshake to see if websocket upgrades are allowed
277    fn websocket_upgrade(&mut self) -> Result<bool> {
278        if self.handshake.is_none() {
279            return Ok(false);
280        }
281
282        Ok(self
283            .handshake
284            .as_ref()
285            .unwrap()
286            .upgrades
287            .iter()
288            .any(|upgrade| upgrade.to_lowercase() == *"websocket"))
289    }
290}