Skip to main content

tf_rust_engineio/
packet.rs

1use base64::{engine::general_purpose, Engine as _};
2use bytes::{BufMut, Bytes, BytesMut};
3use serde::{Deserialize, Serialize};
4use std::char;
5use std::convert::TryFrom;
6use std::convert::TryInto;
7use std::fmt::{Display, Formatter, Result as FmtResult, Write};
8use std::ops::Index;
9
10use crate::error::{Error, Result};
11/// Enumeration of the `engine.io` `Packet` types.
12#[derive(Copy, Clone, Eq, PartialEq, Debug)]
13pub enum PacketId {
14    Open,
15    Close,
16    Ping,
17    Pong,
18    Message,
19    // A type of message that is base64 encoded
20    MessageBinary,
21    Upgrade,
22    Noop,
23}
24
25impl PacketId {
26    /// Returns the byte that represents the [`PacketId`] as a [`char`].
27    fn to_string_byte(self) -> u8 {
28        match self {
29            Self::MessageBinary => b'b',
30            _ => u8::from(self) + b'0',
31        }
32    }
33}
34
35impl Display for PacketId {
36    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
37        f.write_char(self.to_string_byte() as char)
38    }
39}
40
41impl From<PacketId> for u8 {
42    fn from(packet_id: PacketId) -> Self {
43        match packet_id {
44            PacketId::Open => 0,
45            PacketId::Close => 1,
46            PacketId::Ping => 2,
47            PacketId::Pong => 3,
48            PacketId::Message => 4,
49            PacketId::MessageBinary => 4,
50            PacketId::Upgrade => 5,
51            PacketId::Noop => 6,
52        }
53    }
54}
55
56impl TryFrom<u8> for PacketId {
57    type Error = Error;
58    /// Converts a byte into the corresponding `PacketId`.
59    fn try_from(b: u8) -> Result<PacketId> {
60        match b {
61            0 | b'0' => Ok(PacketId::Open),
62            1 | b'1' => Ok(PacketId::Close),
63            2 | b'2' => Ok(PacketId::Ping),
64            3 | b'3' => Ok(PacketId::Pong),
65            4 | b'4' => Ok(PacketId::Message),
66            5 | b'5' => Ok(PacketId::Upgrade),
67            6 | b'6' => Ok(PacketId::Noop),
68            _ => Err(Error::InvalidPacketId(b)),
69        }
70    }
71}
72
73/// A `Packet` sent via the `engine.io` protocol.
74#[derive(Debug, Clone, Eq, PartialEq)]
75pub struct Packet {
76    pub packet_id: PacketId,
77    pub data: Bytes,
78}
79
80/// Data which gets exchanged in a handshake as defined by the server.
81#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
82pub struct HandshakePacket {
83    pub sid: String,
84    pub upgrades: Vec<String>,
85    #[serde(rename = "pingInterval")]
86    pub ping_interval: u64,
87    #[serde(rename = "pingTimeout")]
88    pub ping_timeout: u64,
89}
90
91impl TryFrom<Packet> for HandshakePacket {
92    type Error = Error;
93    fn try_from(packet: Packet) -> Result<HandshakePacket> {
94        Ok(serde_json::from_slice(packet.data[..].as_ref())?)
95    }
96}
97
98impl Packet {
99    /// Creates a new `Packet`.
100    pub fn new<T: Into<Bytes>>(packet_id: PacketId, data: T) -> Self {
101        Packet {
102            packet_id,
103            data: data.into(),
104        }
105    }
106}
107
108impl TryFrom<Bytes> for Packet {
109    type Error = Error;
110    /// Decodes a single `Packet` from an `u8` byte stream.
111    fn try_from(
112        bytes: Bytes,
113    ) -> std::result::Result<Self, <Self as std::convert::TryFrom<Bytes>>::Error> {
114        if bytes.is_empty() {
115            return Err(Error::IncompletePacket());
116        }
117
118        let is_base64 = *bytes.first().ok_or(Error::IncompletePacket())? == b'b';
119
120        // only 'messages' packets could be encoded
121        let packet_id = if is_base64 {
122            PacketId::MessageBinary
123        } else {
124            (*bytes.first().ok_or(Error::IncompletePacket())?).try_into()?
125        };
126
127        if bytes.len() == 1 && packet_id == PacketId::Message {
128            return Err(Error::IncompletePacket());
129        }
130
131        let data: Bytes = bytes.slice(1..);
132
133        Ok(Packet {
134            packet_id,
135            data: if is_base64 {
136                Bytes::from(general_purpose::STANDARD.decode(data.as_ref())?)
137            } else {
138                data
139            },
140        })
141    }
142}
143
144impl From<Packet> for Bytes {
145    /// Encodes a `Packet` into an `u8` byte stream.
146    fn from(packet: Packet) -> Self {
147        let mut result = BytesMut::with_capacity(packet.data.len() + 1);
148        result.put_u8(packet.packet_id.to_string_byte());
149        if packet.packet_id == PacketId::MessageBinary {
150            result.extend(general_purpose::STANDARD.encode(packet.data).into_bytes());
151        } else {
152            result.put(packet.data);
153        }
154        result.freeze()
155    }
156}
157
158#[derive(Debug, Clone)]
159pub(crate) struct Payload(Vec<Packet>);
160
161impl Payload {
162    // see https://en.wikipedia.org/wiki/Delimiter#ASCII_delimited_text
163    const SEPARATOR: char = '\x1e';
164
165    #[cfg(test)]
166    pub fn len(&self) -> usize {
167        self.0.len()
168    }
169}
170
171impl TryFrom<Bytes> for Payload {
172    type Error = Error;
173    /// Decodes a `payload` which in the `engine.io` context means a chain of normal
174    /// packets separated by a certain SEPARATOR, in this case the delimiter `\x30`.
175    fn try_from(payload: Bytes) -> Result<Self> {
176        payload
177            .split(|&c| c as char == Self::SEPARATOR)
178            .map(|slice| Packet::try_from(payload.slice_ref(slice)))
179            .collect::<Result<Vec<_>>>()
180            .map(Self)
181    }
182}
183
184impl TryFrom<Payload> for Bytes {
185    type Error = Error;
186    /// Encodes a payload. Payload in the `engine.io` context means a chain of
187    /// normal `packets` separated by a SEPARATOR, in this case the delimiter
188    /// `\x30`.
189    fn try_from(packets: Payload) -> Result<Self> {
190        let mut buf = BytesMut::new();
191        for packet in packets {
192            // at the moment no base64 encoding is used
193            buf.extend(Bytes::from(packet.clone()));
194            buf.put_u8(Payload::SEPARATOR as u8);
195        }
196
197        // remove the last separator
198        let _ = buf.split_off(buf.len() - 1);
199        Ok(buf.freeze())
200    }
201}
202
203#[derive(Clone, Debug)]
204pub struct IntoIter {
205    iter: std::vec::IntoIter<Packet>,
206}
207
208impl Iterator for IntoIter {
209    type Item = Packet;
210    fn next(&mut self) -> std::option::Option<<Self as std::iter::Iterator>::Item> {
211        self.iter.next()
212    }
213}
214
215impl IntoIterator for Payload {
216    type Item = Packet;
217    type IntoIter = IntoIter;
218    fn into_iter(self) -> <Self as std::iter::IntoIterator>::IntoIter {
219        IntoIter {
220            iter: self.0.into_iter(),
221        }
222    }
223}
224
225impl Index<usize> for Payload {
226    type Output = Packet;
227    fn index(&self, index: usize) -> &Packet {
228        &self.0[index]
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_packet_error() {
238        let err = Packet::try_from(BytesMut::with_capacity(10).freeze());
239        assert!(err.is_err())
240    }
241
242    #[test]
243    fn test_is_reflexive() {
244        let data = Bytes::from_static(b"1Hello World");
245        let packet = Packet::try_from(data).unwrap();
246
247        assert_eq!(packet.packet_id, PacketId::Close);
248        assert_eq!(packet.data, Bytes::from_static(b"Hello World"));
249
250        let data = Bytes::from_static(b"1Hello World");
251        assert_eq!(Bytes::from(packet), data);
252    }
253
254    #[test]
255    fn test_binary_packet() {
256        // SGVsbG8= is the encoded string for 'Hello'
257        let data = Bytes::from_static(b"bSGVsbG8=");
258        let packet = Packet::try_from(data.clone()).unwrap();
259
260        assert_eq!(packet.packet_id, PacketId::MessageBinary);
261        assert_eq!(packet.data, Bytes::from_static(b"Hello"));
262
263        assert_eq!(Bytes::from(packet), data);
264    }
265
266    #[test]
267    fn test_decode_payload() -> Result<()> {
268        let data = Bytes::from_static(b"1Hello\x1e1HelloWorld");
269        let packets = Payload::try_from(data)?;
270
271        assert_eq!(packets[0].packet_id, PacketId::Close);
272        assert_eq!(packets[0].data, Bytes::from_static(b"Hello"));
273        assert_eq!(packets[1].packet_id, PacketId::Close);
274        assert_eq!(packets[1].data, Bytes::from_static(b"HelloWorld"));
275
276        let data = "1Hello\x1e1HelloWorld".to_owned().into_bytes();
277        assert_eq!(Bytes::try_from(packets).unwrap(), data);
278
279        Ok(())
280    }
281
282    #[test]
283    fn test_binary_payload() {
284        let data = Bytes::from_static(b"bSGVsbG8=\x1ebSGVsbG9Xb3JsZA==\x1ebSGVsbG8=");
285        let packets = Payload::try_from(data.clone()).unwrap();
286
287        assert!(packets.len() == 3);
288        assert_eq!(packets[0].packet_id, PacketId::MessageBinary);
289        assert_eq!(packets[0].data, Bytes::from_static(b"Hello"));
290        assert_eq!(packets[1].packet_id, PacketId::MessageBinary);
291        assert_eq!(packets[1].data, Bytes::from_static(b"HelloWorld"));
292        assert_eq!(packets[2].packet_id, PacketId::MessageBinary);
293        assert_eq!(packets[2].data, Bytes::from_static(b"Hello"));
294
295        assert_eq!(Bytes::try_from(packets).unwrap(), data);
296    }
297
298    #[test]
299    fn test_packet_id_conversion_and_incompl_packet() -> Result<()> {
300        let sut = Packet::try_from(Bytes::from_static(b"4"));
301        assert!(sut.is_err());
302        let _sut = sut.unwrap_err();
303        assert!(matches!(Error::IncompletePacket, _sut));
304
305        assert_eq!(PacketId::MessageBinary.to_string(), "b");
306
307        let sut = PacketId::try_from(b'0')?;
308        assert_eq!(sut, PacketId::Open);
309        assert_eq!(sut.to_string(), "0");
310
311        let sut = PacketId::try_from(b'1')?;
312        assert_eq!(sut, PacketId::Close);
313        assert_eq!(sut.to_string(), "1");
314
315        let sut = PacketId::try_from(b'2')?;
316        assert_eq!(sut, PacketId::Ping);
317        assert_eq!(sut.to_string(), "2");
318
319        let sut = PacketId::try_from(b'3')?;
320        assert_eq!(sut, PacketId::Pong);
321        assert_eq!(sut.to_string(), "3");
322
323        let sut = PacketId::try_from(b'4')?;
324        assert_eq!(sut, PacketId::Message);
325        assert_eq!(sut.to_string(), "4");
326
327        let sut = PacketId::try_from(b'5')?;
328        assert_eq!(sut, PacketId::Upgrade);
329        assert_eq!(sut.to_string(), "5");
330
331        let sut = PacketId::try_from(b'6')?;
332        assert_eq!(sut, PacketId::Noop);
333        assert_eq!(sut.to_string(), "6");
334
335        let sut = PacketId::try_from(42);
336        assert!(sut.is_err());
337        assert!(matches!(sut.unwrap_err(), Error::InvalidPacketId(42)));
338
339        Ok(())
340    }
341
342    #[test]
343    fn test_handshake_packet() {
344        assert!(
345            HandshakePacket::try_from(Packet::new(PacketId::Message, Bytes::from("test"))).is_err()
346        );
347        let packet = HandshakePacket {
348            ping_interval: 10000,
349            ping_timeout: 1000,
350            sid: "Test".to_owned(),
351            upgrades: vec!["websocket".to_owned(), "test".to_owned()],
352        };
353        let encoded: String = serde_json::to_string(&packet).unwrap();
354
355        assert_eq!(
356            packet,
357            HandshakePacket::try_from(Packet::new(PacketId::Message, Bytes::from(encoded)))
358                .unwrap()
359        );
360    }
361}