Skip to main content

tf_rust_engineio/asynchronous/client/
async_client.rs

1use std::{fmt::Debug, pin::Pin};
2
3use crate::{
4    asynchronous::{async_socket::Socket as InnerSocket, generator::StreamGenerator},
5    error::Result,
6    Packet,
7};
8use async_stream::try_stream;
9use futures_util::{Stream, StreamExt};
10
11/// An engine.io client that allows interaction with the connected engine.io
12/// server. This client provides means for connecting, disconnecting and sending
13/// packets to the server.
14///
15/// ## Note:
16/// There is no need to put this Client behind an `Arc`, as the type uses `Arc`
17/// internally and provides a shared state beyond all cloned instances.
18#[derive(Clone)]
19pub struct Client {
20    pub(super) socket: InnerSocket,
21    generator: StreamGenerator<Packet>,
22}
23
24impl Client {
25    pub(super) fn new(socket: InnerSocket) -> Self {
26        Client {
27            socket: socket.clone(),
28            generator: StreamGenerator::new(Self::stream(socket)),
29        }
30    }
31
32    pub async fn close(&self) -> Result<()> {
33        self.socket.disconnect().await
34    }
35
36    /// Opens the connection to a specified server. The first Pong packet is sent
37    /// to the server to trigger the Ping-cycle.
38    pub async fn connect(&self) -> Result<()> {
39        self.socket.connect().await
40    }
41
42    /// Disconnects the connection.
43    pub async fn disconnect(&self) -> Result<()> {
44        self.socket.disconnect().await
45    }
46
47    /// Sends a packet to the server.
48    pub async fn emit(&self, packet: Packet) -> Result<()> {
49        self.socket.emit(packet).await
50    }
51
52    /// Static method that returns a generator for each element of the stream.
53    fn stream(
54        socket: InnerSocket,
55    ) -> Pin<Box<impl Stream<Item = Result<Packet>> + 'static + Send>> {
56        Box::pin(try_stream! {
57            let socket = socket.clone();
58            for await item in socket.as_stream() {
59                let packet = item?;
60                socket.handle_incoming_packet(packet.clone()).await?;
61                yield packet;
62            }
63        })
64    }
65
66    /// Check if the underlying transport client is connected.
67    pub fn is_connected(&self) -> bool {
68        self.socket.is_connected()
69    }
70}
71
72impl Stream for Client {
73    type Item = Result<Packet>;
74
75    fn poll_next(
76        mut self: Pin<&mut Self>,
77        cx: &mut std::task::Context<'_>,
78    ) -> std::task::Poll<Option<Self::Item>> {
79        self.generator.poll_next_unpin(cx)
80    }
81}
82
83impl Debug for Client {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        f.debug_struct("Client")
86            .field("socket", &self.socket)
87            .finish()
88    }
89}
90
91#[cfg(test)]
92mod test {
93
94    use super::*;
95    use crate::{asynchronous::ClientBuilder, header::HeaderMap, packet::PacketId, Error};
96    use bytes::Bytes;
97    use futures_util::StreamExt;
98    use native_tls::TlsConnector;
99    use url::Url;
100
101    /// The purpose of this test is to check whether the Client is properly cloneable or not.
102    /// As the documentation of the engine.io client states, the object needs to maintain it's internal
103    /// state when cloned and the cloned object should reflect the same state throughout the lifetime
104    /// of both objects (initial and cloned).
105    #[tokio::test]
106    async fn test_client_cloneable() -> Result<()> {
107        let url = crate::test::engine_io_server()?;
108
109        let mut sut = builder(url).build().await?;
110        let mut cloned = sut.clone();
111
112        sut.connect().await?;
113
114        // when the underlying socket is connected, the
115        // state should also change on the cloned one
116        assert!(sut.is_connected());
117        assert!(cloned.is_connected());
118
119        // both clients should reflect the same messages.
120        assert_eq!(
121            sut.next().await.unwrap()?,
122            Packet::new(PacketId::Message, "hello client")
123        );
124
125        sut.emit(Packet::new(PacketId::Message, "respond")).await?;
126
127        assert_eq!(
128            cloned.next().await.unwrap()?,
129            Packet::new(PacketId::Message, "Roger Roger")
130        );
131
132        cloned.disconnect().await?;
133
134        // when the underlying socket is disconnected, the
135        // state should also change on the cloned one
136        assert!(!sut.is_connected());
137        assert!(!cloned.is_connected());
138
139        Ok(())
140    }
141
142    #[tokio::test]
143    async fn test_illegal_actions() -> Result<()> {
144        let url = crate::test::engine_io_server()?;
145        let mut sut = builder(url.clone()).build().await?;
146
147        assert!(sut
148            .emit(Packet::new(PacketId::Close, Bytes::new()))
149            .await
150            .is_err());
151
152        sut.connect().await?;
153
154        assert!(sut.next().await.unwrap().is_ok());
155
156        assert!(builder(Url::parse("fake://fake.fake").unwrap())
157            .build_websocket()
158            .await
159            .is_err());
160
161        sut.disconnect().await?;
162
163        Ok(())
164    }
165
166    use reqwest::header::HOST;
167
168    use crate::packet::Packet;
169
170    fn builder(url: Url) -> ClientBuilder {
171        ClientBuilder::new(url)
172            .on_open(|_| {
173                Box::pin(async {
174                    println!("Open event!");
175                })
176            })
177            .on_packet(|packet| {
178                Box::pin(async move {
179                    println!("Received packet: {:?}", packet);
180                })
181            })
182            .on_data(|data| {
183                Box::pin(async move {
184                    println!("Received data: {:?}", std::str::from_utf8(&data));
185                })
186            })
187            .on_close(|_| {
188                Box::pin(async {
189                    println!("Close event!");
190                })
191            })
192            .on_error(|error| {
193                Box::pin(async move {
194                    println!("Error {}", error);
195                })
196            })
197    }
198
199    async fn test_connection(socket: Client) -> Result<()> {
200        let mut socket = socket;
201
202        socket.connect().await.unwrap();
203
204        assert_eq!(
205            socket.next().await.unwrap()?,
206            Packet::new(PacketId::Message, "hello client")
207        );
208        println!("received msg, about to send");
209
210        socket
211            .emit(Packet::new(PacketId::Message, "respond"))
212            .await?;
213
214        println!("send msg");
215
216        assert_eq!(
217            socket.next().await.unwrap()?,
218            Packet::new(PacketId::Message, "Roger Roger")
219        );
220        println!("received 2");
221
222        socket.close().await
223    }
224
225    #[tokio::test]
226    async fn test_connection_long() -> Result<()> {
227        // Long lived socket to receive pings
228        let url = crate::test::engine_io_server()?;
229        let mut socket = builder(url).build().await?;
230
231        socket.connect().await?;
232
233        // hello client
234        assert!(matches!(
235            socket.next().await.unwrap()?,
236            Packet {
237                packet_id: PacketId::Message,
238                ..
239            }
240        ));
241        // Ping
242        assert!(matches!(
243            socket.next().await.unwrap()?,
244            Packet {
245                packet_id: PacketId::Ping,
246                ..
247            }
248        ));
249
250        socket.disconnect().await?;
251
252        assert!(!socket.is_connected());
253
254        Ok(())
255    }
256
257    #[tokio::test]
258    async fn test_connection_dynamic() -> Result<()> {
259        let url = crate::test::engine_io_server()?;
260        let socket = builder(url).build().await?;
261        test_connection(socket).await?;
262
263        let url = crate::test::engine_io_polling_server()?;
264        let socket = builder(url).build().await?;
265        test_connection(socket).await
266    }
267
268    #[tokio::test]
269    async fn test_connection_fallback() -> Result<()> {
270        let url = crate::test::engine_io_server()?;
271        let socket = builder(url).build_with_fallback().await?;
272        test_connection(socket).await?;
273
274        let url = crate::test::engine_io_polling_server()?;
275        let socket = builder(url).build_with_fallback().await?;
276        test_connection(socket).await
277    }
278
279    #[tokio::test]
280    async fn test_connection_dynamic_secure() -> Result<()> {
281        let url = crate::test::engine_io_server_secure()?;
282        let mut socket_builder = builder(url);
283        socket_builder = socket_builder.tls_config(crate::test::tls_connector()?);
284        let socket = socket_builder.build().await?;
285        test_connection(socket).await
286    }
287
288    #[tokio::test]
289    async fn test_connection_polling() -> Result<()> {
290        let url = crate::test::engine_io_server()?;
291        let socket = builder(url).build_polling().await?;
292        test_connection(socket).await
293    }
294
295    #[tokio::test]
296    async fn test_connection_wss() -> Result<()> {
297        let url = crate::test::engine_io_polling_server()?;
298        assert!(builder(url).build_websocket_with_upgrade().await.is_err());
299
300        let host =
301            std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned());
302        let mut url = crate::test::engine_io_server_secure()?;
303
304        let mut headers = HeaderMap::default();
305        headers.insert(HOST, host);
306        let mut builder = builder(url.clone());
307
308        builder = builder.tls_config(crate::test::tls_connector()?);
309        builder = builder.headers(headers.clone());
310        let socket = builder.clone().build_websocket_with_upgrade().await?;
311
312        test_connection(socket).await?;
313
314        let socket = builder.build_websocket().await?;
315
316        test_connection(socket).await?;
317
318        url.set_scheme("wss").unwrap();
319
320        let builder = self::builder(url)
321            .tls_config(crate::test::tls_connector()?)
322            .headers(headers);
323        let socket = builder.clone().build_websocket().await?;
324
325        test_connection(socket).await?;
326
327        assert!(builder.build_websocket_with_upgrade().await.is_err());
328
329        Ok(())
330    }
331
332    #[tokio::test]
333    async fn test_connection_ws() -> Result<()> {
334        let url = crate::test::engine_io_polling_server()?;
335        assert!(builder(url.clone()).build_websocket().await.is_err());
336        assert!(builder(url).build_websocket_with_upgrade().await.is_err());
337
338        let mut url = crate::test::engine_io_server()?;
339
340        let builder = builder(url.clone());
341        let socket = builder.clone().build_websocket().await?;
342        test_connection(socket).await?;
343
344        let socket = builder.build_websocket_with_upgrade().await?;
345        test_connection(socket).await?;
346
347        url.set_scheme("ws").unwrap();
348
349        let builder = self::builder(url);
350        let socket = builder.clone().build_websocket().await?;
351
352        test_connection(socket).await?;
353
354        assert!(builder.build_websocket_with_upgrade().await.is_err());
355
356        Ok(())
357    }
358
359    #[tokio::test]
360    async fn test_open_invariants() -> Result<()> {
361        let url = crate::test::engine_io_server()?;
362        let illegal_url = "this is illegal";
363
364        assert!(Url::parse(illegal_url).is_err());
365
366        let invalid_protocol = "file:///tmp/foo";
367        assert!(builder(Url::parse(invalid_protocol).unwrap())
368            .build()
369            .await
370            .is_err());
371
372        let sut = builder(url.clone()).build().await?;
373        let _error = sut
374            .emit(Packet::new(PacketId::Close, Bytes::new()))
375            .await
376            .expect_err("error");
377        assert!(matches!(Error::IllegalActionBeforeOpen(), _error));
378
379        // test missing match arm in socket constructor
380        let mut headers = HeaderMap::default();
381        // Use the correct Host header value including the port
382        let host =
383            std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost:4201".to_owned());
384        headers.insert(HOST, host);
385
386        let _ = builder(url.clone())
387            .tls_config(
388                TlsConnector::builder()
389                    .danger_accept_invalid_certs(true)
390                    .build()
391                    .unwrap(),
392            )
393            .build()
394            .await?;
395        let _ = builder(url).headers(headers).build().await?;
396        Ok(())
397    }
398}