Skip to main content

sockudo_protocol/
wire.rs

1use ahash::AHashMap;
2use prost::Message;
3use serde::{Deserialize, Serialize};
4use sonic_rs::Value;
5use std::collections::{BTreeMap, HashMap};
6
7use crate::messages::{ExtrasValue, MessageData, MessageExtras, PusherMessage};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
10#[serde(rename_all = "lowercase")]
11pub enum WireFormat {
12    #[default]
13    Json,
14    MessagePack,
15    Protobuf,
16}
17
18impl WireFormat {
19    pub fn from_query_param(value: Option<&str>) -> Self {
20        Self::parse_query_param(value).unwrap_or(Self::Json)
21    }
22
23    pub fn parse_query_param(value: Option<&str>) -> Result<Self, String> {
24        match value.map(|v| v.trim().to_ascii_lowercase()) {
25            None => Ok(Self::Json),
26            Some(v) if v.is_empty() || v == "json" => Ok(Self::Json),
27            Some(v) if v == "msgpack" || v == "messagepack" => Ok(Self::MessagePack),
28            Some(v) if v == "protobuf" || v == "proto" => Ok(Self::Protobuf),
29            Some(v) => Err(format!("unsupported wire format '{v}'")),
30        }
31    }
32
33    pub const fn is_binary(self) -> bool {
34        !matches!(self, Self::Json)
35    }
36}
37
38pub fn serialize_message(message: &PusherMessage, format: WireFormat) -> Result<Vec<u8>, String> {
39    match format {
40        WireFormat::Json => {
41            sonic_rs::to_vec(message).map_err(|e| format!("JSON serialization failed: {e}"))
42        }
43        WireFormat::MessagePack => rmp_serde::to_vec(&MsgpackPusherMessage::from(message.clone()))
44            .map_err(|e| format!("MessagePack serialization failed: {e}")),
45        WireFormat::Protobuf => {
46            let proto = ProtoPusherMessage::from(message.clone());
47            let mut buf = Vec::with_capacity(proto.encoded_len());
48            proto
49                .encode(&mut buf)
50                .map_err(|e| format!("Protobuf serialization failed: {e}"))?;
51            Ok(buf)
52        }
53    }
54}
55
56pub fn deserialize_message(bytes: &[u8], format: WireFormat) -> Result<PusherMessage, String> {
57    match format {
58        WireFormat::Json => {
59            sonic_rs::from_slice(bytes).map_err(|e| format!("JSON deserialization failed: {e}"))
60        }
61        WireFormat::MessagePack => {
62            let msg: MsgpackPusherMessage = rmp_serde::from_slice(bytes)
63                .map_err(|e| format!("MessagePack deserialization failed: {e}"))?;
64            Ok(msg.into())
65        }
66        WireFormat::Protobuf => {
67            let proto = ProtoPusherMessage::decode(bytes)
68                .map_err(|e| format!("Protobuf deserialization failed: {e}"))?;
69            Ok(proto.into())
70        }
71    }
72}
73
74#[derive(Clone, PartialEq, Message)]
75struct ProtoPusherMessage {
76    #[prost(string, optional, tag = "1")]
77    event: Option<String>,
78    #[prost(string, optional, tag = "2")]
79    channel: Option<String>,
80    #[prost(message, optional, tag = "3")]
81    data: Option<ProtoMessageData>,
82    #[prost(string, optional, tag = "4")]
83    name: Option<String>,
84    #[prost(string, optional, tag = "5")]
85    user_id: Option<String>,
86    #[prost(map = "string, string", tag = "6")]
87    tags: HashMap<String, String>,
88    #[prost(uint64, optional, tag = "7")]
89    sequence: Option<u64>,
90    #[prost(string, optional, tag = "8")]
91    conflation_key: Option<String>,
92    #[prost(string, optional, tag = "9")]
93    message_id: Option<String>,
94    #[prost(string, optional, tag = "10")]
95    stream_id: Option<String>,
96    #[prost(uint64, optional, tag = "11")]
97    serial: Option<u64>,
98    #[prost(string, optional, tag = "12")]
99    idempotency_key: Option<String>,
100    #[prost(message, optional, tag = "13")]
101    extras: Option<ProtoMessageExtras>,
102    #[prost(uint64, optional, tag = "14")]
103    delta_sequence: Option<u64>,
104    #[prost(string, optional, tag = "15")]
105    delta_conflation_key: Option<String>,
106}
107
108#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
109struct MsgpackPusherMessage {
110    event: Option<String>,
111    channel: Option<String>,
112    data: Option<MsgpackMessageData>,
113    name: Option<String>,
114    user_id: Option<String>,
115    tags: Option<BTreeMap<String, String>>,
116    sequence: Option<u64>,
117    conflation_key: Option<String>,
118    message_id: Option<String>,
119    stream_id: Option<String>,
120    serial: Option<u64>,
121    idempotency_key: Option<String>,
122    extras: Option<MsgpackMessageExtras>,
123    delta_sequence: Option<u64>,
124    delta_conflation_key: Option<String>,
125}
126
127#[derive(Clone, PartialEq, Message)]
128struct ProtoMessageData {
129    #[prost(oneof = "proto_message_data::Kind", tags = "1, 2, 3")]
130    kind: Option<proto_message_data::Kind>,
131}
132
133#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
135enum MsgpackMessageData {
136    String(String),
137    Structured(MsgpackStructuredData),
138    Json(String),
139}
140
141mod proto_message_data {
142    use super::ProtoStructuredData;
143    use prost::Oneof;
144
145    #[derive(Clone, PartialEq, Oneof)]
146    pub enum Kind {
147        #[prost(string, tag = "1")]
148        String(String),
149        #[prost(message, tag = "2")]
150        Structured(ProtoStructuredData),
151        #[prost(string, tag = "3")]
152        Json(String),
153    }
154}
155
156#[derive(Clone, PartialEq, Message)]
157struct ProtoStructuredData {
158    #[prost(string, optional, tag = "1")]
159    channel_data: Option<String>,
160    #[prost(string, optional, tag = "2")]
161    channel: Option<String>,
162    #[prost(string, optional, tag = "3")]
163    user_data: Option<String>,
164    #[prost(map = "string, string", tag = "4")]
165    extra: HashMap<String, String>,
166}
167
168#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
169struct MsgpackStructuredData {
170    channel_data: Option<String>,
171    channel: Option<String>,
172    user_data: Option<String>,
173    extra: HashMap<String, String>,
174}
175
176#[derive(Clone, PartialEq, Message)]
177struct ProtoMessageExtras {
178    #[prost(map = "string, message", tag = "1")]
179    headers: HashMap<String, ProtoExtrasValue>,
180    #[prost(bool, optional, tag = "2")]
181    ephemeral: Option<bool>,
182    #[prost(string, optional, tag = "3")]
183    idempotency_key: Option<String>,
184    #[prost(bool, optional, tag = "4")]
185    echo: Option<bool>,
186}
187
188#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
189struct MsgpackMessageExtras {
190    headers: Option<HashMap<String, MsgpackExtrasValue>>,
191    ephemeral: Option<bool>,
192    idempotency_key: Option<String>,
193    echo: Option<bool>,
194}
195
196#[derive(Clone, PartialEq, Message)]
197struct ProtoExtrasValue {
198    #[prost(oneof = "proto_extras_value::Kind", tags = "1, 2, 3")]
199    kind: Option<proto_extras_value::Kind>,
200}
201
202#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
203#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
204enum MsgpackExtrasValue {
205    String(String),
206    Number(f64),
207    Bool(bool),
208}
209
210mod proto_extras_value {
211    use prost::Oneof;
212
213    #[derive(Clone, PartialEq, Oneof)]
214    pub enum Kind {
215        #[prost(string, tag = "1")]
216        String(String),
217        #[prost(double, tag = "2")]
218        Number(f64),
219        #[prost(bool, tag = "3")]
220        Bool(bool),
221    }
222}
223
224impl From<PusherMessage> for ProtoPusherMessage {
225    fn from(value: PusherMessage) -> Self {
226        Self {
227            event: value.event,
228            channel: value.channel,
229            data: value.data.map(Into::into),
230            name: value.name,
231            user_id: value.user_id,
232            tags: value
233                .tags
234                .map(|m| m.into_iter().collect())
235                .unwrap_or_default(),
236            sequence: value.sequence,
237            conflation_key: value.conflation_key,
238            message_id: value.message_id,
239            stream_id: value.stream_id,
240            serial: value.serial,
241            idempotency_key: value.idempotency_key,
242            extras: value.extras.map(Into::into),
243            delta_sequence: value.delta_sequence,
244            delta_conflation_key: value.delta_conflation_key,
245        }
246    }
247}
248
249impl From<PusherMessage> for MsgpackPusherMessage {
250    fn from(value: PusherMessage) -> Self {
251        Self {
252            event: value.event,
253            channel: value.channel,
254            data: value.data.map(Into::into),
255            name: value.name,
256            user_id: value.user_id,
257            tags: value.tags,
258            sequence: value.sequence,
259            conflation_key: value.conflation_key,
260            message_id: value.message_id,
261            stream_id: value.stream_id,
262            serial: value.serial,
263            idempotency_key: value.idempotency_key,
264            extras: value.extras.map(Into::into),
265            delta_sequence: value.delta_sequence,
266            delta_conflation_key: value.delta_conflation_key,
267        }
268    }
269}
270
271impl From<ProtoPusherMessage> for PusherMessage {
272    fn from(value: ProtoPusherMessage) -> Self {
273        Self {
274            event: value.event,
275            channel: value.channel,
276            data: value.data.map(Into::into),
277            name: value.name,
278            user_id: value.user_id,
279            tags: (!value.tags.is_empty())
280                .then_some(value.tags.into_iter().collect::<BTreeMap<_, _>>()),
281            sequence: value.sequence,
282            conflation_key: value.conflation_key,
283            message_id: value.message_id,
284            stream_id: value.stream_id,
285            serial: value.serial,
286            idempotency_key: value.idempotency_key,
287            extras: value.extras.map(Into::into),
288            delta_sequence: value.delta_sequence,
289            delta_conflation_key: value.delta_conflation_key,
290        }
291    }
292}
293
294impl From<MsgpackPusherMessage> for PusherMessage {
295    fn from(value: MsgpackPusherMessage) -> Self {
296        Self {
297            event: value.event,
298            channel: value.channel,
299            data: value.data.map(Into::into),
300            name: value.name,
301            user_id: value.user_id,
302            tags: value.tags,
303            sequence: value.sequence,
304            conflation_key: value.conflation_key,
305            message_id: value.message_id,
306            stream_id: value.stream_id,
307            serial: value.serial,
308            idempotency_key: value.idempotency_key,
309            extras: value.extras.map(Into::into),
310            delta_sequence: value.delta_sequence,
311            delta_conflation_key: value.delta_conflation_key,
312        }
313    }
314}
315
316impl From<MessageData> for ProtoMessageData {
317    fn from(value: MessageData) -> Self {
318        let kind = match value {
319            MessageData::String(s) => Some(proto_message_data::Kind::String(s)),
320            MessageData::Structured {
321                channel_data,
322                channel,
323                user_data,
324                extra,
325            } => Some(proto_message_data::Kind::Structured(ProtoStructuredData {
326                channel_data,
327                channel,
328                user_data,
329                extra: extra
330                    .into_iter()
331                    .map(|(k, v)| {
332                        (
333                            k,
334                            sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
335                        )
336                    })
337                    .collect(),
338            })),
339            MessageData::Json(v) => Some(proto_message_data::Kind::Json(
340                sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
341            )),
342        };
343
344        Self { kind }
345    }
346}
347
348impl From<MessageData> for MsgpackMessageData {
349    fn from(value: MessageData) -> Self {
350        match value {
351            MessageData::String(s) => Self::String(s),
352            MessageData::Structured {
353                channel_data,
354                channel,
355                user_data,
356                extra,
357            } => Self::Structured(MsgpackStructuredData {
358                channel_data,
359                channel,
360                user_data,
361                extra: extra
362                    .into_iter()
363                    .map(|(k, v)| {
364                        (
365                            k,
366                            sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()),
367                        )
368                    })
369                    .collect(),
370            }),
371            MessageData::Json(v) => {
372                Self::Json(sonic_rs::to_string(&v).unwrap_or_else(|_| "null".to_string()))
373            }
374        }
375    }
376}
377
378impl From<ProtoMessageData> for MessageData {
379    fn from(value: ProtoMessageData) -> Self {
380        match value.kind {
381            Some(proto_message_data::Kind::String(s)) => MessageData::String(s),
382            Some(proto_message_data::Kind::Structured(s)) => MessageData::Structured {
383                channel_data: s.channel_data,
384                channel: s.channel,
385                user_data: s.user_data,
386                extra: s
387                    .extra
388                    .into_iter()
389                    .map(|(k, v)| {
390                        let parsed =
391                            sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
392                        (k, parsed)
393                    })
394                    .collect::<AHashMap<_, _>>(),
395            },
396            Some(proto_message_data::Kind::Json(v)) => MessageData::Json(
397                sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
398            ),
399            None => MessageData::Json(Value::new_null()),
400        }
401    }
402}
403
404impl From<MsgpackMessageData> for MessageData {
405    fn from(value: MsgpackMessageData) -> Self {
406        match value {
407            MsgpackMessageData::String(s) => MessageData::String(s),
408            MsgpackMessageData::Structured(s) => MessageData::Structured {
409                channel_data: s.channel_data,
410                channel: s.channel,
411                user_data: s.user_data,
412                extra: s
413                    .extra
414                    .into_iter()
415                    .map(|(k, v)| {
416                        let parsed =
417                            sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str()));
418                        (k, parsed)
419                    })
420                    .collect::<AHashMap<_, _>>(),
421            },
422            MsgpackMessageData::Json(v) => MessageData::Json(
423                sonic_rs::from_str(&v).unwrap_or_else(|_| Value::from(v.as_str())),
424            ),
425        }
426    }
427}
428
429impl From<MessageExtras> for ProtoMessageExtras {
430    fn from(value: MessageExtras) -> Self {
431        Self {
432            headers: value
433                .headers
434                .unwrap_or_default()
435                .into_iter()
436                .map(|(k, v)| (k, v.into()))
437                .collect(),
438            ephemeral: value.ephemeral,
439            idempotency_key: value.idempotency_key,
440            echo: value.echo,
441        }
442    }
443}
444
445impl From<MessageExtras> for MsgpackMessageExtras {
446    fn from(value: MessageExtras) -> Self {
447        Self {
448            headers: value
449                .headers
450                .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
451            ephemeral: value.ephemeral,
452            idempotency_key: value.idempotency_key,
453            echo: value.echo,
454        }
455    }
456}
457
458impl From<ProtoMessageExtras> for MessageExtras {
459    fn from(value: ProtoMessageExtras) -> Self {
460        Self {
461            headers: (!value.headers.is_empty()).then_some(
462                value
463                    .headers
464                    .into_iter()
465                    .map(|(k, v)| (k, v.into()))
466                    .collect(),
467            ),
468            ephemeral: value.ephemeral,
469            idempotency_key: value.idempotency_key,
470            echo: value.echo,
471        }
472    }
473}
474
475impl From<MsgpackMessageExtras> for MessageExtras {
476    fn from(value: MsgpackMessageExtras) -> Self {
477        Self {
478            headers: value
479                .headers
480                .map(|headers| headers.into_iter().map(|(k, v)| (k, v.into())).collect()),
481            ephemeral: value.ephemeral,
482            idempotency_key: value.idempotency_key,
483            echo: value.echo,
484        }
485    }
486}
487
488impl From<ExtrasValue> for ProtoExtrasValue {
489    fn from(value: ExtrasValue) -> Self {
490        let kind = match value {
491            ExtrasValue::String(s) => Some(proto_extras_value::Kind::String(s)),
492            ExtrasValue::Number(n) => Some(proto_extras_value::Kind::Number(n)),
493            ExtrasValue::Bool(b) => Some(proto_extras_value::Kind::Bool(b)),
494        };
495        Self { kind }
496    }
497}
498
499impl From<ExtrasValue> for MsgpackExtrasValue {
500    fn from(value: ExtrasValue) -> Self {
501        match value {
502            ExtrasValue::String(s) => Self::String(s),
503            ExtrasValue::Number(n) => Self::Number(n),
504            ExtrasValue::Bool(b) => Self::Bool(b),
505        }
506    }
507}
508
509impl From<ProtoExtrasValue> for ExtrasValue {
510    fn from(value: ProtoExtrasValue) -> Self {
511        match value.kind {
512            Some(proto_extras_value::Kind::String(s)) => ExtrasValue::String(s),
513            Some(proto_extras_value::Kind::Number(n)) => ExtrasValue::Number(n),
514            Some(proto_extras_value::Kind::Bool(b)) => ExtrasValue::Bool(b),
515            None => ExtrasValue::String(String::new()),
516        }
517    }
518}
519
520impl From<MsgpackExtrasValue> for ExtrasValue {
521    fn from(value: MsgpackExtrasValue) -> Self {
522        match value {
523            MsgpackExtrasValue::String(s) => ExtrasValue::String(s),
524            MsgpackExtrasValue::Number(n) => ExtrasValue::Number(n),
525            MsgpackExtrasValue::Bool(b) => ExtrasValue::Bool(b),
526        }
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533
534    fn sample_message() -> PusherMessage {
535        PusherMessage {
536            event: Some("sockudo:test".to_string()),
537            channel: Some("chat:room-1".to_string()),
538            data: Some(MessageData::Json(sonic_rs::json!({
539                "hello": "world",
540                "count": 3,
541                "nested": { "ok": true },
542                "items": [1, 2, 3]
543            }))),
544            name: None,
545            user_id: Some("user-1".to_string()),
546            tags: Some(BTreeMap::from([
547                ("region".to_string(), "eu".to_string()),
548                ("tier".to_string(), "gold".to_string()),
549            ])),
550            sequence: Some(7),
551            conflation_key: Some("room".to_string()),
552            message_id: Some("mid-1".to_string()),
553            stream_id: Some("stream-1".to_string()),
554            serial: Some(9),
555            idempotency_key: Some("idem-1".to_string()),
556            extras: Some(MessageExtras {
557                headers: Some(HashMap::from([
558                    (
559                        "priority".to_string(),
560                        ExtrasValue::String("high".to_string()),
561                    ),
562                    ("ttl".to_string(), ExtrasValue::Number(5.0)),
563                ])),
564                ephemeral: Some(true),
565                idempotency_key: Some("extra-idem".to_string()),
566                echo: Some(false),
567            }),
568            delta_sequence: Some(11),
569            delta_conflation_key: Some("btc".to_string()),
570        }
571    }
572
573    #[test]
574    fn round_trip_messagepack() {
575        let msg = sample_message();
576        let bytes = serialize_message(&msg, WireFormat::MessagePack).unwrap();
577        let decoded = deserialize_message(&bytes, WireFormat::MessagePack).unwrap();
578        assert_eq!(decoded.event, msg.event);
579        assert_eq!(decoded.delta_sequence, msg.delta_sequence);
580    }
581
582    #[test]
583    fn round_trip_protobuf() {
584        let msg = sample_message();
585        let bytes = serialize_message(&msg, WireFormat::Protobuf).unwrap();
586        let decoded = deserialize_message(&bytes, WireFormat::Protobuf).unwrap();
587        assert_eq!(decoded.event, msg.event);
588        assert_eq!(decoded.channel, msg.channel);
589        assert_eq!(decoded.message_id, msg.message_id);
590        assert_eq!(decoded.delta_conflation_key, msg.delta_conflation_key);
591    }
592
593    #[test]
594    fn parse_query_param_accepts_known_values() {
595        assert_eq!(
596            WireFormat::parse_query_param(None).unwrap(),
597            WireFormat::Json
598        );
599        assert_eq!(
600            WireFormat::parse_query_param(Some("json")).unwrap(),
601            WireFormat::Json
602        );
603        assert_eq!(
604            WireFormat::parse_query_param(Some("messagepack")).unwrap(),
605            WireFormat::MessagePack
606        );
607        assert_eq!(
608            WireFormat::parse_query_param(Some("msgpack")).unwrap(),
609            WireFormat::MessagePack
610        );
611        assert_eq!(
612            WireFormat::parse_query_param(Some("protobuf")).unwrap(),
613            WireFormat::Protobuf
614        );
615        assert_eq!(
616            WireFormat::parse_query_param(Some("proto")).unwrap(),
617            WireFormat::Protobuf
618        );
619    }
620
621    #[test]
622    fn parse_query_param_rejects_unknown_value() {
623        assert!(WireFormat::parse_query_param(Some("avro")).is_err());
624    }
625}