1use ahash::AHashMap;
2use prost::Message;
3use serde::{Deserialize, Serialize};
4use sonic_rs::Value;
5use std::collections::{BTreeMap, HashMap};
6
7use crate::messages::{ExtrasValue, MessageData, MessageExtras, PusherMessage};
8use crate::versioned_messages::{MessageAction, MessageVersionMetadata, VersionedRealtimeMessage};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
11#[serde(rename_all = "lowercase")]
12pub enum WireFormat {
13 #[default]
14 Json,
15 MessagePack,
16 Protobuf,
17}
18
19impl WireFormat {
20 pub fn from_query_param(value: Option<&str>) -> Self {
21 Self::parse_query_param(value).unwrap_or(Self::Json)
22 }
23
24 pub fn parse_query_param(value: Option<&str>) -> Result<Self, String> {
25 match value.map(|v| v.trim().to_ascii_lowercase()) {
26 None => Ok(Self::Json),
27 Some(v) if v.is_empty() || v == "json" => Ok(Self::Json),
28 Some(v) if v == "msgpack" || v == "messagepack" => Ok(Self::MessagePack),
29 Some(v) if v == "protobuf" || v == "proto" => Ok(Self::Protobuf),
30 Some(v) => Err(format!("unsupported wire format '{v}'")),
31 }
32 }
33
34 pub const fn is_binary(self) -> bool {
35 !matches!(self, Self::Json)
36 }
37}
38
39pub fn serialize_message(message: &PusherMessage, format: WireFormat) -> Result<Vec<u8>, String> {
40 match format {
41 WireFormat::Json => {
42 sonic_rs::to_vec(message).map_err(|e| format!("JSON serialization failed: {e}"))
43 }
44 WireFormat::MessagePack => rmp_serde::to_vec(&MsgpackPusherMessage::from(message.clone()))
45 .map_err(|e| format!("MessagePack serialization failed: {e}")),
46 WireFormat::Protobuf => {
47 let proto = ProtoPusherMessage::from(message.clone());
48 let mut buf = Vec::with_capacity(proto.encoded_len());
49 proto
50 .encode(&mut buf)
51 .map_err(|e| format!("Protobuf serialization failed: {e}"))?;
52 Ok(buf)
53 }
54 }
55}
56
57pub fn deserialize_message(bytes: &[u8], format: WireFormat) -> Result<PusherMessage, String> {
58 match format {
59 WireFormat::Json => {
60 sonic_rs::from_slice(bytes).map_err(|e| format!("JSON deserialization failed: {e}"))
61 }
62 WireFormat::MessagePack => {
63 let msg: MsgpackPusherMessage = rmp_serde::from_slice(bytes)
64 .map_err(|e| format!("MessagePack deserialization failed: {e}"))?;
65 Ok(msg.into())
66 }
67 WireFormat::Protobuf => {
68 let proto = ProtoPusherMessage::decode(bytes)
69 .map_err(|e| format!("Protobuf deserialization failed: {e}"))?;
70 Ok(proto.into())
71 }
72 }
73}
74
75pub fn serialize_versioned_message(
76 message: &VersionedRealtimeMessage,
77 format: WireFormat,
78) -> Result<Vec<u8>, String> {
79 match format {
80 WireFormat::Json => {
81 sonic_rs::to_vec(message).map_err(|e| format!("JSON serialization failed: {e}"))
82 }
83 WireFormat::MessagePack => {
84 rmp_serde::to_vec(&MsgpackVersionedRealtimeMessage::from(message.clone()))
85 .map_err(|e| format!("MessagePack serialization failed: {e}"))
86 }
87 WireFormat::Protobuf => {
88 let proto = ProtoVersionedRealtimeMessage::from(message.clone());
89 let mut buf = Vec::with_capacity(proto.encoded_len());
90 proto
91 .encode(&mut buf)
92 .map_err(|e| format!("Protobuf serialization failed: {e}"))?;
93 Ok(buf)
94 }
95 }
96}
97
98pub fn deserialize_versioned_message(
99 bytes: &[u8],
100 format: WireFormat,
101) -> Result<VersionedRealtimeMessage, String> {
102 let message: VersionedRealtimeMessage = match format {
103 WireFormat::Json => {
104 sonic_rs::from_slice(bytes).map_err(|e| format!("JSON deserialization failed: {e}"))
105 }
106 WireFormat::MessagePack => {
107 let msg: MsgpackVersionedRealtimeMessage = rmp_serde::from_slice(bytes)
108 .map_err(|e| format!("MessagePack deserialization failed: {e}"))?;
109 Ok(msg.into())
110 }
111 WireFormat::Protobuf => {
112 let proto = ProtoVersionedRealtimeMessage::decode(bytes)
113 .map_err(|e| format!("Protobuf deserialization failed: {e}"))?;
114 Ok(proto.into())
115 }
116 }?;
117
118 message.validate_v2()?;
119 Ok(message)
120}
121
122#[derive(Clone, PartialEq, Message)]
123struct ProtoPusherMessage {
124 #[prost(string, optional, tag = "1")]
125 event: Option<String>,
126 #[prost(string, optional, tag = "2")]
127 channel: Option<String>,
128 #[prost(message, optional, tag = "3")]
129 data: Option<ProtoMessageData>,
130 #[prost(string, optional, tag = "4")]
131 name: Option<String>,
132 #[prost(string, optional, tag = "5")]
133 user_id: Option<String>,
134 #[prost(map = "string, string", tag = "6")]
135 tags: HashMap<String, String>,
136 #[prost(uint64, optional, tag = "7")]
137 sequence: Option<u64>,
138 #[prost(string, optional, tag = "8")]
139 conflation_key: Option<String>,
140 #[prost(string, optional, tag = "9")]
141 message_id: Option<String>,
142 #[prost(string, optional, tag = "10")]
143 stream_id: Option<String>,
144 #[prost(uint64, optional, tag = "11")]
145 serial: Option<u64>,
146 #[prost(string, optional, tag = "12")]
147 idempotency_key: Option<String>,
148 #[prost(message, optional, tag = "13")]
149 extras: Option<ProtoMessageExtras>,
150 #[prost(uint64, optional, tag = "14")]
151 delta_sequence: Option<u64>,
152 #[prost(string, optional, tag = "15")]
153 delta_conflation_key: Option<String>,
154}
155
156#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
157struct MsgpackPusherMessage {
158 event: Option<String>,
159 channel: Option<String>,
160 data: Option<MsgpackMessageData>,
161 name: Option<String>,
162 user_id: Option<String>,
163 tags: Option<BTreeMap<String, String>>,
164 sequence: Option<u64>,
165 conflation_key: Option<String>,
166 message_id: Option<String>,
167 stream_id: Option<String>,
168 serial: Option<u64>,
169 idempotency_key: Option<String>,
170 extras: Option<MsgpackMessageExtras>,
171 delta_sequence: Option<u64>,
172 delta_conflation_key: Option<String>,
173}
174
175#[derive(Clone, PartialEq, Message)]
176struct ProtoVersionedRealtimeMessage {
177 #[prost(message, optional, tag = "1")]
178 message: Option<ProtoPusherMessage>,
179 #[prost(string, tag = "2")]
180 action: String,
181 #[prost(string, tag = "3")]
182 message_serial: String,
183 #[prost(uint64, optional, tag = "4")]
184 history_serial: Option<u64>,
185 #[prost(uint64, optional, tag = "5")]
186 delivery_serial: Option<u64>,
187 #[prost(message, optional, tag = "6")]
188 version: Option<ProtoMessageVersionMetadata>,
189}
190
191#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192struct MsgpackVersionedRealtimeMessage {
193 message: MsgpackPusherMessage,
194 action: MessageAction,
195 message_serial: String,
196 history_serial: Option<u64>,
197 delivery_serial: Option<u64>,
198 version: Option<MsgpackMessageVersionMetadata>,
199}
200
201#[derive(Clone, PartialEq, Message)]
202struct ProtoMessageData {
203 #[prost(oneof = "proto_message_data::Kind", tags = "1, 2, 3")]
204 kind: Option<proto_message_data::Kind>,
205}
206
207#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
208#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
209enum MsgpackMessageData {
210 String(String),
211 Structured(MsgpackStructuredData),
212 Json(String),
213}
214
215mod proto_message_data {
216 use super::ProtoStructuredData;
217 use prost::Oneof;
218
219 #[derive(Clone, PartialEq, Oneof)]
220 pub enum Kind {
221 #[prost(string, tag = "1")]
222 String(String),
223 #[prost(message, tag = "2")]
224 Structured(ProtoStructuredData),
225 #[prost(string, tag = "3")]
226 Json(String),
227 }
228}
229
230#[derive(Clone, PartialEq, Message)]
231struct ProtoStructuredData {
232 #[prost(string, optional, tag = "1")]
233 channel_data: Option<String>,
234 #[prost(string, optional, tag = "2")]
235 channel: Option<String>,
236 #[prost(string, optional, tag = "3")]
237 user_data: Option<String>,
238 #[prost(map = "string, string", tag = "4")]
239 extra: HashMap<String, String>,
240}
241
242#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
243struct MsgpackStructuredData {
244 channel_data: Option<String>,
245 channel: Option<String>,
246 user_data: Option<String>,
247 extra: HashMap<String, String>,
248}
249
250#[derive(Clone, PartialEq, Message)]
251struct ProtoMessageExtras {
252 #[prost(map = "string, message", tag = "1")]
253 headers: HashMap<String, ProtoExtrasValue>,
254 #[prost(bool, optional, tag = "2")]
255 ephemeral: Option<bool>,
256 #[prost(string, optional, tag = "3")]
257 idempotency_key: Option<String>,
258 #[prost(bool, optional, tag = "4")]
259 echo: Option<bool>,
260}
261
262#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
263struct MsgpackMessageExtras {
264 headers: Option<HashMap<String, MsgpackExtrasValue>>,
265 ephemeral: Option<bool>,
266 idempotency_key: Option<String>,
267 echo: Option<bool>,
268}
269
270#[derive(Clone, PartialEq, Message)]
271struct ProtoMessageVersionMetadata {
272 #[prost(string, tag = "1")]
273 serial: String,
274 #[prost(string, optional, tag = "2")]
275 client_id: Option<String>,
276 #[prost(int64, tag = "3")]
277 timestamp_ms: i64,
278 #[prost(string, optional, tag = "4")]
279 description: Option<String>,
280 #[prost(string, optional, tag = "5")]
281 metadata_json: Option<String>,
282}
283
284#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
285struct MsgpackMessageVersionMetadata {
286 serial: String,
287 client_id: Option<String>,
288 timestamp_ms: i64,
289 description: Option<String>,
290 metadata_json: Option<String>,
291}
292
293#[derive(Clone, PartialEq, Message)]
294struct ProtoExtrasValue {
295 #[prost(oneof = "proto_extras_value::Kind", tags = "1, 2, 3")]
296 kind: Option<proto_extras_value::Kind>,
297}
298
299#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
300#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
301enum MsgpackExtrasValue {
302 String(String),
303 Number(f64),
304 Bool(bool),
305}
306
307mod proto_extras_value {
308 use prost::Oneof;
309
310 #[derive(Clone, PartialEq, Oneof)]
311 pub enum Kind {
312 #[prost(string, tag = "1")]
313 String(String),
314 #[prost(double, tag = "2")]
315 Number(f64),
316 #[prost(bool, tag = "3")]
317 Bool(bool),
318 }
319}
320
321impl From<PusherMessage> for ProtoPusherMessage {
322 fn from(value: PusherMessage) -> Self {
323 Self {
324 event: value.event,
325 channel: value.channel,
326 data: value.data.map(Into::into),
327 name: value.name,
328 user_id: value.user_id,
329 tags: value
330 .tags
331 .map(|m| m.into_iter().collect())
332 .unwrap_or_default(),
333 sequence: value.sequence,
334 conflation_key: value.conflation_key,
335 message_id: value.message_id,
336 stream_id: value.stream_id,
337 serial: value.serial,
338 idempotency_key: value.idempotency_key,
339 extras: value.extras.map(Into::into),
340 delta_sequence: value.delta_sequence,
341 delta_conflation_key: value.delta_conflation_key,
342 }
343 }
344}
345
346impl From<PusherMessage> for MsgpackPusherMessage {
347 fn from(value: PusherMessage) -> Self {
348 Self {
349 event: value.event,
350 channel: value.channel,
351 data: value.data.map(Into::into),
352 name: value.name,
353 user_id: value.user_id,
354 tags: value.tags,
355 sequence: value.sequence,
356 conflation_key: value.conflation_key,
357 message_id: value.message_id,
358 stream_id: value.stream_id,
359 serial: value.serial,
360 idempotency_key: value.idempotency_key,
361 extras: value.extras.map(Into::into),
362 delta_sequence: value.delta_sequence,
363 delta_conflation_key: value.delta_conflation_key,
364 }
365 }
366}
367
368impl From<VersionedRealtimeMessage> for ProtoVersionedRealtimeMessage {
369 fn from(value: VersionedRealtimeMessage) -> Self {
370 Self {
371 message: Some(ProtoPusherMessage::from(value.message)),
372 action: value.action.as_str().to_string(),
373 message_serial: value.message_serial,
374 history_serial: value.history_serial,
375 delivery_serial: value.delivery_serial,
376 version: value.version.map(Into::into),
377 }
378 }
379}
380
381impl From<VersionedRealtimeMessage> for MsgpackVersionedRealtimeMessage {
382 fn from(value: VersionedRealtimeMessage) -> Self {
383 Self {
384 message: MsgpackPusherMessage::from(value.message),
385 action: value.action,
386 message_serial: value.message_serial,
387 history_serial: value.history_serial,
388 delivery_serial: value.delivery_serial,
389 version: value.version.map(Into::into),
390 }
391 }
392}
393
394impl From<ProtoPusherMessage> for PusherMessage {
395 fn from(value: ProtoPusherMessage) -> Self {
396 Self {
397 event: value.event,
398 channel: value.channel,
399 data: value.data.map(Into::into),
400 name: value.name,
401 user_id: value.user_id,
402 tags: (!value.tags.is_empty())
403 .then_some(value.tags.into_iter().collect::<BTreeMap<_, _>>()),
404 sequence: value.sequence,
405 conflation_key: value.conflation_key,
406 message_id: value.message_id,
407 stream_id: value.stream_id,
408 serial: value.serial,
409 idempotency_key: value.idempotency_key,
410 extras: value.extras.map(Into::into),
411 delta_sequence: value.delta_sequence,
412 delta_conflation_key: value.delta_conflation_key,
413 }
414 }
415}
416
417impl From<MsgpackPusherMessage> for PusherMessage {
418 fn from(value: MsgpackPusherMessage) -> Self {
419 Self {
420 event: value.event,
421 channel: value.channel,
422 data: value.data.map(Into::into),
423 name: value.name,
424 user_id: value.user_id,
425 tags: value.tags,
426 sequence: value.sequence,
427 conflation_key: value.conflation_key,
428 message_id: value.message_id,
429 stream_id: value.stream_id,
430 serial: value.serial,
431 idempotency_key: value.idempotency_key,
432 extras: value.extras.map(Into::into),
433 delta_sequence: value.delta_sequence,
434 delta_conflation_key: value.delta_conflation_key,
435 }
436 }
437}
438
439impl From<ProtoVersionedRealtimeMessage> for VersionedRealtimeMessage {
440 fn from(value: ProtoVersionedRealtimeMessage) -> Self {
441 Self {
442 message: value.message.map(Into::into).unwrap_or(PusherMessage {
443 event: None,
444 channel: None,
445 data: None,
446 name: None,
447 user_id: None,
448 tags: None,
449 sequence: None,
450 conflation_key: None,
451 message_id: None,
452 stream_id: None,
453 serial: None,
454 idempotency_key: None,
455 extras: None,
456 delta_sequence: None,
457 delta_conflation_key: None,
458 }),
459 action: parse_message_action(&value.action),
460 message_serial: value.message_serial,
461 history_serial: value.history_serial,
462 delivery_serial: value.delivery_serial,
463 version: value.version.map(Into::into),
464 }
465 }
466}
467
468impl From<MsgpackVersionedRealtimeMessage> for VersionedRealtimeMessage {
469 fn from(value: MsgpackVersionedRealtimeMessage) -> Self {
470 Self {
471 message: value.message.into(),
472 action: value.action,
473 message_serial: value.message_serial,
474 history_serial: value.history_serial,
475 delivery_serial: value.delivery_serial,
476 version: value.version.map(Into::into),
477 }
478 }
479}
480
481impl From<MessageData> for ProtoMessageData {
482 fn from(value: MessageData) -> Self {
483 let kind = match value {
484 MessageData::String(s) => Some(proto_message_data::Kind::String(s)),
485 MessageData::Structured {
486 channel_data,
487 channel,
488 user_data,
489 extra,
490 } => Some(proto_message_data::Kind::Structured(ProtoStructuredData {
491 channel_data,
492 channel,
493 user_data,
494 extra: extra
495 .into_iter()
496 .map(|(k, v)| {
497 (
498 k,
499 sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
500 )
501 })
502 .collect(),
503 })),
504 MessageData::Json(v) => Some(proto_message_data::Kind::Json(
505 sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
506 )),
507 };
508
509 Self { kind }
510 }
511}
512
513impl From<MessageData> for MsgpackMessageData {
514 fn from(value: MessageData) -> Self {
515 match value {
516 MessageData::String(s) => Self::String(s),
517 MessageData::Structured {
518 channel_data,
519 channel,
520 user_data,
521 extra,
522 } => Self::Structured(MsgpackStructuredData {
523 channel_data,
524 channel,
525 user_data,
526 extra: extra
527 .into_iter()
528 .map(|(k, v)| {
529 (
530 k,
531 sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
532 )
533 })
534 .collect(),
535 }),
536 MessageData::Json(v) => {
537 Self::Json(sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()))
538 }
539 }
540 }
541}
542
543impl From<ProtoMessageData> for MessageData {
544 fn from(value: ProtoMessageData) -> Self {
545 match value.kind {
546 Some(proto_message_data::Kind::String(s)) => MessageData::String(s),
547 Some(proto_message_data::Kind::Structured(s)) => MessageData::Structured {
548 channel_data: s.channel_data,
549 channel: s.channel,
550 user_data: s.user_data,
551 extra: s
552 .extra
553 .into_iter()
554 .map(|(k, v)| {
555 let parsed =
556 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
557 (k, parsed)
558 })
559 .collect::<AHashMap<_, _>>(),
560 },
561 Some(proto_message_data::Kind::Json(v)) => MessageData::Json(
562 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
563 ),
564 None => MessageData::Json(Value::new_null()),
565 }
566 }
567}
568
569impl From<MsgpackMessageData> for MessageData {
570 fn from(value: MsgpackMessageData) -> Self {
571 match value {
572 MsgpackMessageData::String(s) => MessageData::String(s),
573 MsgpackMessageData::Structured(s) => MessageData::Structured {
574 channel_data: s.channel_data,
575 channel: s.channel,
576 user_data: s.user_data,
577 extra: s
578 .extra
579 .into_iter()
580 .map(|(k, v)| {
581 let parsed =
582 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
583 (k, parsed)
584 })
585 .collect::<AHashMap<_, _>>(),
586 },
587 MsgpackMessageData::Json(v) => MessageData::Json(
588 sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
589 ),
590 }
591 }
592}
593
594impl From<MessageExtras> for ProtoMessageExtras {
595 fn from(value: MessageExtras) -> Self {
596 Self {
597 headers: value
598 .headers
599 .unwrap_or_default()
600 .into_iter()
601 .map(|(k, v)| (k, v.into()))
602 .collect(),
603 ephemeral: value.ephemeral,
604 idempotency_key: value.idempotency_key,
605 echo: value.echo,
606 }
607 }
608}
609
610impl From<MessageExtras> for MsgpackMessageExtras {
611 fn from(value: MessageExtras) -> Self {
612 Self {
613 headers: value
614 .headers
615 .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
616 ephemeral: value.ephemeral,
617 idempotency_key: value.idempotency_key,
618 echo: value.echo,
619 }
620 }
621}
622
623impl From<ProtoMessageExtras> for MessageExtras {
624 fn from(value: ProtoMessageExtras) -> Self {
625 Self {
626 headers: (!value.headers.is_empty()).then_some(
627 value
628 .headers
629 .into_iter()
630 .map(|(k, v)| (k, v.into()))
631 .collect(),
632 ),
633 ephemeral: value.ephemeral,
634 idempotency_key: value.idempotency_key,
635 echo: value.echo,
636 }
637 }
638}
639
640impl From<MsgpackMessageExtras> for MessageExtras {
641 fn from(value: MsgpackMessageExtras) -> Self {
642 Self {
643 headers: value
644 .headers
645 .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
646 ephemeral: value.ephemeral,
647 idempotency_key: value.idempotency_key,
648 echo: value.echo,
649 }
650 }
651}
652
653impl From<ExtrasValue> for ProtoExtrasValue {
654 fn from(value: ExtrasValue) -> Self {
655 let kind = match value {
656 ExtrasValue::String(s) => Some(proto_extras_value::Kind::String(s)),
657 ExtrasValue::Number(n) => Some(proto_extras_value::Kind::Number(n)),
658 ExtrasValue::Bool(b) => Some(proto_extras_value::Kind::Bool(b)),
659 };
660 Self { kind }
661 }
662}
663
664impl From<ExtrasValue> for MsgpackExtrasValue {
665 fn from(value: ExtrasValue) -> Self {
666 match value {
667 ExtrasValue::String(s) => Self::String(s),
668 ExtrasValue::Number(n) => Self::Number(n),
669 ExtrasValue::Bool(b) => Self::Bool(b),
670 }
671 }
672}
673
674impl From<ProtoExtrasValue> for ExtrasValue {
675 fn from(value: ProtoExtrasValue) -> Self {
676 match value.kind {
677 Some(proto_extras_value::Kind::String(s)) => ExtrasValue::String(s),
678 Some(proto_extras_value::Kind::Number(n)) => ExtrasValue::Number(n),
679 Some(proto_extras_value::Kind::Bool(b)) => ExtrasValue::Bool(b),
680 None => ExtrasValue::String(String::new()),
681 }
682 }
683}
684
685impl From<MsgpackExtrasValue> for ExtrasValue {
686 fn from(value: MsgpackExtrasValue) -> Self {
687 match value {
688 MsgpackExtrasValue::String(s) => ExtrasValue::String(s),
689 MsgpackExtrasValue::Number(n) => ExtrasValue::Number(n),
690 MsgpackExtrasValue::Bool(b) => ExtrasValue::Bool(b),
691 }
692 }
693}
694
695impl From<MessageVersionMetadata> for ProtoMessageVersionMetadata {
696 fn from(value: MessageVersionMetadata) -> Self {
697 Self {
698 serial: value.serial,
699 client_id: value.client_id,
700 timestamp_ms: value.timestamp_ms,
701 description: value.description,
702 metadata_json: value
703 .metadata
704 .and_then(|value| sonic_rs::to_string(&value).ok()),
705 }
706 }
707}
708
709impl From<MessageVersionMetadata> for MsgpackMessageVersionMetadata {
710 fn from(value: MessageVersionMetadata) -> Self {
711 Self {
712 serial: value.serial,
713 client_id: value.client_id,
714 timestamp_ms: value.timestamp_ms,
715 description: value.description,
716 metadata_json: value
717 .metadata
718 .and_then(|value| sonic_rs::to_string(&value).ok()),
719 }
720 }
721}
722
723impl From<ProtoMessageVersionMetadata> for MessageVersionMetadata {
724 fn from(value: ProtoMessageVersionMetadata) -> Self {
725 Self {
726 serial: value.serial,
727 client_id: value.client_id,
728 timestamp_ms: value.timestamp_ms,
729 description: value.description,
730 metadata: value
731 .metadata_json
732 .and_then(|raw| sonic_rs::from_str(&raw).ok()),
733 }
734 }
735}
736
737impl From<MsgpackMessageVersionMetadata> for MessageVersionMetadata {
738 fn from(value: MsgpackMessageVersionMetadata) -> Self {
739 Self {
740 serial: value.serial,
741 client_id: value.client_id,
742 timestamp_ms: value.timestamp_ms,
743 description: value.description,
744 metadata: value
745 .metadata_json
746 .and_then(|raw| sonic_rs::from_str(&raw).ok()),
747 }
748 }
749}
750
751fn parse_message_action(raw: &str) -> MessageAction {
752 match raw {
753 "message.create" => MessageAction::Create,
754 "message.update" => MessageAction::Update,
755 "message.delete" => MessageAction::Delete,
756 "message.append" => MessageAction::Append,
757 "message.summary" => MessageAction::Summary,
758 _ => MessageAction::Update,
759 }
760}
761
762#[cfg(test)]
763mod tests {
764 use super::*;
765 use crate::versioned_messages::{
766 MessageAction, MessageVersionMetadata, VersionedRealtimeMessage,
767 };
768
769 fn sample_message() -> PusherMessage {
770 PusherMessage {
771 event: Some("sockudo:test".to_string()),
772 channel: Some("chat:room-1".to_string()),
773 data: Some(MessageData::Json(sonic_rs::json!({
774 "hello": "world",
775 "count": 3,
776 "nested": { "ok": true },
777 "items": [1, 2, 3]
778 }))),
779 name: None,
780 user_id: Some("user-1".to_string()),
781 tags: Some(BTreeMap::from([
782 ("region".to_string(), "eu".to_string()),
783 ("tier".to_string(), "gold".to_string()),
784 ])),
785 sequence: Some(7),
786 conflation_key: Some("room".to_string()),
787 message_id: Some("mid-1".to_string()),
788 stream_id: Some("stream-1".to_string()),
789 serial: Some(9),
790 idempotency_key: Some("idem-1".to_string()),
791 extras: Some(MessageExtras {
792 headers: Some(HashMap::from([
793 (
794 "priority".to_string(),
795 ExtrasValue::String("high".to_string()),
796 ),
797 ("ttl".to_string(), ExtrasValue::Number(5.0)),
798 ])),
799 ephemeral: Some(true),
800 idempotency_key: Some("extra-idem".to_string()),
801 echo: Some(false),
802 }),
803 delta_sequence: Some(11),
804 delta_conflation_key: Some("btc".to_string()),
805 }
806 }
807
808 fn sample_versioned_message() -> VersionedRealtimeMessage {
809 let mut message = sample_message();
810 message.event = Some("sockudo:message.update".to_string());
811
812 VersionedRealtimeMessage {
813 message,
814 action: MessageAction::Update,
815 message_serial: "msg:1".to_string(),
816 history_serial: Some(7),
817 delivery_serial: Some(9),
818 version: Some(MessageVersionMetadata {
819 serial: "ver:2".to_string(),
820 client_id: Some("user-1".to_string()),
821 timestamp_ms: 1_713_100_805_000,
822 description: Some("patched".to_string()),
823 metadata: Some(sonic_rs::json!({"source": "test"})),
824 }),
825 }
826 }
827
828 #[test]
829 fn round_trip_messagepack() {
830 let msg = sample_message();
831 let bytes = serialize_message(&msg, WireFormat::MessagePack).unwrap();
832 let decoded = deserialize_message(&bytes, WireFormat::MessagePack).unwrap();
833 assert_eq!(decoded.event, msg.event);
834 assert_eq!(decoded.delta_sequence, msg.delta_sequence);
835 }
836
837 #[test]
838 fn round_trip_protobuf() {
839 let msg = sample_message();
840 let bytes = serialize_message(&msg, WireFormat::Protobuf).unwrap();
841 let decoded = deserialize_message(&bytes, WireFormat::Protobuf).unwrap();
842 assert_eq!(decoded.event, msg.event);
843 assert_eq!(decoded.channel, msg.channel);
844 assert_eq!(decoded.message_id, msg.message_id);
845 assert_eq!(decoded.delta_conflation_key, msg.delta_conflation_key);
846 }
847
848 #[test]
849 fn round_trip_versioned_messagepack() {
850 let msg = sample_versioned_message();
851 let bytes = serialize_versioned_message(&msg, WireFormat::MessagePack).unwrap();
852 let decoded = deserialize_versioned_message(&bytes, WireFormat::MessagePack).unwrap();
853 assert_eq!(decoded.action, msg.action);
854 assert_eq!(decoded.message_serial, msg.message_serial);
855 assert_eq!(decoded.version, msg.version);
856 }
857
858 #[test]
859 fn round_trip_versioned_protobuf() {
860 let msg = sample_versioned_message();
861 let bytes = serialize_versioned_message(&msg, WireFormat::Protobuf).unwrap();
862 let decoded = deserialize_versioned_message(&bytes, WireFormat::Protobuf).unwrap();
863 assert_eq!(decoded.action, msg.action);
864 assert_eq!(decoded.message_serial, msg.message_serial);
865 assert_eq!(decoded.history_serial, msg.history_serial);
866 assert_eq!(decoded.delivery_serial, msg.delivery_serial);
867 }
868
869 #[test]
870 fn deserialize_versioned_message_rejects_invalid_action_event_pair() {
871 let bytes = sonic_rs::to_vec(&VersionedRealtimeMessage {
872 message: PusherMessage {
873 event: Some("sockudo:message.delete".to_string()),
874 channel: Some("chat:room-1".to_string()),
875 data: Some(MessageData::String("hello".to_string())),
876 name: Some("chat.message".to_string()),
877 user_id: None,
878 tags: None,
879 sequence: None,
880 conflation_key: None,
881 message_id: None,
882 stream_id: None,
883 serial: Some(9),
884 idempotency_key: None,
885 extras: None,
886 delta_sequence: None,
887 delta_conflation_key: None,
888 },
889 action: MessageAction::Update,
890 message_serial: "msg:1".to_string(),
891 history_serial: Some(7),
892 delivery_serial: Some(9),
893 version: Some(MessageVersionMetadata {
894 serial: "ver:2".to_string(),
895 client_id: Some("user-1".to_string()),
896 timestamp_ms: 1_713_100_805_000,
897 description: None,
898 metadata: None,
899 }),
900 })
901 .unwrap();
902
903 let error = deserialize_versioned_message(&bytes, WireFormat::Json).unwrap_err();
904 assert!(
905 error.contains("does not match action")
906 || error.contains("JSON deserialization failed"),
907 "unexpected error: {error}"
908 );
909 }
910
911 #[test]
912 fn parse_query_param_accepts_known_values() {
913 assert_eq!(
914 WireFormat::parse_query_param(None).unwrap(),
915 WireFormat::Json
916 );
917 assert_eq!(
918 WireFormat::parse_query_param(Some("json")).unwrap(),
919 WireFormat::Json
920 );
921 assert_eq!(
922 WireFormat::parse_query_param(Some("messagepack")).unwrap(),
923 WireFormat::MessagePack
924 );
925 assert_eq!(
926 WireFormat::parse_query_param(Some("msgpack")).unwrap(),
927 WireFormat::MessagePack
928 );
929 assert_eq!(
930 WireFormat::parse_query_param(Some("protobuf")).unwrap(),
931 WireFormat::Protobuf
932 );
933 assert_eq!(
934 WireFormat::parse_query_param(Some("proto")).unwrap(),
935 WireFormat::Protobuf
936 );
937 }
938
939 #[test]
940 fn parse_query_param_rejects_unknown_value() {
941 assert!(WireFormat::parse_query_param(Some("avro")).is_err());
942 }
943}