Skip to main content

sockudo_protocol/
wire.rs

1use ahash::AHashMap;
2use prost::Message;
3use serde::{Deserialize, Serialize};
4use sonic_rs::Value;
5use std::collections::{BTreeMap, HashMap};
6
7use crate::messages::{ExtrasValue, MessageData, MessageExtras, PusherMessage};
8use crate::versioned_messages::{MessageAction, MessageVersionMetadata, VersionedRealtimeMessage};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
11#[serde(rename_all = "lowercase")]
12pub enum WireFormat {
13    #[default]
14    Json,
15    MessagePack,
16    Protobuf,
17}
18
19impl WireFormat {
20    pub fn from_query_param(value: Option<&str>) -> Self {
21        Self::parse_query_param(value).unwrap_or(Self::Json)
22    }
23
24    pub fn parse_query_param(value: Option<&str>) -> Result<Self, String> {
25        match value.map(|v| v.trim().to_ascii_lowercase()) {
26            None => Ok(Self::Json),
27            Some(v) if v.is_empty() || v == "json" => Ok(Self::Json),
28            Some(v) if v == "msgpack" || v == "messagepack" => Ok(Self::MessagePack),
29            Some(v) if v == "protobuf" || v == "proto" => Ok(Self::Protobuf),
30            Some(v) => Err(format!("unsupported wire format '{v}'")),
31        }
32    }
33
34    pub const fn is_binary(self) -> bool {
35        !matches!(self, Self::Json)
36    }
37}
38
39pub fn serialize_message(message: &PusherMessage, format: WireFormat) -> Result<Vec<u8>, String> {
40    match format {
41        WireFormat::Json => {
42            sonic_rs::to_vec(message).map_err(|e| format!("JSON serialization failed: {e}"))
43        }
44        WireFormat::MessagePack => rmp_serde::to_vec(&MsgpackPusherMessage::from(message.clone()))
45            .map_err(|e| format!("MessagePack serialization failed: {e}")),
46        WireFormat::Protobuf => {
47            let proto = ProtoPusherMessage::from(message.clone());
48            let mut buf = Vec::with_capacity(proto.encoded_len());
49            proto
50                .encode(&mut buf)
51                .map_err(|e| format!("Protobuf serialization failed: {e}"))?;
52            Ok(buf)
53        }
54    }
55}
56
57pub fn deserialize_message(bytes: &[u8], format: WireFormat) -> Result<PusherMessage, String> {
58    match format {
59        WireFormat::Json => {
60            sonic_rs::from_slice(bytes).map_err(|e| format!("JSON deserialization failed: {e}"))
61        }
62        WireFormat::MessagePack => {
63            let msg: MsgpackPusherMessage = rmp_serde::from_slice(bytes)
64                .map_err(|e| format!("MessagePack deserialization failed: {e}"))?;
65            Ok(msg.into())
66        }
67        WireFormat::Protobuf => {
68            let proto = ProtoPusherMessage::decode(bytes)
69                .map_err(|e| format!("Protobuf deserialization failed: {e}"))?;
70            Ok(proto.into())
71        }
72    }
73}
74
75pub fn serialize_versioned_message(
76    message: &VersionedRealtimeMessage,
77    format: WireFormat,
78) -> Result<Vec<u8>, String> {
79    match format {
80        WireFormat::Json => {
81            sonic_rs::to_vec(message).map_err(|e| format!("JSON serialization failed: {e}"))
82        }
83        WireFormat::MessagePack => {
84            rmp_serde::to_vec(&MsgpackVersionedRealtimeMessage::from(message.clone()))
85                .map_err(|e| format!("MessagePack serialization failed: {e}"))
86        }
87        WireFormat::Protobuf => {
88            let proto = ProtoVersionedRealtimeMessage::from(message.clone());
89            let mut buf = Vec::with_capacity(proto.encoded_len());
90            proto
91                .encode(&mut buf)
92                .map_err(|e| format!("Protobuf serialization failed: {e}"))?;
93            Ok(buf)
94        }
95    }
96}
97
98pub fn deserialize_versioned_message(
99    bytes: &[u8],
100    format: WireFormat,
101) -> Result<VersionedRealtimeMessage, String> {
102    let message: VersionedRealtimeMessage = match format {
103        WireFormat::Json => {
104            sonic_rs::from_slice(bytes).map_err(|e| format!("JSON deserialization failed: {e}"))
105        }
106        WireFormat::MessagePack => {
107            let msg: MsgpackVersionedRealtimeMessage = rmp_serde::from_slice(bytes)
108                .map_err(|e| format!("MessagePack deserialization failed: {e}"))?;
109            Ok(msg.into())
110        }
111        WireFormat::Protobuf => {
112            let proto = ProtoVersionedRealtimeMessage::decode(bytes)
113                .map_err(|e| format!("Protobuf deserialization failed: {e}"))?;
114            Ok(proto.into())
115        }
116    }?;
117
118    message.validate_v2()?;
119    Ok(message)
120}
121
122#[derive(Clone, PartialEq, Message)]
123struct ProtoPusherMessage {
124    #[prost(string, optional, tag = "1")]
125    event: Option<String>,
126    #[prost(string, optional, tag = "2")]
127    channel: Option<String>,
128    #[prost(message, optional, tag = "3")]
129    data: Option<ProtoMessageData>,
130    #[prost(string, optional, tag = "4")]
131    name: Option<String>,
132    #[prost(string, optional, tag = "5")]
133    user_id: Option<String>,
134    #[prost(map = "string, string", tag = "6")]
135    tags: HashMap<String, String>,
136    #[prost(uint64, optional, tag = "7")]
137    sequence: Option<u64>,
138    #[prost(string, optional, tag = "8")]
139    conflation_key: Option<String>,
140    #[prost(string, optional, tag = "9")]
141    message_id: Option<String>,
142    #[prost(string, optional, tag = "10")]
143    stream_id: Option<String>,
144    #[prost(uint64, optional, tag = "11")]
145    serial: Option<u64>,
146    #[prost(string, optional, tag = "12")]
147    idempotency_key: Option<String>,
148    #[prost(message, optional, tag = "13")]
149    extras: Option<ProtoMessageExtras>,
150    #[prost(uint64, optional, tag = "14")]
151    delta_sequence: Option<u64>,
152    #[prost(string, optional, tag = "15")]
153    delta_conflation_key: Option<String>,
154}
155
156#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
157struct MsgpackPusherMessage {
158    event: Option<String>,
159    channel: Option<String>,
160    data: Option<MsgpackMessageData>,
161    name: Option<String>,
162    user_id: Option<String>,
163    tags: Option<BTreeMap<String, String>>,
164    sequence: Option<u64>,
165    conflation_key: Option<String>,
166    message_id: Option<String>,
167    stream_id: Option<String>,
168    serial: Option<u64>,
169    idempotency_key: Option<String>,
170    extras: Option<MsgpackMessageExtras>,
171    delta_sequence: Option<u64>,
172    delta_conflation_key: Option<String>,
173}
174
175#[derive(Clone, PartialEq, Message)]
176struct ProtoVersionedRealtimeMessage {
177    #[prost(message, optional, tag = "1")]
178    message: Option<ProtoPusherMessage>,
179    #[prost(string, tag = "2")]
180    action: String,
181    #[prost(string, tag = "3")]
182    message_serial: String,
183    #[prost(uint64, optional, tag = "4")]
184    history_serial: Option<u64>,
185    #[prost(uint64, optional, tag = "5")]
186    delivery_serial: Option<u64>,
187    #[prost(message, optional, tag = "6")]
188    version: Option<ProtoMessageVersionMetadata>,
189}
190
191#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192struct MsgpackVersionedRealtimeMessage {
193    message: MsgpackPusherMessage,
194    action: MessageAction,
195    message_serial: String,
196    history_serial: Option<u64>,
197    delivery_serial: Option<u64>,
198    version: Option<MsgpackMessageVersionMetadata>,
199}
200
201#[derive(Clone, PartialEq, Message)]
202struct ProtoMessageData {
203    #[prost(oneof = "proto_message_data::Kind", tags = "1, 2, 3")]
204    kind: Option<proto_message_data::Kind>,
205}
206
207#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
208#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
209enum MsgpackMessageData {
210    String(String),
211    Structured(MsgpackStructuredData),
212    Json(String),
213}
214
215mod proto_message_data {
216    use super::ProtoStructuredData;
217    use prost::Oneof;
218
219    #[derive(Clone, PartialEq, Oneof)]
220    pub enum Kind {
221        #[prost(string, tag = "1")]
222        String(String),
223        #[prost(message, tag = "2")]
224        Structured(ProtoStructuredData),
225        #[prost(string, tag = "3")]
226        Json(String),
227    }
228}
229
230#[derive(Clone, PartialEq, Message)]
231struct ProtoStructuredData {
232    #[prost(string, optional, tag = "1")]
233    channel_data: Option<String>,
234    #[prost(string, optional, tag = "2")]
235    channel: Option<String>,
236    #[prost(string, optional, tag = "3")]
237    user_data: Option<String>,
238    #[prost(map = "string, string", tag = "4")]
239    extra: HashMap<String, String>,
240}
241
242#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
243struct MsgpackStructuredData {
244    channel_data: Option<String>,
245    channel: Option<String>,
246    user_data: Option<String>,
247    extra: HashMap<String, String>,
248}
249
250#[derive(Clone, PartialEq, Message)]
251struct ProtoMessageExtras {
252    #[prost(map = "string, message", tag = "1")]
253    headers: HashMap<String, ProtoExtrasValue>,
254    #[prost(bool, optional, tag = "2")]
255    ephemeral: Option<bool>,
256    #[prost(string, optional, tag = "3")]
257    idempotency_key: Option<String>,
258    #[prost(bool, optional, tag = "4")]
259    echo: Option<bool>,
260}
261
262#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
263struct MsgpackMessageExtras {
264    headers: Option<HashMap<String, MsgpackExtrasValue>>,
265    ephemeral: Option<bool>,
266    idempotency_key: Option<String>,
267    echo: Option<bool>,
268}
269
270#[derive(Clone, PartialEq, Message)]
271struct ProtoMessageVersionMetadata {
272    #[prost(string, tag = "1")]
273    serial: String,
274    #[prost(string, optional, tag = "2")]
275    client_id: Option<String>,
276    #[prost(int64, tag = "3")]
277    timestamp_ms: i64,
278    #[prost(string, optional, tag = "4")]
279    description: Option<String>,
280    #[prost(string, optional, tag = "5")]
281    metadata_json: Option<String>,
282}
283
284#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
285struct MsgpackMessageVersionMetadata {
286    serial: String,
287    client_id: Option<String>,
288    timestamp_ms: i64,
289    description: Option<String>,
290    metadata_json: Option<String>,
291}
292
293#[derive(Clone, PartialEq, Message)]
294struct ProtoExtrasValue {
295    #[prost(oneof = "proto_extras_value::Kind", tags = "1, 2, 3")]
296    kind: Option<proto_extras_value::Kind>,
297}
298
299#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
300#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
301enum MsgpackExtrasValue {
302    String(String),
303    Number(f64),
304    Bool(bool),
305}
306
307mod proto_extras_value {
308    use prost::Oneof;
309
310    #[derive(Clone, PartialEq, Oneof)]
311    pub enum Kind {
312        #[prost(string, tag = "1")]
313        String(String),
314        #[prost(double, tag = "2")]
315        Number(f64),
316        #[prost(bool, tag = "3")]
317        Bool(bool),
318    }
319}
320
321impl From<PusherMessage> for ProtoPusherMessage {
322    fn from(value: PusherMessage) -> Self {
323        Self {
324            event: value.event,
325            channel: value.channel,
326            data: value.data.map(Into::into),
327            name: value.name,
328            user_id: value.user_id,
329            tags: value
330                .tags
331                .map(|m| m.into_iter().collect())
332                .unwrap_or_default(),
333            sequence: value.sequence,
334            conflation_key: value.conflation_key,
335            message_id: value.message_id,
336            stream_id: value.stream_id,
337            serial: value.serial,
338            idempotency_key: value.idempotency_key,
339            extras: value.extras.map(Into::into),
340            delta_sequence: value.delta_sequence,
341            delta_conflation_key: value.delta_conflation_key,
342        }
343    }
344}
345
346impl From<PusherMessage> for MsgpackPusherMessage {
347    fn from(value: PusherMessage) -> Self {
348        Self {
349            event: value.event,
350            channel: value.channel,
351            data: value.data.map(Into::into),
352            name: value.name,
353            user_id: value.user_id,
354            tags: value.tags,
355            sequence: value.sequence,
356            conflation_key: value.conflation_key,
357            message_id: value.message_id,
358            stream_id: value.stream_id,
359            serial: value.serial,
360            idempotency_key: value.idempotency_key,
361            extras: value.extras.map(Into::into),
362            delta_sequence: value.delta_sequence,
363            delta_conflation_key: value.delta_conflation_key,
364        }
365    }
366}
367
368impl From<VersionedRealtimeMessage> for ProtoVersionedRealtimeMessage {
369    fn from(value: VersionedRealtimeMessage) -> Self {
370        Self {
371            message: Some(ProtoPusherMessage::from(value.message)),
372            action: value.action.as_str().to_string(),
373            message_serial: value.message_serial,
374            history_serial: value.history_serial,
375            delivery_serial: value.delivery_serial,
376            version: value.version.map(Into::into),
377        }
378    }
379}
380
381impl From<VersionedRealtimeMessage> for MsgpackVersionedRealtimeMessage {
382    fn from(value: VersionedRealtimeMessage) -> Self {
383        Self {
384            message: MsgpackPusherMessage::from(value.message),
385            action: value.action,
386            message_serial: value.message_serial,
387            history_serial: value.history_serial,
388            delivery_serial: value.delivery_serial,
389            version: value.version.map(Into::into),
390        }
391    }
392}
393
394impl From<ProtoPusherMessage> for PusherMessage {
395    fn from(value: ProtoPusherMessage) -> Self {
396        Self {
397            event: value.event,
398            channel: value.channel,
399            data: value.data.map(Into::into),
400            name: value.name,
401            user_id: value.user_id,
402            tags: (!value.tags.is_empty())
403                .then_some(value.tags.into_iter().collect::<BTreeMap<_, _>>()),
404            sequence: value.sequence,
405            conflation_key: value.conflation_key,
406            message_id: value.message_id,
407            stream_id: value.stream_id,
408            serial: value.serial,
409            idempotency_key: value.idempotency_key,
410            extras: value.extras.map(Into::into),
411            delta_sequence: value.delta_sequence,
412            delta_conflation_key: value.delta_conflation_key,
413        }
414    }
415}
416
417impl From<MsgpackPusherMessage> for PusherMessage {
418    fn from(value: MsgpackPusherMessage) -> Self {
419        Self {
420            event: value.event,
421            channel: value.channel,
422            data: value.data.map(Into::into),
423            name: value.name,
424            user_id: value.user_id,
425            tags: value.tags,
426            sequence: value.sequence,
427            conflation_key: value.conflation_key,
428            message_id: value.message_id,
429            stream_id: value.stream_id,
430            serial: value.serial,
431            idempotency_key: value.idempotency_key,
432            extras: value.extras.map(Into::into),
433            delta_sequence: value.delta_sequence,
434            delta_conflation_key: value.delta_conflation_key,
435        }
436    }
437}
438
439impl From<ProtoVersionedRealtimeMessage> for VersionedRealtimeMessage {
440    fn from(value: ProtoVersionedRealtimeMessage) -> Self {
441        Self {
442            message: value.message.map(Into::into).unwrap_or(PusherMessage {
443                event: None,
444                channel: None,
445                data: None,
446                name: None,
447                user_id: None,
448                tags: None,
449                sequence: None,
450                conflation_key: None,
451                message_id: None,
452                stream_id: None,
453                serial: None,
454                idempotency_key: None,
455                extras: None,
456                delta_sequence: None,
457                delta_conflation_key: None,
458            }),
459            action: parse_message_action(&value.action),
460            message_serial: value.message_serial,
461            history_serial: value.history_serial,
462            delivery_serial: value.delivery_serial,
463            version: value.version.map(Into::into),
464        }
465    }
466}
467
468impl From<MsgpackVersionedRealtimeMessage> for VersionedRealtimeMessage {
469    fn from(value: MsgpackVersionedRealtimeMessage) -> Self {
470        Self {
471            message: value.message.into(),
472            action: value.action,
473            message_serial: value.message_serial,
474            history_serial: value.history_serial,
475            delivery_serial: value.delivery_serial,
476            version: value.version.map(Into::into),
477        }
478    }
479}
480
481impl From<MessageData> for ProtoMessageData {
482    fn from(value: MessageData) -> Self {
483        let kind = match value {
484            MessageData::String(s) => Some(proto_message_data::Kind::String(s)),
485            MessageData::Structured {
486                channel_data,
487                channel,
488                user_data,
489                extra,
490            } => Some(proto_message_data::Kind::Structured(ProtoStructuredData {
491                channel_data,
492                channel,
493                user_data,
494                extra: extra
495                    .into_iter()
496                    .map(|(k, v)| {
497                        (
498                            k,
499                            sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
500                        )
501                    })
502                    .collect(),
503            })),
504            MessageData::Json(v) => Some(proto_message_data::Kind::Json(
505                sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
506            )),
507        };
508
509        Self { kind }
510    }
511}
512
513impl From<MessageData> for MsgpackMessageData {
514    fn from(value: MessageData) -> Self {
515        match value {
516            MessageData::String(s) => Self::String(s),
517            MessageData::Structured {
518                channel_data,
519                channel,
520                user_data,
521                extra,
522            } => Self::Structured(MsgpackStructuredData {
523                channel_data,
524                channel,
525                user_data,
526                extra: extra
527                    .into_iter()
528                    .map(|(k, v)| {
529                        (
530                            k,
531                            sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
532                        )
533                    })
534                    .collect(),
535            }),
536            MessageData::Json(v) => {
537                Self::Json(sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()))
538            }
539        }
540    }
541}
542
543impl From<ProtoMessageData> for MessageData {
544    fn from(value: ProtoMessageData) -> Self {
545        match value.kind {
546            Some(proto_message_data::Kind::String(s)) => MessageData::String(s),
547            Some(proto_message_data::Kind::Structured(s)) => MessageData::Structured {
548                channel_data: s.channel_data,
549                channel: s.channel,
550                user_data: s.user_data,
551                extra: s
552                    .extra
553                    .into_iter()
554                    .map(|(k, v)| {
555                        let parsed =
556                            sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
557                        (k, parsed)
558                    })
559                    .collect::<AHashMap<_, _>>(),
560            },
561            Some(proto_message_data::Kind::Json(v)) => MessageData::Json(
562                sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
563            ),
564            None => MessageData::Json(Value::new_null()),
565        }
566    }
567}
568
569impl From<MsgpackMessageData> for MessageData {
570    fn from(value: MsgpackMessageData) -> Self {
571        match value {
572            MsgpackMessageData::String(s) => MessageData::String(s),
573            MsgpackMessageData::Structured(s) => MessageData::Structured {
574                channel_data: s.channel_data,
575                channel: s.channel,
576                user_data: s.user_data,
577                extra: s
578                    .extra
579                    .into_iter()
580                    .map(|(k, v)| {
581                        let parsed =
582                            sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
583                        (k, parsed)
584                    })
585                    .collect::<AHashMap<_, _>>(),
586            },
587            MsgpackMessageData::Json(v) => MessageData::Json(
588                sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
589            ),
590        }
591    }
592}
593
594impl From<MessageExtras> for ProtoMessageExtras {
595    fn from(value: MessageExtras) -> Self {
596        Self {
597            headers: value
598                .headers
599                .unwrap_or_default()
600                .into_iter()
601                .map(|(k, v)| (k, v.into()))
602                .collect(),
603            ephemeral: value.ephemeral,
604            idempotency_key: value.idempotency_key,
605            echo: value.echo,
606        }
607    }
608}
609
610impl From<MessageExtras> for MsgpackMessageExtras {
611    fn from(value: MessageExtras) -> Self {
612        Self {
613            headers: value
614                .headers
615                .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
616            ephemeral: value.ephemeral,
617            idempotency_key: value.idempotency_key,
618            echo: value.echo,
619        }
620    }
621}
622
623impl From<ProtoMessageExtras> for MessageExtras {
624    fn from(value: ProtoMessageExtras) -> Self {
625        Self {
626            headers: (!value.headers.is_empty()).then_some(
627                value
628                    .headers
629                    .into_iter()
630                    .map(|(k, v)| (k, v.into()))
631                    .collect(),
632            ),
633            ephemeral: value.ephemeral,
634            idempotency_key: value.idempotency_key,
635            push: None,
636            echo: value.echo,
637        }
638    }
639}
640
641impl From<MsgpackMessageExtras> for MessageExtras {
642    fn from(value: MsgpackMessageExtras) -> Self {
643        Self {
644            headers: value
645                .headers
646                .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
647            ephemeral: value.ephemeral,
648            idempotency_key: value.idempotency_key,
649            push: None,
650            echo: value.echo,
651        }
652    }
653}
654
655impl From<ExtrasValue> for ProtoExtrasValue {
656    fn from(value: ExtrasValue) -> Self {
657        let kind = match value {
658            ExtrasValue::String(s) => Some(proto_extras_value::Kind::String(s)),
659            ExtrasValue::Number(n) => Some(proto_extras_value::Kind::Number(n)),
660            ExtrasValue::Bool(b) => Some(proto_extras_value::Kind::Bool(b)),
661        };
662        Self { kind }
663    }
664}
665
666impl From<ExtrasValue> for MsgpackExtrasValue {
667    fn from(value: ExtrasValue) -> Self {
668        match value {
669            ExtrasValue::String(s) => Self::String(s),
670            ExtrasValue::Number(n) => Self::Number(n),
671            ExtrasValue::Bool(b) => Self::Bool(b),
672        }
673    }
674}
675
676impl From<ProtoExtrasValue> for ExtrasValue {
677    fn from(value: ProtoExtrasValue) -> Self {
678        match value.kind {
679            Some(proto_extras_value::Kind::String(s)) => ExtrasValue::String(s),
680            Some(proto_extras_value::Kind::Number(n)) => ExtrasValue::Number(n),
681            Some(proto_extras_value::Kind::Bool(b)) => ExtrasValue::Bool(b),
682            None => ExtrasValue::String(String::new()),
683        }
684    }
685}
686
687impl From<MsgpackExtrasValue> for ExtrasValue {
688    fn from(value: MsgpackExtrasValue) -> Self {
689        match value {
690            MsgpackExtrasValue::String(s) => ExtrasValue::String(s),
691            MsgpackExtrasValue::Number(n) => ExtrasValue::Number(n),
692            MsgpackExtrasValue::Bool(b) => ExtrasValue::Bool(b),
693        }
694    }
695}
696
697impl From<MessageVersionMetadata> for ProtoMessageVersionMetadata {
698    fn from(value: MessageVersionMetadata) -> Self {
699        Self {
700            serial: value.serial,
701            client_id: value.client_id,
702            timestamp_ms: value.timestamp_ms,
703            description: value.description,
704            metadata_json: value
705                .metadata
706                .and_then(|value| sonic_rs::to_string(&value).ok()),
707        }
708    }
709}
710
711impl From<MessageVersionMetadata> for MsgpackMessageVersionMetadata {
712    fn from(value: MessageVersionMetadata) -> Self {
713        Self {
714            serial: value.serial,
715            client_id: value.client_id,
716            timestamp_ms: value.timestamp_ms,
717            description: value.description,
718            metadata_json: value
719                .metadata
720                .and_then(|value| sonic_rs::to_string(&value).ok()),
721        }
722    }
723}
724
725impl From<ProtoMessageVersionMetadata> for MessageVersionMetadata {
726    fn from(value: ProtoMessageVersionMetadata) -> Self {
727        Self {
728            serial: value.serial,
729            client_id: value.client_id,
730            timestamp_ms: value.timestamp_ms,
731            description: value.description,
732            metadata: value
733                .metadata_json
734                .and_then(|raw| sonic_rs::from_str(&raw).ok()),
735        }
736    }
737}
738
739impl From<MsgpackMessageVersionMetadata> for MessageVersionMetadata {
740    fn from(value: MsgpackMessageVersionMetadata) -> Self {
741        Self {
742            serial: value.serial,
743            client_id: value.client_id,
744            timestamp_ms: value.timestamp_ms,
745            description: value.description,
746            metadata: value
747                .metadata_json
748                .and_then(|raw| sonic_rs::from_str(&raw).ok()),
749        }
750    }
751}
752
753fn parse_message_action(raw: &str) -> MessageAction {
754    match raw {
755        "message.create" => MessageAction::Create,
756        "message.update" => MessageAction::Update,
757        "message.delete" => MessageAction::Delete,
758        "message.append" => MessageAction::Append,
759        "message.summary" => MessageAction::Summary,
760        _ => MessageAction::Update,
761    }
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767    use crate::versioned_messages::{
768        MessageAction, MessageVersionMetadata, VersionedRealtimeMessage,
769    };
770
771    fn sample_message() -> PusherMessage {
772        PusherMessage {
773            event: Some("sockudo:test".to_string()),
774            channel: Some("chat:room-1".to_string()),
775            data: Some(MessageData::Json(sonic_rs::json!({
776                "hello": "world",
777                "count": 3,
778                "nested": { "ok": true },
779                "items": [1, 2, 3]
780            }))),
781            name: None,
782            user_id: Some("user-1".to_string()),
783            tags: Some(BTreeMap::from([
784                ("region".to_string(), "eu".to_string()),
785                ("tier".to_string(), "gold".to_string()),
786            ])),
787            sequence: Some(7),
788            conflation_key: Some("room".to_string()),
789            message_id: Some("mid-1".to_string()),
790            stream_id: Some("stream-1".to_string()),
791            serial: Some(9),
792            idempotency_key: Some("idem-1".to_string()),
793            extras: Some(MessageExtras {
794                headers: Some(HashMap::from([
795                    (
796                        "priority".to_string(),
797                        ExtrasValue::String("high".to_string()),
798                    ),
799                    ("ttl".to_string(), ExtrasValue::Number(5.0)),
800                ])),
801                ephemeral: Some(true),
802                idempotency_key: Some("extra-idem".to_string()),
803                push: None,
804                echo: Some(false),
805            }),
806            delta_sequence: Some(11),
807            delta_conflation_key: Some("btc".to_string()),
808        }
809    }
810
811    fn sample_versioned_message() -> VersionedRealtimeMessage {
812        let mut message = sample_message();
813        message.event = Some("sockudo:message.update".to_string());
814
815        VersionedRealtimeMessage {
816            message,
817            action: MessageAction::Update,
818            message_serial: "msg:1".to_string(),
819            history_serial: Some(7),
820            delivery_serial: Some(9),
821            version: Some(MessageVersionMetadata {
822                serial: "ver:2".to_string(),
823                client_id: Some("user-1".to_string()),
824                timestamp_ms: 1_713_100_805_000,
825                description: Some("patched".to_string()),
826                metadata: Some(sonic_rs::json!({"source": "test"})),
827            }),
828        }
829    }
830
831    #[test]
832    fn round_trip_messagepack() {
833        let msg = sample_message();
834        let bytes = serialize_message(&msg, WireFormat::MessagePack).unwrap();
835        let decoded = deserialize_message(&bytes, WireFormat::MessagePack).unwrap();
836        assert_eq!(decoded.event, msg.event);
837        assert_eq!(decoded.delta_sequence, msg.delta_sequence);
838    }
839
840    #[test]
841    fn round_trip_protobuf() {
842        let msg = sample_message();
843        let bytes = serialize_message(&msg, WireFormat::Protobuf).unwrap();
844        let decoded = deserialize_message(&bytes, WireFormat::Protobuf).unwrap();
845        assert_eq!(decoded.event, msg.event);
846        assert_eq!(decoded.channel, msg.channel);
847        assert_eq!(decoded.message_id, msg.message_id);
848        assert_eq!(decoded.delta_conflation_key, msg.delta_conflation_key);
849    }
850
851    #[test]
852    fn round_trip_versioned_messagepack() {
853        let msg = sample_versioned_message();
854        let bytes = serialize_versioned_message(&msg, WireFormat::MessagePack).unwrap();
855        let decoded = deserialize_versioned_message(&bytes, WireFormat::MessagePack).unwrap();
856        assert_eq!(decoded.action, msg.action);
857        assert_eq!(decoded.message_serial, msg.message_serial);
858        assert_eq!(decoded.version, msg.version);
859    }
860
861    #[test]
862    fn round_trip_versioned_protobuf() {
863        let msg = sample_versioned_message();
864        let bytes = serialize_versioned_message(&msg, WireFormat::Protobuf).unwrap();
865        let decoded = deserialize_versioned_message(&bytes, WireFormat::Protobuf).unwrap();
866        assert_eq!(decoded.action, msg.action);
867        assert_eq!(decoded.message_serial, msg.message_serial);
868        assert_eq!(decoded.history_serial, msg.history_serial);
869        assert_eq!(decoded.delivery_serial, msg.delivery_serial);
870    }
871
872    #[test]
873    fn deserialize_versioned_message_rejects_invalid_action_event_pair() {
874        let bytes = sonic_rs::to_vec(&VersionedRealtimeMessage {
875            message: PusherMessage {
876                event: Some("sockudo:message.delete".to_string()),
877                channel: Some("chat:room-1".to_string()),
878                data: Some(MessageData::String("hello".to_string())),
879                name: Some("chat.message".to_string()),
880                user_id: None,
881                tags: None,
882                sequence: None,
883                conflation_key: None,
884                message_id: None,
885                stream_id: None,
886                serial: Some(9),
887                idempotency_key: None,
888                extras: None,
889                delta_sequence: None,
890                delta_conflation_key: None,
891            },
892            action: MessageAction::Update,
893            message_serial: "msg:1".to_string(),
894            history_serial: Some(7),
895            delivery_serial: Some(9),
896            version: Some(MessageVersionMetadata {
897                serial: "ver:2".to_string(),
898                client_id: Some("user-1".to_string()),
899                timestamp_ms: 1_713_100_805_000,
900                description: None,
901                metadata: None,
902            }),
903        })
904        .unwrap();
905
906        let error = deserialize_versioned_message(&bytes, WireFormat::Json).unwrap_err();
907        assert!(
908            error.contains("does not match action")
909                || error.contains("JSON deserialization failed"),
910            "unexpected error: {error}"
911        );
912    }
913
914    #[test]
915    fn parse_query_param_accepts_known_values() {
916        assert_eq!(
917            WireFormat::parse_query_param(None).unwrap(),
918            WireFormat::Json
919        );
920        assert_eq!(
921            WireFormat::parse_query_param(Some("json")).unwrap(),
922            WireFormat::Json
923        );
924        assert_eq!(
925            WireFormat::parse_query_param(Some("messagepack")).unwrap(),
926            WireFormat::MessagePack
927        );
928        assert_eq!(
929            WireFormat::parse_query_param(Some("msgpack")).unwrap(),
930            WireFormat::MessagePack
931        );
932        assert_eq!(
933            WireFormat::parse_query_param(Some("protobuf")).unwrap(),
934            WireFormat::Protobuf
935        );
936        assert_eq!(
937            WireFormat::parse_query_param(Some("proto")).unwrap(),
938            WireFormat::Protobuf
939        );
940    }
941
942    #[test]
943    fn parse_query_param_rejects_unknown_value() {
944        assert!(WireFormat::parse_query_param(Some("avro")).is_err());
945    }
946}