1use super::super::socket::Socket as InnerSocket;
2use crate::callback::OptionalCallback;
3use crate::socket::DEFAULT_MAX_POLL_TIMEOUT;
4use crate::transport::Transport;
5
6use crate::error::{Error, Result};
7use crate::header::HeaderMap;
8use crate::packet::{HandshakePacket, Packet, PacketId};
9use crate::transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport};
10use crate::ENGINE_IO_VERSION;
11use bytes::Bytes;
12use native_tls::TlsConnector;
13use std::convert::TryFrom;
14use std::convert::TryInto;
15use std::fmt::Debug;
16use url::Url;
17
18#[derive(Clone, Debug)]
26pub struct Client {
27 socket: InnerSocket,
28}
29
30#[derive(Clone, Debug)]
31pub struct ClientBuilder {
32 url: Url,
33 tls_config: Option<TlsConnector>,
34 headers: Option<HeaderMap>,
35 handshake: Option<HandshakePacket>,
36 on_error: OptionalCallback<String>,
37 on_open: OptionalCallback<()>,
38 on_close: OptionalCallback<()>,
39 on_data: OptionalCallback<Bytes>,
40 on_packet: OptionalCallback<Packet>,
41}
42
43impl ClientBuilder {
44 pub fn new(url: Url) -> Self {
45 let mut url = url;
46 url.query_pairs_mut()
47 .append_pair("EIO", &ENGINE_IO_VERSION.to_string());
48
49 if url.path() == "/" {
51 url.set_path("/engine.io/");
52 }
53 ClientBuilder {
54 url,
55 headers: None,
56 tls_config: None,
57 handshake: None,
58 on_close: OptionalCallback::default(),
59 on_data: OptionalCallback::default(),
60 on_error: OptionalCallback::default(),
61 on_open: OptionalCallback::default(),
62 on_packet: OptionalCallback::default(),
63 }
64 }
65
66 pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
68 self.tls_config = Some(tls_config);
69 self
70 }
71
72 pub fn headers(mut self, headers: HeaderMap) -> Self {
74 self.headers = Some(headers);
75 self
76 }
77
78 pub fn on_close<T>(mut self, callback: T) -> Self
80 where
81 T: Fn(()) + 'static + Sync + Send,
82 {
83 self.on_close = OptionalCallback::new(callback);
84 self
85 }
86
87 pub fn on_data<T>(mut self, callback: T) -> Self
89 where
90 T: Fn(Bytes) + 'static + Sync + Send,
91 {
92 self.on_data = OptionalCallback::new(callback);
93 self
94 }
95
96 pub fn on_error<T>(mut self, callback: T) -> Self
98 where
99 T: Fn(String) + 'static + Sync + Send,
100 {
101 self.on_error = OptionalCallback::new(callback);
102 self
103 }
104
105 pub fn on_open<T>(mut self, callback: T) -> Self
107 where
108 T: Fn(()) + 'static + Sync + Send,
109 {
110 self.on_open = OptionalCallback::new(callback);
111 self
112 }
113
114 pub fn on_packet<T>(mut self, callback: T) -> Self
116 where
117 T: Fn(Packet) + 'static + Sync + Send,
118 {
119 self.on_packet = OptionalCallback::new(callback);
120 self
121 }
122
123 fn handshake_with_transport<T: Transport>(&mut self, transport: &T) -> Result<()> {
125 if self.handshake.is_some() {
127 return Ok(());
128 }
129
130 let mut url = self.url.clone();
131
132 let handshake: HandshakePacket =
133 Packet::try_from(transport.poll(DEFAULT_MAX_POLL_TIMEOUT)?)?.try_into()?;
134
135 url.query_pairs_mut().append_pair("sid", &handshake.sid[..]);
137
138 self.handshake = Some(handshake);
139
140 self.url = url;
141
142 Ok(())
143 }
144
145 fn handshake(&mut self) -> Result<()> {
146 if self.handshake.is_some() {
147 return Ok(());
148 }
149
150 let transport = PollingTransport::new(
152 self.url.clone(),
153 self.tls_config.clone(),
154 self.headers.clone().map(|v| v.try_into().unwrap()),
155 );
156
157 self.handshake_with_transport(&transport)
158 }
159
160 pub fn build(mut self) -> Result<Client> {
162 self.handshake()?;
163
164 if self.websocket_upgrade()? {
165 self.build_websocket_with_upgrade()
166 } else {
167 self.build_polling()
168 }
169 }
170
171 pub fn build_polling(mut self) -> Result<Client> {
173 self.handshake()?;
174
175 let transport = PollingTransport::new(
177 self.url,
178 self.tls_config,
179 self.headers.map(|v| v.try_into().unwrap()),
180 );
181
182 Ok(Client {
184 socket: InnerSocket::new(
185 transport.into(),
186 self.handshake.unwrap(),
187 self.on_close,
188 self.on_data,
189 self.on_error,
190 self.on_open,
191 self.on_packet,
192 ),
193 })
194 }
195
196 pub fn build_websocket_with_upgrade(mut self) -> Result<Client> {
198 self.handshake()?;
199
200 if self.websocket_upgrade()? {
201 self.build_websocket()
202 } else {
203 Err(Error::IllegalWebsocketUpgrade())
204 }
205 }
206
207 pub fn build_websocket(mut self) -> Result<Client> {
209 let url = url::Url::parse(self.url.as_ref())?;
211
212 let headers: Option<http::HeaderMap> = if let Some(map) = self.headers.clone() {
213 Some(map.try_into()?)
214 } else {
215 None
216 };
217
218 match url.scheme() {
219 "http" | "ws" => {
220 let transport = WebsocketTransport::new(url, headers)?;
221 if self.handshake.is_some() {
222 transport.upgrade()?;
223 } else {
224 self.handshake_with_transport(&transport)?;
225 }
226 Ok(Client {
229 socket: InnerSocket::new(
230 transport.into(),
231 self.handshake.unwrap(),
232 self.on_close,
233 self.on_data,
234 self.on_error,
235 self.on_open,
236 self.on_packet,
237 ),
238 })
239 }
240 "https" | "wss" => {
241 let transport =
242 WebsocketSecureTransport::new(url, self.tls_config.clone(), headers)?;
243 if self.handshake.is_some() {
244 transport.upgrade()?;
245 } else {
246 self.handshake_with_transport(&transport)?;
247 }
248 Ok(Client {
251 socket: InnerSocket::new(
252 transport.into(),
253 self.handshake.unwrap(),
254 self.on_close,
255 self.on_data,
256 self.on_error,
257 self.on_open,
258 self.on_packet,
259 ),
260 })
261 }
262 _ => Err(Error::InvalidUrlScheme(url.scheme().to_string())),
263 }
264 }
265
266 pub fn build_with_fallback(self) -> Result<Client> {
269 let result = self.clone().build();
270 if result.is_err() {
271 self.build_polling()
272 } else {
273 result
274 }
275 }
276
277 fn websocket_upgrade(&mut self) -> Result<bool> {
279 Ok(self
281 .handshake
282 .as_ref()
283 .unwrap()
284 .upgrades
285 .iter()
286 .any(|upgrade| upgrade.to_lowercase() == *"websocket"))
287 }
288}
289
290impl Client {
291 pub fn close(&self) -> Result<()> {
292 self.socket.disconnect()
293 }
294
295 pub fn connect(&self) -> Result<()> {
298 self.socket.connect()
299 }
300
301 pub fn disconnect(&self) -> Result<()> {
303 self.socket.disconnect()
304 }
305
306 pub fn emit(&self, packet: Packet) -> Result<()> {
308 self.socket.emit(packet)
309 }
310
311 #[doc(hidden)]
313 pub fn poll(&self) -> Result<Option<Packet>> {
314 let packet = self.socket.poll()?;
315 if let Some(packet) = packet {
316 self.socket.handle_packet(packet.clone());
318 match packet.packet_id {
319 PacketId::MessageBinary => {
320 self.socket.handle_data(packet.data.clone());
321 }
322 PacketId::Message => {
323 self.socket.handle_data(packet.data.clone());
324 }
325 PacketId::Close => {
326 self.socket.handle_close();
327 }
328 PacketId::Open => {
329 unreachable!("Won't happen as we open the connection beforehand");
330 }
331 PacketId::Upgrade => {
332 }
334 PacketId::Ping => {
335 self.socket.pinged()?;
336 self.emit(Packet::new(PacketId::Pong, Bytes::new()))?;
337 }
338 PacketId::Pong => {
339 unreachable!();
342 }
343 PacketId::Noop => (),
344 }
345 Ok(Some(packet))
346 } else {
347 Ok(None)
348 }
349 }
350
351 pub fn is_connected(&self) -> Result<bool> {
353 self.socket.is_connected()
354 }
355
356 pub fn iter(&self) -> Iter<'_> {
357 Iter { socket: self }
358 }
359}
360
361#[derive(Clone)]
362pub struct Iter<'a> {
363 socket: &'a Client,
364}
365
366impl<'a> Iterator for Iter<'a> {
367 type Item = Result<Packet>;
368 fn next(&mut self) -> std::option::Option<<Self as std::iter::Iterator>::Item> {
369 match self.socket.poll() {
370 Ok(Some(packet)) => Some(Ok(packet)),
371 Ok(None) => None,
372 Err(err) => Some(Err(err)),
373 }
374 }
375}
376
377#[cfg(test)]
378mod test {
379
380 use crate::packet::PacketId;
381
382 use super::*;
383
384 #[test]
389 fn test_client_cloneable() -> Result<()> {
390 let url = crate::test::engine_io_server()?;
391 let sut = builder(url).build()?;
392
393 let cloned = sut.clone();
394
395 sut.connect()?;
396
397 assert!(sut.is_connected()?);
400 assert!(cloned.is_connected()?);
401
402 let mut iter = sut
404 .iter()
405 .map(|packet| packet.unwrap())
406 .filter(|packet| packet.packet_id != PacketId::Ping);
407
408 let mut iter_cloned = cloned
409 .iter()
410 .map(|packet| packet.unwrap())
411 .filter(|packet| packet.packet_id != PacketId::Ping);
412
413 assert_eq!(
414 iter.next(),
415 Some(Packet::new(PacketId::Message, "hello client"))
416 );
417
418 sut.emit(Packet::new(PacketId::Message, "respond"))?;
419
420 assert_eq!(
421 iter_cloned.next(),
422 Some(Packet::new(PacketId::Message, "Roger Roger"))
423 );
424
425 cloned.disconnect()?;
426
427 assert!(!sut.is_connected()?);
430 assert!(!cloned.is_connected()?);
431
432 Ok(())
433 }
434
435 #[test]
436 fn test_illegal_actions() -> Result<()> {
437 let url = crate::test::engine_io_server()?;
438 let sut = builder(url.clone()).build()?;
439
440 assert!(sut
441 .emit(Packet::new(PacketId::Close, Bytes::new()))
442 .is_err());
443
444 sut.connect()?;
445
446 assert!(sut.poll().is_ok());
447
448 assert!(builder(Url::parse("fake://fake.fake").unwrap())
449 .build_websocket()
450 .is_err());
451
452 Ok(())
453 }
454 use reqwest::header::HOST;
455
456 use crate::packet::Packet;
457
458 fn builder(url: Url) -> ClientBuilder {
459 ClientBuilder::new(url)
460 .on_open(|_| {
461 println!("Open event!");
462 })
463 .on_packet(|packet| {
464 println!("Received packet: {:?}", packet);
465 })
466 .on_data(|data| {
467 println!("Received data: {:?}", std::str::from_utf8(&data));
468 })
469 .on_close(|_| {
470 println!("Close event!");
471 })
472 .on_error(|error| {
473 println!("Error {}", error);
474 })
475 }
476
477 fn test_connection(socket: Client) -> Result<()> {
478 let socket = socket;
479
480 socket.connect().unwrap();
481
482 let mut iter = socket
485 .iter()
486 .map(|packet| packet.unwrap())
487 .filter(|packet| packet.packet_id != PacketId::Ping);
488
489 assert_eq!(
490 iter.next(),
491 Some(Packet::new(PacketId::Message, "hello client"))
492 );
493
494 socket.emit(Packet::new(PacketId::Message, "respond"))?;
495
496 assert_eq!(
497 iter.next(),
498 Some(Packet::new(PacketId::Message, "Roger Roger"))
499 );
500
501 socket.close()
502 }
503
504 #[test]
505 fn test_connection_long() -> Result<()> {
506 let url = crate::test::engine_io_server()?;
508 let socket = builder(url).build()?;
509
510 socket.connect()?;
511
512 let mut iter = socket.iter();
513 iter.next();
515 iter.next();
517
518 socket.disconnect()?;
519
520 assert!(!socket.is_connected()?);
521
522 Ok(())
523 }
524
525 #[test]
526 fn test_connection_dynamic() -> Result<()> {
527 let url = crate::test::engine_io_server()?;
528 let socket = builder(url).build()?;
529 test_connection(socket)?;
530
531 let url = crate::test::engine_io_polling_server()?;
532 let socket = builder(url).build()?;
533 test_connection(socket)
534 }
535
536 #[test]
537 fn test_connection_fallback() -> Result<()> {
538 let url = crate::test::engine_io_server()?;
539 let socket = builder(url).build_with_fallback()?;
540 test_connection(socket)?;
541
542 let url = crate::test::engine_io_polling_server()?;
543 let socket = builder(url).build_with_fallback()?;
544 test_connection(socket)
545 }
546
547 #[test]
548 fn test_connection_dynamic_secure() -> Result<()> {
549 let url = crate::test::engine_io_server_secure()?;
550 let mut builder = builder(url);
551 builder = builder.tls_config(crate::test::tls_connector()?);
552 let socket = builder.build()?;
553 test_connection(socket)
554 }
555
556 #[test]
557 fn test_connection_polling() -> Result<()> {
558 let url = crate::test::engine_io_server()?;
559 let socket = builder(url).build_polling()?;
560 test_connection(socket)
561 }
562
563 #[test]
564 fn test_connection_wss() -> Result<()> {
565 let url = crate::test::engine_io_polling_server()?;
566 assert!(builder(url).build_websocket_with_upgrade().is_err());
567
568 let host =
569 std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned());
570 let mut url = crate::test::engine_io_server_secure()?;
571
572 let mut headers = HeaderMap::default();
573 headers.insert(HOST, host);
574 let mut builder = builder(url.clone());
575
576 builder = builder.tls_config(crate::test::tls_connector()?);
577 builder = builder.headers(headers.clone());
578 let socket = builder.clone().build_websocket_with_upgrade()?;
579
580 test_connection(socket)?;
581
582 let socket = builder.build_websocket()?;
583
584 test_connection(socket)?;
585
586 url.set_scheme("wss").unwrap();
587
588 let builder = self::builder(url)
589 .tls_config(crate::test::tls_connector()?)
590 .headers(headers);
591 let socket = builder.clone().build_websocket()?;
592
593 test_connection(socket)?;
594
595 assert!(builder.build_websocket_with_upgrade().is_err());
596
597 Ok(())
598 }
599
600 #[test]
601 fn test_connection_ws() -> Result<()> {
602 let url = crate::test::engine_io_polling_server()?;
603 assert!(builder(url.clone()).build_websocket().is_err());
604 assert!(builder(url).build_websocket_with_upgrade().is_err());
605
606 let mut url = crate::test::engine_io_server()?;
607
608 let builder = builder(url.clone());
609 let socket = builder.clone().build_websocket()?;
610 test_connection(socket)?;
611
612 let socket = builder.build_websocket_with_upgrade()?;
613 test_connection(socket)?;
614
615 url.set_scheme("ws").unwrap();
616
617 let builder = self::builder(url);
618 let socket = builder.clone().build_websocket()?;
619
620 test_connection(socket)?;
621
622 assert!(builder.build_websocket_with_upgrade().is_err());
623
624 Ok(())
625 }
626
627 #[test]
628 fn test_open_invariants() -> Result<()> {
629 let url = crate::test::engine_io_server()?;
630 let illegal_url = "this is illegal";
631
632 assert!(Url::parse(illegal_url).is_err());
633
634 let invalid_protocol = "file:///tmp/foo";
635 assert!(builder(Url::parse(invalid_protocol).unwrap())
636 .build()
637 .is_err());
638
639 let sut = builder(url.clone()).build()?;
640 let _error = sut
641 .emit(Packet::new(PacketId::Close, Bytes::new()))
642 .expect_err("error");
643 assert!(matches!(Error::IllegalActionBeforeOpen(), _error));
644
645 let mut headers = HeaderMap::default();
647 let host =
649 std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost:4201".to_owned());
650 headers.insert(HOST, host);
651
652 let _ = builder(url.clone())
653 .tls_config(
654 TlsConnector::builder()
655 .danger_accept_invalid_certs(true)
656 .build()
657 .unwrap(),
658 )
659 .build()?;
660 let _ = builder(url).headers(headers).build()?;
661 Ok(())
662 }
663}