1use base64::{engine::general_purpose, Engine as _};
2use bytes::{BufMut, Bytes, BytesMut};
3use serde::{Deserialize, Serialize};
4use std::char;
5use std::convert::TryFrom;
6use std::convert::TryInto;
7use std::fmt::{Display, Formatter, Result as FmtResult, Write};
8use std::ops::Index;
9
10use crate::error::{Error, Result};
11#[derive(Copy, Clone, Eq, PartialEq, Debug)]
13pub enum PacketId {
14 Open,
15 Close,
16 Ping,
17 Pong,
18 Message,
19 MessageBinary,
21 Upgrade,
22 Noop,
23}
24
25impl PacketId {
26 fn to_string_byte(self) -> u8 {
28 match self {
29 Self::MessageBinary => b'b',
30 _ => u8::from(self) + b'0',
31 }
32 }
33}
34
35impl Display for PacketId {
36 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
37 f.write_char(self.to_string_byte() as char)
38 }
39}
40
41impl From<PacketId> for u8 {
42 fn from(packet_id: PacketId) -> Self {
43 match packet_id {
44 PacketId::Open => 0,
45 PacketId::Close => 1,
46 PacketId::Ping => 2,
47 PacketId::Pong => 3,
48 PacketId::Message => 4,
49 PacketId::MessageBinary => 4,
50 PacketId::Upgrade => 5,
51 PacketId::Noop => 6,
52 }
53 }
54}
55
56impl TryFrom<u8> for PacketId {
57 type Error = Error;
58 fn try_from(b: u8) -> Result<PacketId> {
60 match b {
61 0 | b'0' => Ok(PacketId::Open),
62 1 | b'1' => Ok(PacketId::Close),
63 2 | b'2' => Ok(PacketId::Ping),
64 3 | b'3' => Ok(PacketId::Pong),
65 4 | b'4' => Ok(PacketId::Message),
66 5 | b'5' => Ok(PacketId::Upgrade),
67 6 | b'6' => Ok(PacketId::Noop),
68 _ => Err(Error::InvalidPacketId(b)),
69 }
70 }
71}
72
73#[derive(Debug, Clone, Eq, PartialEq)]
75pub struct Packet {
76 pub packet_id: PacketId,
77 pub data: Bytes,
78}
79
80#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
82pub struct HandshakePacket {
83 pub sid: String,
84 pub upgrades: Vec<String>,
85 #[serde(rename = "pingInterval")]
86 pub ping_interval: u64,
87 #[serde(rename = "pingTimeout")]
88 pub ping_timeout: u64,
89}
90
91impl TryFrom<Packet> for HandshakePacket {
92 type Error = Error;
93 fn try_from(packet: Packet) -> Result<HandshakePacket> {
94 Ok(serde_json::from_slice(packet.data[..].as_ref())?)
95 }
96}
97
98impl Packet {
99 pub fn new<T: Into<Bytes>>(packet_id: PacketId, data: T) -> Self {
101 Packet {
102 packet_id,
103 data: data.into(),
104 }
105 }
106}
107
108impl TryFrom<Bytes> for Packet {
109 type Error = Error;
110 fn try_from(
112 bytes: Bytes,
113 ) -> std::result::Result<Self, <Self as std::convert::TryFrom<Bytes>>::Error> {
114 if bytes.is_empty() {
115 return Err(Error::IncompletePacket());
116 }
117
118 let is_base64 = *bytes.first().ok_or(Error::IncompletePacket())? == b'b';
119
120 let packet_id = if is_base64 {
122 PacketId::MessageBinary
123 } else {
124 (*bytes.first().ok_or(Error::IncompletePacket())?).try_into()?
125 };
126
127 if bytes.len() == 1 && packet_id == PacketId::Message {
128 return Err(Error::IncompletePacket());
129 }
130
131 let data: Bytes = bytes.slice(1..);
132
133 Ok(Packet {
134 packet_id,
135 data: if is_base64 {
136 Bytes::from(general_purpose::STANDARD.decode(data.as_ref())?)
137 } else {
138 data
139 },
140 })
141 }
142}
143
144impl From<Packet> for Bytes {
145 fn from(packet: Packet) -> Self {
147 let mut result = BytesMut::with_capacity(packet.data.len() + 1);
148 result.put_u8(packet.packet_id.to_string_byte());
149 if packet.packet_id == PacketId::MessageBinary {
150 result.extend(general_purpose::STANDARD.encode(packet.data).into_bytes());
151 } else {
152 result.put(packet.data);
153 }
154 result.freeze()
155 }
156}
157
158#[derive(Debug, Clone)]
159pub(crate) struct Payload(Vec<Packet>);
160
161impl Payload {
162 const SEPARATOR: char = '\x1e';
164
165 #[cfg(test)]
166 pub fn len(&self) -> usize {
167 self.0.len()
168 }
169}
170
171impl TryFrom<Bytes> for Payload {
172 type Error = Error;
173 fn try_from(payload: Bytes) -> Result<Self> {
176 payload
177 .split(|&c| c as char == Self::SEPARATOR)
178 .map(|slice| Packet::try_from(payload.slice_ref(slice)))
179 .collect::<Result<Vec<_>>>()
180 .map(Self)
181 }
182}
183
184impl TryFrom<Payload> for Bytes {
185 type Error = Error;
186 fn try_from(packets: Payload) -> Result<Self> {
190 let mut buf = BytesMut::new();
191 for packet in packets {
192 buf.extend(Bytes::from(packet.clone()));
194 buf.put_u8(Payload::SEPARATOR as u8);
195 }
196
197 let _ = buf.split_off(buf.len() - 1);
199 Ok(buf.freeze())
200 }
201}
202
203#[derive(Clone, Debug)]
204pub struct IntoIter {
205 iter: std::vec::IntoIter<Packet>,
206}
207
208impl Iterator for IntoIter {
209 type Item = Packet;
210 fn next(&mut self) -> std::option::Option<<Self as std::iter::Iterator>::Item> {
211 self.iter.next()
212 }
213}
214
215impl IntoIterator for Payload {
216 type Item = Packet;
217 type IntoIter = IntoIter;
218 fn into_iter(self) -> <Self as std::iter::IntoIterator>::IntoIter {
219 IntoIter {
220 iter: self.0.into_iter(),
221 }
222 }
223}
224
225impl Index<usize> for Payload {
226 type Output = Packet;
227 fn index(&self, index: usize) -> &Packet {
228 &self.0[index]
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_packet_error() {
238 let err = Packet::try_from(BytesMut::with_capacity(10).freeze());
239 assert!(err.is_err())
240 }
241
242 #[test]
243 fn test_is_reflexive() {
244 let data = Bytes::from_static(b"1Hello World");
245 let packet = Packet::try_from(data).unwrap();
246
247 assert_eq!(packet.packet_id, PacketId::Close);
248 assert_eq!(packet.data, Bytes::from_static(b"Hello World"));
249
250 let data = Bytes::from_static(b"1Hello World");
251 assert_eq!(Bytes::from(packet), data);
252 }
253
254 #[test]
255 fn test_binary_packet() {
256 let data = Bytes::from_static(b"bSGVsbG8=");
258 let packet = Packet::try_from(data.clone()).unwrap();
259
260 assert_eq!(packet.packet_id, PacketId::MessageBinary);
261 assert_eq!(packet.data, Bytes::from_static(b"Hello"));
262
263 assert_eq!(Bytes::from(packet), data);
264 }
265
266 #[test]
267 fn test_decode_payload() -> Result<()> {
268 let data = Bytes::from_static(b"1Hello\x1e1HelloWorld");
269 let packets = Payload::try_from(data)?;
270
271 assert_eq!(packets[0].packet_id, PacketId::Close);
272 assert_eq!(packets[0].data, Bytes::from_static(b"Hello"));
273 assert_eq!(packets[1].packet_id, PacketId::Close);
274 assert_eq!(packets[1].data, Bytes::from_static(b"HelloWorld"));
275
276 let data = "1Hello\x1e1HelloWorld".to_owned().into_bytes();
277 assert_eq!(Bytes::try_from(packets).unwrap(), data);
278
279 Ok(())
280 }
281
282 #[test]
283 fn test_binary_payload() {
284 let data = Bytes::from_static(b"bSGVsbG8=\x1ebSGVsbG9Xb3JsZA==\x1ebSGVsbG8=");
285 let packets = Payload::try_from(data.clone()).unwrap();
286
287 assert!(packets.len() == 3);
288 assert_eq!(packets[0].packet_id, PacketId::MessageBinary);
289 assert_eq!(packets[0].data, Bytes::from_static(b"Hello"));
290 assert_eq!(packets[1].packet_id, PacketId::MessageBinary);
291 assert_eq!(packets[1].data, Bytes::from_static(b"HelloWorld"));
292 assert_eq!(packets[2].packet_id, PacketId::MessageBinary);
293 assert_eq!(packets[2].data, Bytes::from_static(b"Hello"));
294
295 assert_eq!(Bytes::try_from(packets).unwrap(), data);
296 }
297
298 #[test]
299 fn test_packet_id_conversion_and_incompl_packet() -> Result<()> {
300 let sut = Packet::try_from(Bytes::from_static(b"4"));
301 assert!(sut.is_err());
302 let _sut = sut.unwrap_err();
303 assert!(matches!(Error::IncompletePacket, _sut));
304
305 assert_eq!(PacketId::MessageBinary.to_string(), "b");
306
307 let sut = PacketId::try_from(b'0')?;
308 assert_eq!(sut, PacketId::Open);
309 assert_eq!(sut.to_string(), "0");
310
311 let sut = PacketId::try_from(b'1')?;
312 assert_eq!(sut, PacketId::Close);
313 assert_eq!(sut.to_string(), "1");
314
315 let sut = PacketId::try_from(b'2')?;
316 assert_eq!(sut, PacketId::Ping);
317 assert_eq!(sut.to_string(), "2");
318
319 let sut = PacketId::try_from(b'3')?;
320 assert_eq!(sut, PacketId::Pong);
321 assert_eq!(sut.to_string(), "3");
322
323 let sut = PacketId::try_from(b'4')?;
324 assert_eq!(sut, PacketId::Message);
325 assert_eq!(sut.to_string(), "4");
326
327 let sut = PacketId::try_from(b'5')?;
328 assert_eq!(sut, PacketId::Upgrade);
329 assert_eq!(sut.to_string(), "5");
330
331 let sut = PacketId::try_from(b'6')?;
332 assert_eq!(sut, PacketId::Noop);
333 assert_eq!(sut.to_string(), "6");
334
335 let sut = PacketId::try_from(42);
336 assert!(sut.is_err());
337 assert!(matches!(sut.unwrap_err(), Error::InvalidPacketId(42)));
338
339 Ok(())
340 }
341
342 #[test]
343 fn test_handshake_packet() {
344 assert!(
345 HandshakePacket::try_from(Packet::new(PacketId::Message, Bytes::from("test"))).is_err()
346 );
347 let packet = HandshakePacket {
348 ping_interval: 10000,
349 ping_timeout: 1000,
350 sid: "Test".to_owned(),
351 upgrades: vec!["websocket".to_owned(), "test".to_owned()],
352 };
353 let encoded: String = serde_json::to_string(&packet).unwrap();
354
355 assert_eq!(
356 packet,
357 HandshakePacket::try_from(Packet::new(PacketId::Message, Bytes::from(encoded)))
358 .unwrap()
359 );
360 }
361}