1use serde::{Deserialize, Serialize};
5
6pub use engineioxide_core::{Sid, Str};
7
8use crate::Value;
9
10#[derive(Debug, Clone, PartialEq)]
13pub struct Packet {
14 pub inner: PacketData,
16 pub ns: Str,
18}
19
20impl Packet {
21 pub fn connect(ns: impl Into<Str>, value: Option<Value>) -> Self {
23 Self {
24 inner: PacketData::Connect(value),
25 ns: ns.into(),
26 }
27 }
28
29 pub fn disconnect(ns: impl Into<Str>) -> Self {
31 Self {
32 inner: PacketData::Disconnect,
33 ns: ns.into(),
34 }
35 }
36}
37
38impl Packet {
39 pub fn connect_error(ns: impl Into<Str>, message: impl Into<String>) -> Self {
41 Self {
42 inner: PacketData::ConnectError(message.into()),
43 ns: ns.into(),
44 }
45 }
46
47 pub fn event(ns: impl Into<Str>, data: Value) -> Self {
50 Self {
51 inner: match data {
52 Value::Str(_, Some(ref bins)) if !bins.is_empty() => {
53 PacketData::BinaryEvent(data, None)
54 }
55 _ => PacketData::Event(data, None),
56 },
57 ns: ns.into(),
58 }
59 }
60
61 pub fn ack(ns: impl Into<Str>, data: Value, ack: i64) -> Self {
64 Self {
65 inner: match data {
66 Value::Str(_, Some(ref bins)) if !bins.is_empty() => {
67 PacketData::BinaryAck(data, ack)
68 }
69 _ => PacketData::EventAck(data, ack),
70 },
71 ns: ns.into(),
72 }
73 }
74}
75
76#[derive(Debug, Clone, PartialEq)]
86pub enum PacketData {
87 Connect(Option<Value>),
89 Disconnect,
91 Event(Value, Option<i64>),
93 EventAck(Value, i64),
95 ConnectError(String),
97 BinaryEvent(Value, Option<i64>),
99 BinaryAck(Value, i64),
101}
102
103impl PacketData {
104 pub fn index(&self) -> usize {
106 match self {
107 PacketData::Connect(_) => 0,
108 PacketData::Disconnect => 1,
109 PacketData::Event(_, _) => 2,
110 PacketData::EventAck(_, _) => 3,
111 PacketData::ConnectError(_) => 4,
112 PacketData::BinaryEvent(_, _) => 5,
113 PacketData::BinaryAck(_, _) => 6,
114 }
115 }
116
117 pub fn set_ack_id(&mut self, ack_id: i64) {
120 match self {
121 PacketData::Event(_, ack) | PacketData::BinaryEvent(_, ack) => *ack = Some(ack_id),
122 _ => {}
123 };
124 }
125
126 pub fn is_binary(&self) -> bool {
128 matches!(
129 self,
130 PacketData::BinaryEvent(_, _) | PacketData::BinaryAck(_, _)
131 )
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ConnectPacket {
138 pub sid: Sid,
140}
141
142impl Serialize for Packet {
143 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
144 #[derive(Serialize)]
145 struct RawPacket<'a> {
146 ns: &'a Str,
147 r#type: u8,
148 data: Option<&'a Value>,
149 ack: Option<i64>,
150 error: Option<&'a String>,
151 }
152 let (r#type, data, ack, error) = match &self.inner {
153 PacketData::Connect(v) => (0, v.as_ref(), None, None),
154 PacketData::Disconnect => (1, None, None, None),
155 PacketData::Event(v, ack) => (2, Some(v), *ack, None),
156 PacketData::EventAck(v, ack) => (3, Some(v), Some(*ack), None),
157 PacketData::ConnectError(e) => (4, None, None, Some(e)),
158 PacketData::BinaryEvent(v, ack) => (5, Some(v), *ack, None),
159 PacketData::BinaryAck(v, ack) => (6, Some(v), Some(*ack), None),
160 };
161 let raw = RawPacket {
162 ns: &self.ns,
163 data,
164 ack,
165 error,
166 r#type,
167 };
168 raw.serialize(serializer)
169 }
170}
171impl<'de> Deserialize<'de> for Packet {
172 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
173 #[derive(Deserialize)]
174 struct RawPacket {
175 ns: Str,
176 r#type: u8,
177 data: Option<Value>,
178 ack: Option<i64>,
179 error: Option<String>,
180 }
181 let raw = RawPacket::deserialize(deserializer)?;
182 let err = |field| serde::de::Error::custom(format!("missing field: {}", field));
183 let inner = match raw.r#type {
184 0 => PacketData::Connect(raw.data),
185 1 => PacketData::Disconnect,
186 2 => PacketData::Event(raw.data.ok_or(err("data"))?, raw.ack),
187 3 => PacketData::EventAck(raw.data.ok_or(err("data"))?, raw.ack.ok_or(err("ack"))?),
188 4 => PacketData::ConnectError(raw.error.ok_or(err("error"))?),
189 5 => PacketData::BinaryEvent(raw.data.ok_or(err("data"))?, raw.ack),
190 6 => PacketData::BinaryAck(raw.data.ok_or(err("data"))?, raw.ack.ok_or(err("ack"))?),
191 i => return Err(serde::de::Error::custom(format!("invalid packet type {i}"))),
192 };
193 Ok(Self { inner, ns: raw.ns })
194 }
195}
196
197#[cfg(test)]
198mod tests {
199
200 use std::collections::VecDeque;
201
202 use super::{Packet, PacketData, Value};
203 use bytes::Bytes;
204
205 #[test]
206 fn should_create_bin_packet_with_adjacent_binary() {
207 let val = Value::Str(
208 "test".into(),
209 Some(vec![Bytes::from_static(&[1, 2, 3])].into()),
210 );
211 assert!(matches!(
212 Packet::event("/", val.clone()).inner,
213 PacketData::BinaryEvent(v, None) if v == val));
214
215 assert!(matches!(
216 Packet::ack("/", val.clone(), 120).inner,
217 PacketData::BinaryAck(v, 120) if v == val));
218 }
219
220 #[test]
221 fn should_create_default_packet_with_base_data() {
222 let val = Value::Str("test".into(), None);
223 let val1 = Value::Bytes(Bytes::from_static(b"test"));
224
225 assert!(matches!(
226 Packet::event("/", val.clone()).inner,
227 PacketData::Event(v, None) if v == val));
228
229 assert!(matches!(
230 Packet::ack("/", val.clone(), 120).inner,
231 PacketData::EventAck(v, 120) if v == val));
232
233 assert!(matches!(
234 Packet::event("/", val1.clone()).inner,
235 PacketData::Event(v, None) if v == val1));
236
237 assert!(matches!(
238 Packet::ack("/", val1.clone(), 120).inner,
239 PacketData::EventAck(v, 120) if v == val1));
240 }
241
242 fn assert_serde_packet(packet: Packet) {
243 let serialized = serde_json::to_string(&packet).unwrap();
244 let deserialized: Packet = serde_json::from_str(&serialized).unwrap();
245 assert_eq!(packet, deserialized);
246 }
247 #[test]
248 fn packet_serde_connect() {
249 let packet = Packet {
250 ns: "/".into(),
251 inner: PacketData::Connect(Some(Value::Str("test_data".into(), None))),
252 };
253 assert_serde_packet(packet);
254 }
255
256 #[test]
257 fn packet_serde_disconnect() {
258 let packet = Packet {
259 ns: "/".into(),
260 inner: PacketData::Disconnect,
261 };
262 assert_serde_packet(packet);
263 }
264
265 #[test]
266 fn packet_serde_event() {
267 let packet = Packet {
268 ns: "/".into(),
269 inner: PacketData::Event(Value::Str("event_data".into(), None), None),
270 };
271 assert_serde_packet(packet);
272
273 let mut bins = VecDeque::new();
274 bins.push_back(Bytes::from_static(&[1, 2, 3, 4]));
275 bins.push_back(Bytes::from_static(&[1, 2, 3, 4]));
276 let packet = Packet {
277 ns: "/".into(),
278 inner: PacketData::Event(Value::Str("event_data".into(), Some(bins)), Some(12)),
279 };
280 assert_serde_packet(packet);
281 }
282
283 #[test]
284 fn packet_serde_event_ack() {
285 let packet = Packet {
286 ns: "/".into(),
287 inner: PacketData::EventAck(Value::Str("event_ack_data".into(), None), 42),
288 };
289 assert_serde_packet(packet);
290 }
291
292 #[test]
293 fn packet_serde_connect_error() {
294 let packet = Packet {
295 ns: "/".into(),
296 inner: PacketData::ConnectError("connection_error".into()),
297 };
298 assert_serde_packet(packet);
299 }
300
301 #[test]
302 fn packet_serde_binary_event() {
303 let packet = Packet {
304 ns: "/".into(),
305 inner: PacketData::BinaryEvent(Value::Str("binary_event_data".into(), None), None),
306 };
307 assert_serde_packet(packet);
308
309 let mut bins = VecDeque::new();
310 bins.push_back(Bytes::from_static(&[1, 2, 3, 4]));
311 bins.push_back(Bytes::from_static(&[1, 2, 3, 4]));
312 let packet = Packet {
313 ns: "/".into(),
314 inner: PacketData::BinaryEvent(Value::Str("event_data".into(), Some(bins)), Some(12)),
315 };
316 assert_serde_packet(packet);
317 }
318
319 #[test]
320 fn packet_serde_binary_ack() {
321 let packet = Packet {
322 ns: "/".into(),
323 inner: PacketData::BinaryAck(Value::Str("binary_ack_data".into(), None), 99),
324 };
325 assert_serde_packet(packet);
326
327 let mut bins = VecDeque::new();
328 bins.push_back(Bytes::from_static(&[1, 2, 3, 4]));
329 bins.push_back(Bytes::from_static(&[1, 2, 3, 4]));
330 let packet = Packet {
331 ns: "/".into(),
332 inner: PacketData::BinaryAck(Value::Str("binary_ack_data".into(), Some(bins)), 99),
333 };
334 assert_serde_packet(packet);
335 }
336}