Skip to main content

tf_rust_engineio/client/
client.rs

1use super::super::socket::Socket as InnerSocket;
2use crate::callback::OptionalCallback;
3use crate::socket::DEFAULT_MAX_POLL_TIMEOUT;
4use crate::transport::Transport;
5
6use crate::error::{Error, Result};
7use crate::header::HeaderMap;
8use crate::packet::{HandshakePacket, Packet, PacketId};
9use crate::transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport};
10use crate::ENGINE_IO_VERSION;
11use bytes::Bytes;
12use native_tls::TlsConnector;
13use std::convert::TryFrom;
14use std::convert::TryInto;
15use std::fmt::Debug;
16use url::Url;
17
18/// An engine.io client that allows interaction with the connected engine.io
19/// server. This client provides means for connecting, disconnecting and sending
20/// packets to the server.
21///
22/// ## Note:
23/// There is no need to put this Client behind an `Arc`, as the type uses `Arc`
24/// internally and provides a shared state beyond all cloned instances.
25#[derive(Clone, Debug)]
26pub struct Client {
27    socket: InnerSocket,
28}
29
30#[derive(Clone, Debug)]
31pub struct ClientBuilder {
32    url: Url,
33    tls_config: Option<TlsConnector>,
34    headers: Option<HeaderMap>,
35    handshake: Option<HandshakePacket>,
36    on_error: OptionalCallback<String>,
37    on_open: OptionalCallback<()>,
38    on_close: OptionalCallback<()>,
39    on_data: OptionalCallback<Bytes>,
40    on_packet: OptionalCallback<Packet>,
41}
42
43impl ClientBuilder {
44    pub fn new(url: Url) -> Self {
45        let mut url = url;
46        url.query_pairs_mut()
47            .append_pair("EIO", &ENGINE_IO_VERSION.to_string());
48
49        // No path add engine.io
50        if url.path() == "/" {
51            url.set_path("/engine.io/");
52        }
53        ClientBuilder {
54            url,
55            headers: None,
56            tls_config: None,
57            handshake: None,
58            on_close: OptionalCallback::default(),
59            on_data: OptionalCallback::default(),
60            on_error: OptionalCallback::default(),
61            on_open: OptionalCallback::default(),
62            on_packet: OptionalCallback::default(),
63        }
64    }
65
66    /// Specify transport's tls config
67    pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
68        self.tls_config = Some(tls_config);
69        self
70    }
71
72    /// Specify transport's HTTP headers
73    pub fn headers(mut self, headers: HeaderMap) -> Self {
74        self.headers = Some(headers);
75        self
76    }
77
78    /// Registers the `on_close` callback.
79    pub fn on_close<T>(mut self, callback: T) -> Self
80    where
81        T: Fn(()) + 'static + Sync + Send,
82    {
83        self.on_close = OptionalCallback::new(callback);
84        self
85    }
86
87    /// Registers the `on_data` callback.
88    pub fn on_data<T>(mut self, callback: T) -> Self
89    where
90        T: Fn(Bytes) + 'static + Sync + Send,
91    {
92        self.on_data = OptionalCallback::new(callback);
93        self
94    }
95
96    /// Registers the `on_error` callback.
97    pub fn on_error<T>(mut self, callback: T) -> Self
98    where
99        T: Fn(String) + 'static + Sync + Send,
100    {
101        self.on_error = OptionalCallback::new(callback);
102        self
103    }
104
105    /// Registers the `on_open` callback.
106    pub fn on_open<T>(mut self, callback: T) -> Self
107    where
108        T: Fn(()) + 'static + Sync + Send,
109    {
110        self.on_open = OptionalCallback::new(callback);
111        self
112    }
113
114    /// Registers the `on_packet` callback.
115    pub fn on_packet<T>(mut self, callback: T) -> Self
116    where
117        T: Fn(Packet) + 'static + Sync + Send,
118    {
119        self.on_packet = OptionalCallback::new(callback);
120        self
121    }
122
123    /// Performs the handshake
124    fn handshake_with_transport<T: Transport>(&mut self, transport: &T) -> Result<()> {
125        // No need to handshake twice
126        if self.handshake.is_some() {
127            return Ok(());
128        }
129
130        let mut url = self.url.clone();
131
132        let handshake: HandshakePacket =
133            Packet::try_from(transport.poll(DEFAULT_MAX_POLL_TIMEOUT)?)?.try_into()?;
134
135        // update the base_url with the new sid
136        url.query_pairs_mut().append_pair("sid", &handshake.sid[..]);
137
138        self.handshake = Some(handshake);
139
140        self.url = url;
141
142        Ok(())
143    }
144
145    fn handshake(&mut self) -> Result<()> {
146        if self.handshake.is_some() {
147            return Ok(());
148        }
149
150        // Start with polling transport
151        let transport = PollingTransport::new(
152            self.url.clone(),
153            self.tls_config.clone(),
154            self.headers.clone().map(|v| v.try_into().unwrap()),
155        );
156
157        self.handshake_with_transport(&transport)
158    }
159
160    /// Build websocket if allowed, if not fall back to polling
161    pub fn build(mut self) -> Result<Client> {
162        self.handshake()?;
163
164        if self.websocket_upgrade()? {
165            self.build_websocket_with_upgrade()
166        } else {
167            self.build_polling()
168        }
169    }
170
171    /// Build socket with polling transport
172    pub fn build_polling(mut self) -> Result<Client> {
173        self.handshake()?;
174
175        // Make a polling transport with new sid
176        let transport = PollingTransport::new(
177            self.url,
178            self.tls_config,
179            self.headers.map(|v| v.try_into().unwrap()),
180        );
181
182        // SAFETY: handshake function called previously.
183        Ok(Client {
184            socket: InnerSocket::new(
185                transport.into(),
186                self.handshake.unwrap(),
187                self.on_close,
188                self.on_data,
189                self.on_error,
190                self.on_open,
191                self.on_packet,
192            ),
193        })
194    }
195
196    /// Build socket with a polling transport then upgrade to websocket transport
197    pub fn build_websocket_with_upgrade(mut self) -> Result<Client> {
198        self.handshake()?;
199
200        if self.websocket_upgrade()? {
201            self.build_websocket()
202        } else {
203            Err(Error::IllegalWebsocketUpgrade())
204        }
205    }
206
207    /// Build socket with only a websocket transport
208    pub fn build_websocket(mut self) -> Result<Client> {
209        // SAFETY: Already a Url
210        let url = url::Url::parse(self.url.as_ref())?;
211
212        let headers: Option<http::HeaderMap> = if let Some(map) = self.headers.clone() {
213            Some(map.try_into()?)
214        } else {
215            None
216        };
217
218        match url.scheme() {
219            "http" | "ws" => {
220                let transport = WebsocketTransport::new(url, headers)?;
221                if self.handshake.is_some() {
222                    transport.upgrade()?;
223                } else {
224                    self.handshake_with_transport(&transport)?;
225                }
226                // NOTE: Although self.url contains the sid, it does not propagate to the transport
227                // SAFETY: handshake function called previously.
228                Ok(Client {
229                    socket: InnerSocket::new(
230                        transport.into(),
231                        self.handshake.unwrap(),
232                        self.on_close,
233                        self.on_data,
234                        self.on_error,
235                        self.on_open,
236                        self.on_packet,
237                    ),
238                })
239            }
240            "https" | "wss" => {
241                let transport =
242                    WebsocketSecureTransport::new(url, self.tls_config.clone(), headers)?;
243                if self.handshake.is_some() {
244                    transport.upgrade()?;
245                } else {
246                    self.handshake_with_transport(&transport)?;
247                }
248                // NOTE: Although self.url contains the sid, it does not propagate to the transport
249                // SAFETY: handshake function called previously.
250                Ok(Client {
251                    socket: InnerSocket::new(
252                        transport.into(),
253                        self.handshake.unwrap(),
254                        self.on_close,
255                        self.on_data,
256                        self.on_error,
257                        self.on_open,
258                        self.on_packet,
259                    ),
260                })
261            }
262            _ => Err(Error::InvalidUrlScheme(url.scheme().to_string())),
263        }
264    }
265
266    /// Build websocket if allowed, if not allowed or errored fall back to polling.
267    /// WARNING: websocket errors suppressed, no indication of websocket success or failure.
268    pub fn build_with_fallback(self) -> Result<Client> {
269        let result = self.clone().build();
270        if result.is_err() {
271            self.build_polling()
272        } else {
273            result
274        }
275    }
276
277    /// Checks the handshake to see if websocket upgrades are allowed
278    fn websocket_upgrade(&mut self) -> Result<bool> {
279        // SAFETY: handshake set by above function.
280        Ok(self
281            .handshake
282            .as_ref()
283            .unwrap()
284            .upgrades
285            .iter()
286            .any(|upgrade| upgrade.to_lowercase() == *"websocket"))
287    }
288}
289
290impl Client {
291    pub fn close(&self) -> Result<()> {
292        self.socket.disconnect()
293    }
294
295    /// Opens the connection to a specified server. The first Pong packet is sent
296    /// to the server to trigger the Ping-cycle.
297    pub fn connect(&self) -> Result<()> {
298        self.socket.connect()
299    }
300
301    /// Disconnects the connection.
302    pub fn disconnect(&self) -> Result<()> {
303        self.socket.disconnect()
304    }
305
306    /// Sends a packet to the server.
307    pub fn emit(&self, packet: Packet) -> Result<()> {
308        self.socket.emit(packet)
309    }
310
311    /// Polls for next payload
312    #[doc(hidden)]
313    pub fn poll(&self) -> Result<Option<Packet>> {
314        let packet = self.socket.poll()?;
315        if let Some(packet) = packet {
316            // check for the appropriate action or callback
317            self.socket.handle_packet(packet.clone());
318            match packet.packet_id {
319                PacketId::MessageBinary => {
320                    self.socket.handle_data(packet.data.clone());
321                }
322                PacketId::Message => {
323                    self.socket.handle_data(packet.data.clone());
324                }
325                PacketId::Close => {
326                    self.socket.handle_close();
327                }
328                PacketId::Open => {
329                    unreachable!("Won't happen as we open the connection beforehand");
330                }
331                PacketId::Upgrade => {
332                    // this is already checked during the handshake, so just do nothing here
333                }
334                PacketId::Ping => {
335                    self.socket.pinged()?;
336                    self.emit(Packet::new(PacketId::Pong, Bytes::new()))?;
337                }
338                PacketId::Pong => {
339                    // this will never happen as the pong packet is
340                    // only sent by the client
341                    unreachable!();
342                }
343                PacketId::Noop => (),
344            }
345            Ok(Some(packet))
346        } else {
347            Ok(None)
348        }
349    }
350
351    /// Check if the underlying transport client is connected.
352    pub fn is_connected(&self) -> Result<bool> {
353        self.socket.is_connected()
354    }
355
356    pub fn iter(&self) -> Iter<'_> {
357        Iter { socket: self }
358    }
359}
360
361#[derive(Clone)]
362pub struct Iter<'a> {
363    socket: &'a Client,
364}
365
366impl<'a> Iterator for Iter<'a> {
367    type Item = Result<Packet>;
368    fn next(&mut self) -> std::option::Option<<Self as std::iter::Iterator>::Item> {
369        match self.socket.poll() {
370            Ok(Some(packet)) => Some(Ok(packet)),
371            Ok(None) => None,
372            Err(err) => Some(Err(err)),
373        }
374    }
375}
376
377#[cfg(test)]
378mod test {
379
380    use crate::packet::PacketId;
381
382    use super::*;
383
384    /// The purpose of this test is to check whether the Client is properly cloneable or not.
385    /// As the documentation of the engine.io client states, the object needs to maintain it's internal
386    /// state when cloned and the cloned object should reflect the same state throughout the lifetime
387    /// of both objects (initial and cloned).
388    #[test]
389    fn test_client_cloneable() -> Result<()> {
390        let url = crate::test::engine_io_server()?;
391        let sut = builder(url).build()?;
392
393        let cloned = sut.clone();
394
395        sut.connect()?;
396
397        // when the underlying socket is connected, the
398        // state should also change on the cloned one
399        assert!(sut.is_connected()?);
400        assert!(cloned.is_connected()?);
401
402        // both clients should reflect the same messages.
403        let mut iter = sut
404            .iter()
405            .map(|packet| packet.unwrap())
406            .filter(|packet| packet.packet_id != PacketId::Ping);
407
408        let mut iter_cloned = cloned
409            .iter()
410            .map(|packet| packet.unwrap())
411            .filter(|packet| packet.packet_id != PacketId::Ping);
412
413        assert_eq!(
414            iter.next(),
415            Some(Packet::new(PacketId::Message, "hello client"))
416        );
417
418        sut.emit(Packet::new(PacketId::Message, "respond"))?;
419
420        assert_eq!(
421            iter_cloned.next(),
422            Some(Packet::new(PacketId::Message, "Roger Roger"))
423        );
424
425        cloned.disconnect()?;
426
427        // when the underlying socket is disconnected, the
428        // state should also change on the cloned one
429        assert!(!sut.is_connected()?);
430        assert!(!cloned.is_connected()?);
431
432        Ok(())
433    }
434
435    #[test]
436    fn test_illegal_actions() -> Result<()> {
437        let url = crate::test::engine_io_server()?;
438        let sut = builder(url.clone()).build()?;
439
440        assert!(sut
441            .emit(Packet::new(PacketId::Close, Bytes::new()))
442            .is_err());
443
444        sut.connect()?;
445
446        assert!(sut.poll().is_ok());
447
448        assert!(builder(Url::parse("fake://fake.fake").unwrap())
449            .build_websocket()
450            .is_err());
451
452        Ok(())
453    }
454    use reqwest::header::HOST;
455
456    use crate::packet::Packet;
457
458    fn builder(url: Url) -> ClientBuilder {
459        ClientBuilder::new(url)
460            .on_open(|_| {
461                println!("Open event!");
462            })
463            .on_packet(|packet| {
464                println!("Received packet: {:?}", packet);
465            })
466            .on_data(|data| {
467                println!("Received data: {:?}", std::str::from_utf8(&data));
468            })
469            .on_close(|_| {
470                println!("Close event!");
471            })
472            .on_error(|error| {
473                println!("Error {}", error);
474            })
475    }
476
477    fn test_connection(socket: Client) -> Result<()> {
478        let socket = socket;
479
480        socket.connect().unwrap();
481
482        // TODO: 0.3.X better tests
483
484        let mut iter = socket
485            .iter()
486            .map(|packet| packet.unwrap())
487            .filter(|packet| packet.packet_id != PacketId::Ping);
488
489        assert_eq!(
490            iter.next(),
491            Some(Packet::new(PacketId::Message, "hello client"))
492        );
493
494        socket.emit(Packet::new(PacketId::Message, "respond"))?;
495
496        assert_eq!(
497            iter.next(),
498            Some(Packet::new(PacketId::Message, "Roger Roger"))
499        );
500
501        socket.close()
502    }
503
504    #[test]
505    fn test_connection_long() -> Result<()> {
506        // Long lived socket to receive pings
507        let url = crate::test::engine_io_server()?;
508        let socket = builder(url).build()?;
509
510        socket.connect()?;
511
512        let mut iter = socket.iter();
513        // hello client
514        iter.next();
515        // Ping
516        iter.next();
517
518        socket.disconnect()?;
519
520        assert!(!socket.is_connected()?);
521
522        Ok(())
523    }
524
525    #[test]
526    fn test_connection_dynamic() -> Result<()> {
527        let url = crate::test::engine_io_server()?;
528        let socket = builder(url).build()?;
529        test_connection(socket)?;
530
531        let url = crate::test::engine_io_polling_server()?;
532        let socket = builder(url).build()?;
533        test_connection(socket)
534    }
535
536    #[test]
537    fn test_connection_fallback() -> Result<()> {
538        let url = crate::test::engine_io_server()?;
539        let socket = builder(url).build_with_fallback()?;
540        test_connection(socket)?;
541
542        let url = crate::test::engine_io_polling_server()?;
543        let socket = builder(url).build_with_fallback()?;
544        test_connection(socket)
545    }
546
547    #[test]
548    fn test_connection_dynamic_secure() -> Result<()> {
549        let url = crate::test::engine_io_server_secure()?;
550        let mut builder = builder(url);
551        builder = builder.tls_config(crate::test::tls_connector()?);
552        let socket = builder.build()?;
553        test_connection(socket)
554    }
555
556    #[test]
557    fn test_connection_polling() -> Result<()> {
558        let url = crate::test::engine_io_server()?;
559        let socket = builder(url).build_polling()?;
560        test_connection(socket)
561    }
562
563    #[test]
564    fn test_connection_wss() -> Result<()> {
565        let url = crate::test::engine_io_polling_server()?;
566        assert!(builder(url).build_websocket_with_upgrade().is_err());
567
568        let host =
569            std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned());
570        let mut url = crate::test::engine_io_server_secure()?;
571
572        let mut headers = HeaderMap::default();
573        headers.insert(HOST, host);
574        let mut builder = builder(url.clone());
575
576        builder = builder.tls_config(crate::test::tls_connector()?);
577        builder = builder.headers(headers.clone());
578        let socket = builder.clone().build_websocket_with_upgrade()?;
579
580        test_connection(socket)?;
581
582        let socket = builder.build_websocket()?;
583
584        test_connection(socket)?;
585
586        url.set_scheme("wss").unwrap();
587
588        let builder = self::builder(url)
589            .tls_config(crate::test::tls_connector()?)
590            .headers(headers);
591        let socket = builder.clone().build_websocket()?;
592
593        test_connection(socket)?;
594
595        assert!(builder.build_websocket_with_upgrade().is_err());
596
597        Ok(())
598    }
599
600    #[test]
601    fn test_connection_ws() -> Result<()> {
602        let url = crate::test::engine_io_polling_server()?;
603        assert!(builder(url.clone()).build_websocket().is_err());
604        assert!(builder(url).build_websocket_with_upgrade().is_err());
605
606        let mut url = crate::test::engine_io_server()?;
607
608        let builder = builder(url.clone());
609        let socket = builder.clone().build_websocket()?;
610        test_connection(socket)?;
611
612        let socket = builder.build_websocket_with_upgrade()?;
613        test_connection(socket)?;
614
615        url.set_scheme("ws").unwrap();
616
617        let builder = self::builder(url);
618        let socket = builder.clone().build_websocket()?;
619
620        test_connection(socket)?;
621
622        assert!(builder.build_websocket_with_upgrade().is_err());
623
624        Ok(())
625    }
626
627    #[test]
628    fn test_open_invariants() -> Result<()> {
629        let url = crate::test::engine_io_server()?;
630        let illegal_url = "this is illegal";
631
632        assert!(Url::parse(illegal_url).is_err());
633
634        let invalid_protocol = "file:///tmp/foo";
635        assert!(builder(Url::parse(invalid_protocol).unwrap())
636            .build()
637            .is_err());
638
639        let sut = builder(url.clone()).build()?;
640        let _error = sut
641            .emit(Packet::new(PacketId::Close, Bytes::new()))
642            .expect_err("error");
643        assert!(matches!(Error::IllegalActionBeforeOpen(), _error));
644
645        // test missing match arm in socket constructor
646        let mut headers = HeaderMap::default();
647        // Use the correct Host header value including the port
648        let host =
649            std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost:4201".to_owned());
650        headers.insert(HOST, host);
651
652        let _ = builder(url.clone())
653            .tls_config(
654                TlsConnector::builder()
655                    .danger_accept_invalid_certs(true)
656                    .build()
657                    .unwrap(),
658            )
659            .build()?;
660        let _ = builder(url).headers(headers).build()?;
661        Ok(())
662    }
663}