1use ahash::AHashMap;
2use prost::Message;
3use serde::{Deserialize, Serialize};
4use sonic_rs::Value;
5use std::collections::{BTreeMap, HashMap};
6
7use crate::messages::{ExtrasValue, MessageData, MessageExtras, PusherMessage};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
10#[serde(rename_all = "lowercase")]
11pub enum WireFormat {
12 #[default]
13 Json,
14 MessagePack,
15 Protobuf,
16}
17
18impl WireFormat {
19 pub fn from_query_param(value: Option<&str>) -> Self {
20 Self::parse_query_param(value).unwrap_or(Self::Json)
21 }
22
23 pub fn parse_query_param(value: Option<&str>) -> Result<Self, String> {
24 match value.map(|v| v.trim().to_ascii_lowercase()) {
25 None => Ok(Self::Json),
26 Some(v) if v.is_empty() || v == "json" => Ok(Self::Json),
27 Some(v) if v == "msgpack" || v == "messagepack" => Ok(Self::MessagePack),
28 Some(v) if v == "protobuf" || v == "proto" => Ok(Self::Protobuf),
29 Some(v) => Err(format!("unsupported wire format '{v}'")),
30 }
31 }
32
33 pub const fn is_binary(self) -> bool {
34 !matches!(self, Self::Json)
35 }
36}
37
38pub fn serialize_message(message: &PusherMessage, format: WireFormat) -> Result<Vec<u8>, String> {
39 match format {
40 WireFormat::Json => {
41 sonic_rs::to_vec(message).map_err(|e| format!("JSON serialization failed: {e}"))
42 }
43 WireFormat::MessagePack => rmp_serde::to_vec(&MsgpackPusherMessage::from(message.clone()))
44 .map_err(|e| format!("MessagePack serialization failed: {e}")),
45 WireFormat::Protobuf => {
46 let proto = ProtoPusherMessage::from(message.clone());
47 let mut buf = Vec::with_capacity(proto.encoded_len());
48 proto
49 .encode(&mut buf)
50 .map_err(|e| format!("Protobuf serialization failed: {e}"))?;
51 Ok(buf)
52 }
53 }
54}
55
56pub fn deserialize_message(bytes: &[u8], format: WireFormat) -> Result<PusherMessage, String> {
57 match format {
58 WireFormat::Json => {
59 sonic_rs::from_slice(bytes).map_err(|e| format!("JSON deserialization failed: {e}"))
60 }
61 WireFormat::MessagePack => {
62 let msg: MsgpackPusherMessage = rmp_serde::from_slice(bytes)
63 .map_err(|e| format!("MessagePack deserialization failed: {e}"))?;
64 Ok(msg.into())
65 }
66 WireFormat::Protobuf => {
67 let proto = ProtoPusherMessage::decode(bytes)
68 .map_err(|e| format!("Protobuf deserialization failed: {e}"))?;
69 Ok(proto.into())
70 }
71 }
72}
73
74#[derive(Clone, PartialEq, Message)]
75struct ProtoPusherMessage {
76 #[prost(string, optional, tag = "1")]
77 event: Option<String>,
78 #[prost(string, optional, tag = "2")]
79 channel: Option<String>,
80 #[prost(message, optional, tag = "3")]
81 data: Option<ProtoMessageData>,
82 #[prost(string, optional, tag = "4")]
83 name: Option<String>,
84 #[prost(string, optional, tag = "5")]
85 user_id: Option<String>,
86 #[prost(map = "string, string", tag = "6")]
87 tags: HashMap<String, String>,
88 #[prost(uint64, optional, tag = "7")]
89 sequence: Option<u64>,
90 #[prost(string, optional, tag = "8")]
91 conflation_key: Option<String>,
92 #[prost(string, optional, tag = "9")]
93 message_id: Option<String>,
94 #[prost(uint64, optional, tag = "10")]
95 serial: Option<u64>,
96 #[prost(string, optional, tag = "11")]
97 idempotency_key: Option<String>,
98 #[prost(message, optional, tag = "12")]
99 extras: Option<ProtoMessageExtras>,
100 #[prost(uint64, optional, tag = "13")]
101 delta_sequence: Option<u64>,
102 #[prost(string, optional, tag = "14")]
103 delta_conflation_key: Option<String>,
104}
105
106#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
107struct MsgpackPusherMessage {
108 event: Option<String>,
109 channel: Option<String>,
110 data: Option<MsgpackMessageData>,
111 name: Option<String>,
112 user_id: Option<String>,
113 tags: Option<BTreeMap<String, String>>,
114 sequence: Option<u64>,
115 conflation_key: Option<String>,
116 message_id: Option<String>,
117 serial: Option<u64>,
118 idempotency_key: Option<String>,
119 extras: Option<MsgpackMessageExtras>,
120 delta_sequence: Option<u64>,
121 delta_conflation_key: Option<String>,
122}
123
124#[derive(Clone, PartialEq, Message)]
125struct ProtoMessageData {
126 #[prost(oneof = "proto_message_data::Kind", tags = "1, 2, 3")]
127 kind: Option<proto_message_data::Kind>,
128}
129
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
131#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
132enum MsgpackMessageData {
133 String(String),
134 Structured(MsgpackStructuredData),
135 Json(String),
136}
137
138mod proto_message_data {
139 use super::ProtoStructuredData;
140 use prost::Oneof;
141
142 #[derive(Clone, PartialEq, Oneof)]
143 pub enum Kind {
144 #[prost(string, tag = "1")]
145 String(String),
146 #[prost(message, tag = "2")]
147 Structured(ProtoStructuredData),
148 #[prost(string, tag = "3")]
149 Json(String),
150 }
151}
152
153#[derive(Clone, PartialEq, Message)]
154struct ProtoStructuredData {
155 #[prost(string, optional, tag = "1")]
156 channel_data: Option<String>,
157 #[prost(string, optional, tag = "2")]
158 channel: Option<String>,
159 #[prost(string, optional, tag = "3")]
160 user_data: Option<String>,
161 #[prost(map = "string, string", tag = "4")]
162 extra: HashMap<String, String>,
163}
164
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166struct MsgpackStructuredData {
167 channel_data: Option<String>,
168 channel: Option<String>,
169 user_data: Option<String>,
170 extra: HashMap<String, String>,
171}
172
173#[derive(Clone, PartialEq, Message)]
174struct ProtoMessageExtras {
175 #[prost(map = "string, message", tag = "1")]
176 headers: HashMap<String, ProtoExtrasValue>,
177 #[prost(bool, optional, tag = "2")]
178 ephemeral: Option<bool>,
179 #[prost(string, optional, tag = "3")]
180 idempotency_key: Option<String>,
181 #[prost(bool, optional, tag = "4")]
182 echo: Option<bool>,
183}
184
185#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
186struct MsgpackMessageExtras {
187 headers: Option<HashMap<String, MsgpackExtrasValue>>,
188 ephemeral: Option<bool>,
189 idempotency_key: Option<String>,
190 echo: Option<bool>,
191}
192
193#[derive(Clone, PartialEq, Message)]
194struct ProtoExtrasValue {
195 #[prost(oneof = "proto_extras_value::Kind", tags = "1, 2, 3")]
196 kind: Option<proto_extras_value::Kind>,
197}
198
199#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
200#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
201enum MsgpackExtrasValue {
202 String(String),
203 Number(f64),
204 Bool(bool),
205}
206
207mod proto_extras_value {
208 use prost::Oneof;
209
210 #[derive(Clone, PartialEq, Oneof)]
211 pub enum Kind {
212 #[prost(string, tag = "1")]
213 String(String),
214 #[prost(double, tag = "2")]
215 Number(f64),
216 #[prost(bool, tag = "3")]
217 Bool(bool),
218 }
219}
220
221impl From<PusherMessage> for ProtoPusherMessage {
222 fn from(value: PusherMessage) -> Self {
223 Self {
224 event: value.event,
225 channel: value.channel,
226 data: value.data.map(Into::into),
227 name: value.name,
228 user_id: value.user_id,
229 tags: value
230 .tags
231 .map(|m| m.into_iter().collect())
232 .unwrap_or_default(),
233 sequence: value.sequence,
234 conflation_key: value.conflation_key,
235 message_id: value.message_id,
236 serial: value.serial,
237 idempotency_key: value.idempotency_key,
238 extras: value.extras.map(Into::into),
239 delta_sequence: value.delta_sequence,
240 delta_conflation_key: value.delta_conflation_key,
241 }
242 }
243}
244
245impl From<PusherMessage> for MsgpackPusherMessage {
246 fn from(value: PusherMessage) -> Self {
247 Self {
248 event: value.event,
249 channel: value.channel,
250 data: value.data.map(Into::into),
251 name: value.name,
252 user_id: value.user_id,
253 tags: value.tags,
254 sequence: value.sequence,
255 conflation_key: value.conflation_key,
256 message_id: value.message_id,
257 serial: value.serial,
258 idempotency_key: value.idempotency_key,
259 extras: value.extras.map(Into::into),
260 delta_sequence: value.delta_sequence,
261 delta_conflation_key: value.delta_conflation_key,
262 }
263 }
264}
265
266impl From<ProtoPusherMessage> for PusherMessage {
267 fn from(value: ProtoPusherMessage) -> Self {
268 Self {
269 event: value.event,
270 channel: value.channel,
271 data: value.data.map(Into::into),
272 name: value.name,
273 user_id: value.user_id,
274 tags: (!value.tags.is_empty())
275 .then_some(value.tags.into_iter().collect::<BTreeMap<_, _>>()),
276 sequence: value.sequence,
277 conflation_key: value.conflation_key,
278 message_id: value.message_id,
279 serial: value.serial,
280 idempotency_key: value.idempotency_key,
281 extras: value.extras.map(Into::into),
282 delta_sequence: value.delta_sequence,
283 delta_conflation_key: value.delta_conflation_key,
284 }
285 }
286}
287
288impl From<MsgpackPusherMessage> for PusherMessage {
289 fn from(value: MsgpackPusherMessage) -> Self {
290 Self {
291 event: value.event,
292 channel: value.channel,
293 data: value.data.map(Into::into),
294 name: value.name,
295 user_id: value.user_id,
296 tags: value.tags,
297 sequence: value.sequence,
298 conflation_key: value.conflation_key,
299 message_id: value.message_id,
300 serial: value.serial,
301 idempotency_key: value.idempotency_key,
302 extras: value.extras.map(Into::into),
303 delta_sequence: value.delta_sequence,
304 delta_conflation_key: value.delta_conflation_key,
305 }
306 }
307}
308
309impl From<MessageData> for ProtoMessageData {
310 fn from(value: MessageData) -> Self {
311 let kind = match value {
312 MessageData::String(s) => Some(proto_message_data::Kind::String(s)),
313 MessageData::Structured {
314 channel_data,
315 channel,
316 user_data,
317 extra,
318 } => Some(proto_message_data::Kind::Structured(ProtoStructuredData {
319 channel_data,
320 channel,
321 user_data,
322 extra: extra
323 .into_iter()
324 .map(|(k, v)| {
325 (
326 k,
327 sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
328 )
329 })
330 .collect(),
331 })),
332 MessageData::Json(v) => Some(proto_message_data::Kind::Json(
333 sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
334 )),
335 };
336
337 Self { kind }
338 }
339}
340
341impl From<MessageData> for MsgpackMessageData {
342 fn from(value: MessageData) -> Self {
343 match value {
344 MessageData::String(s) => Self::String(s),
345 MessageData::Structured {
346 channel_data,
347 channel,
348 user_data,
349 extra,
350 } => Self::Structured(MsgpackStructuredData {
351 channel_data,
352 channel,
353 user_data,
354 extra: extra
355 .into_iter()
356 .map(|(k, v)| {
357 (
358 k,
359 sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
360 )
361 })
362 .collect(),
363 }),
364 MessageData::Json(v) => {
365 Self::Json(sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()))
366 }
367 }
368 }
369}
370
371impl From<ProtoMessageData> for MessageData {
372 fn from(value: ProtoMessageData) -> Self {
373 match value.kind {
374 Some(proto_message_data::Kind::String(s)) => MessageData::String(s),
375 Some(proto_message_data::Kind::Structured(s)) => MessageData::Structured {
376 channel_data: s.channel_data,
377 channel: s.channel,
378 user_data: s.user_data,
379 extra: s
380 .extra
381 .into_iter()
382 .map(|(k, v)| {
383 let parsed =
384 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
385 (k, parsed)
386 })
387 .collect::<AHashMap<_, _>>(),
388 },
389 Some(proto_message_data::Kind::Json(v)) => MessageData::Json(
390 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
391 ),
392 None => MessageData::Json(Value::new_null()),
393 }
394 }
395}
396
397impl From<MsgpackMessageData> for MessageData {
398 fn from(value: MsgpackMessageData) -> Self {
399 match value {
400 MsgpackMessageData::String(s) => MessageData::String(s),
401 MsgpackMessageData::Structured(s) => MessageData::Structured {
402 channel_data: s.channel_data,
403 channel: s.channel,
404 user_data: s.user_data,
405 extra: s
406 .extra
407 .into_iter()
408 .map(|(k, v)| {
409 let parsed =
410 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
411 (k, parsed)
412 })
413 .collect::<AHashMap<_, _>>(),
414 },
415 MsgpackMessageData::Json(v) => MessageData::Json(
416 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
417 ),
418 }
419 }
420}
421
422impl From<MessageExtras> for ProtoMessageExtras {
423 fn from(value: MessageExtras) -> Self {
424 Self {
425 headers: value
426 .headers
427 .unwrap_or_default()
428 .into_iter()
429 .map(|(k, v)| (k, v.into()))
430 .collect(),
431 ephemeral: value.ephemeral,
432 idempotency_key: value.idempotency_key,
433 echo: value.echo,
434 }
435 }
436}
437
438impl From<MessageExtras> for MsgpackMessageExtras {
439 fn from(value: MessageExtras) -> Self {
440 Self {
441 headers: value
442 .headers
443 .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
444 ephemeral: value.ephemeral,
445 idempotency_key: value.idempotency_key,
446 echo: value.echo,
447 }
448 }
449}
450
451impl From<ProtoMessageExtras> for MessageExtras {
452 fn from(value: ProtoMessageExtras) -> Self {
453 Self {
454 headers: (!value.headers.is_empty()).then_some(
455 value
456 .headers
457 .into_iter()
458 .map(|(k, v)| (k, v.into()))
459 .collect(),
460 ),
461 ephemeral: value.ephemeral,
462 idempotency_key: value.idempotency_key,
463 echo: value.echo,
464 }
465 }
466}
467
468impl From<MsgpackMessageExtras> for MessageExtras {
469 fn from(value: MsgpackMessageExtras) -> Self {
470 Self {
471 headers: value
472 .headers
473 .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
474 ephemeral: value.ephemeral,
475 idempotency_key: value.idempotency_key,
476 echo: value.echo,
477 }
478 }
479}
480
481impl From<ExtrasValue> for ProtoExtrasValue {
482 fn from(value: ExtrasValue) -> Self {
483 let kind = match value {
484 ExtrasValue::String(s) => Some(proto_extras_value::Kind::String(s)),
485 ExtrasValue::Number(n) => Some(proto_extras_value::Kind::Number(n)),
486 ExtrasValue::Bool(b) => Some(proto_extras_value::Kind::Bool(b)),
487 };
488 Self { kind }
489 }
490}
491
492impl From<ExtrasValue> for MsgpackExtrasValue {
493 fn from(value: ExtrasValue) -> Self {
494 match value {
495 ExtrasValue::String(s) => Self::String(s),
496 ExtrasValue::Number(n) => Self::Number(n),
497 ExtrasValue::Bool(b) => Self::Bool(b),
498 }
499 }
500}
501
502impl From<ProtoExtrasValue> for ExtrasValue {
503 fn from(value: ProtoExtrasValue) -> Self {
504 match value.kind {
505 Some(proto_extras_value::Kind::String(s)) => ExtrasValue::String(s),
506 Some(proto_extras_value::Kind::Number(n)) => ExtrasValue::Number(n),
507 Some(proto_extras_value::Kind::Bool(b)) => ExtrasValue::Bool(b),
508 None => ExtrasValue::String(String::new()),
509 }
510 }
511}
512
513impl From<MsgpackExtrasValue> for ExtrasValue {
514 fn from(value: MsgpackExtrasValue) -> Self {
515 match value {
516 MsgpackExtrasValue::String(s) => ExtrasValue::String(s),
517 MsgpackExtrasValue::Number(n) => ExtrasValue::Number(n),
518 MsgpackExtrasValue::Bool(b) => ExtrasValue::Bool(b),
519 }
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 fn sample_message() -> PusherMessage {
528 PusherMessage {
529 event: Some("sockudo:test".to_string()),
530 channel: Some("chat:room-1".to_string()),
531 data: Some(MessageData::Json(sonic_rs::json!({
532 "hello": "world",
533 "count": 3,
534 "nested": { "ok": true },
535 "items": [1, 2, 3]
536 }))),
537 name: None,
538 user_id: Some("user-1".to_string()),
539 tags: Some(BTreeMap::from([
540 ("region".to_string(), "eu".to_string()),
541 ("tier".to_string(), "gold".to_string()),
542 ])),
543 sequence: Some(7),
544 conflation_key: Some("room".to_string()),
545 message_id: Some("mid-1".to_string()),
546 serial: Some(9),
547 idempotency_key: Some("idem-1".to_string()),
548 extras: Some(MessageExtras {
549 headers: Some(HashMap::from([
550 (
551 "priority".to_string(),
552 ExtrasValue::String("high".to_string()),
553 ),
554 ("ttl".to_string(), ExtrasValue::Number(5.0)),
555 ])),
556 ephemeral: Some(true),
557 idempotency_key: Some("extra-idem".to_string()),
558 echo: Some(false),
559 }),
560 delta_sequence: Some(11),
561 delta_conflation_key: Some("btc".to_string()),
562 }
563 }
564
565 #[test]
566 fn round_trip_messagepack() {
567 let msg = sample_message();
568 let bytes = serialize_message(&msg, WireFormat::MessagePack).unwrap();
569 let decoded = deserialize_message(&bytes, WireFormat::MessagePack).unwrap();
570 assert_eq!(decoded.event, msg.event);
571 assert_eq!(decoded.delta_sequence, msg.delta_sequence);
572 }
573
574 #[test]
575 fn round_trip_protobuf() {
576 let msg = sample_message();
577 let bytes = serialize_message(&msg, WireFormat::Protobuf).unwrap();
578 let decoded = deserialize_message(&bytes, WireFormat::Protobuf).unwrap();
579 assert_eq!(decoded.event, msg.event);
580 assert_eq!(decoded.channel, msg.channel);
581 assert_eq!(decoded.message_id, msg.message_id);
582 assert_eq!(decoded.delta_conflation_key, msg.delta_conflation_key);
583 }
584
585 #[test]
586 fn parse_query_param_accepts_known_values() {
587 assert_eq!(
588 WireFormat::parse_query_param(None).unwrap(),
589 WireFormat::Json
590 );
591 assert_eq!(
592 WireFormat::parse_query_param(Some("json")).unwrap(),
593 WireFormat::Json
594 );
595 assert_eq!(
596 WireFormat::parse_query_param(Some("messagepack")).unwrap(),
597 WireFormat::MessagePack
598 );
599 assert_eq!(
600 WireFormat::parse_query_param(Some("msgpack")).unwrap(),
601 WireFormat::MessagePack
602 );
603 assert_eq!(
604 WireFormat::parse_query_param(Some("protobuf")).unwrap(),
605 WireFormat::Protobuf
606 );
607 assert_eq!(
608 WireFormat::parse_query_param(Some("proto")).unwrap(),
609 WireFormat::Protobuf
610 );
611 }
612
613 #[test]
614 fn parse_query_param_rejects_unknown_value() {
615 assert!(WireFormat::parse_query_param(Some("avro")).is_err());
616 }
617}