Skip to main content

sockudo_protocol/
messages.rs

1use ahash::AHashMap;
2use serde::{Deserialize, Serialize};
3use sonic_rs::prelude::*;
4use sonic_rs::{Value, json};
5use std::collections::{BTreeMap, HashMap};
6use std::time::Duration;
7
8use crate::protocol_version::ProtocolVersion;
9
10/// Allowed value types for extras.headers.
11/// Flat only — no Object or Array variant so nesting is structurally impossible.
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13#[serde(untagged)]
14pub enum ExtrasValue {
15    String(String),
16    Number(f64),
17    Bool(bool),
18}
19
20/// Structured metadata envelope for V2-specific message features.
21///
22/// Present on the wire for V2 connections only. V1 connections receive messages
23/// with extras stripped entirely. Pusher SDKs ignore unknown fields so the
24/// field is safe to carry through internal pipelines even when the publisher
25/// is V1.
26#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
27#[serde(rename_all = "camelCase")]
28pub struct MessageExtras {
29    /// Flat metadata for server-side event name filtering.
30    /// Must be a flat object — no nested objects, no arrays.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub headers: Option<HashMap<String, ExtrasValue>>,
33
34    /// If true: skip connection recovery buffer and webhook forwarding.
35    /// Deliver to currently connected V2 subscribers only.
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub ephemeral: Option<bool>,
38
39    /// Server-side deduplication key. If the same key arrives again within
40    /// the app's idempotency TTL window, the message is silently dropped.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub idempotency_key: Option<String>,
43
44    /// Per-message echo control. Overrides the connection-level echo setting
45    /// when explicitly set.
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub echo: Option<bool>,
48}
49
50impl MessageExtras {
51    /// Validate that headers (if present) contain only flat scalar values.
52    /// This is structurally guaranteed by `ExtrasValue` having no Object/Array
53    /// variants, but this method provides an explicit check with a clear error
54    /// when validating raw JSON before deserialization.
55    pub fn validate_headers_from_json(raw: &Value) -> Result<(), String> {
56        if let Some(extras) = raw.get("extras")
57            && let Some(headers) = extras.get("headers")
58            && let Some(obj) = headers.as_object()
59        {
60            for (key, val) in obj.iter() {
61                if val.is_object() || val.is_array() {
62                    return Err(format!(
63                        "extras.headers must be a flat object — nested objects and arrays are not allowed (key: '{key}')"
64                    ));
65                }
66            }
67        }
68        Ok(())
69    }
70}
71
72/// Generate a unique message ID (UUIDv4) for client-side deduplication.
73pub fn generate_message_id() -> String {
74    uuid::Uuid::new_v4().to_string()
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct PresenceData {
79    pub ids: Vec<String>,
80    pub hash: AHashMap<String, Option<Value>>,
81    pub count: usize,
82}
83
84#[derive(Debug, Clone, Serialize)]
85#[serde(untagged)]
86pub enum MessageData {
87    String(String),
88    Structured {
89        #[serde(skip_serializing_if = "Option::is_none")]
90        channel_data: Option<String>,
91        #[serde(skip_serializing_if = "Option::is_none")]
92        channel: Option<String>,
93        #[serde(skip_serializing_if = "Option::is_none")]
94        user_data: Option<String>,
95        #[serde(flatten)]
96        extra: AHashMap<String, Value>,
97    },
98    Json(Value),
99}
100
101impl<'de> Deserialize<'de> for MessageData {
102    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
103    where
104        D: serde::Deserializer<'de>,
105    {
106        let v = Value::deserialize(deserializer)?;
107        if let Some(s) = v.as_str() {
108            return Ok(MessageData::String(s.to_string()));
109        }
110        if let Some(obj) = v.as_object() {
111            // Flatten workaround for sonic-rs issue #114:
112            // manually split known structured keys and keep remaining keys in `extra`.
113            let channel_data = obj
114                .get(&"channel_data")
115                .and_then(|x| x.as_str())
116                .map(ToString::to_string);
117            let channel = obj
118                .get(&"channel")
119                .and_then(|x| x.as_str())
120                .map(ToString::to_string);
121            let user_data = obj
122                .get(&"user_data")
123                .and_then(|x| x.as_str())
124                .map(ToString::to_string);
125
126            if channel_data.is_some() || channel.is_some() || user_data.is_some() {
127                let mut extra = AHashMap::new();
128                for (k, val) in obj.iter() {
129                    if k != "channel_data" && k != "channel" && k != "user_data" {
130                        extra.insert(k.to_string(), val.clone());
131                    }
132                }
133                return Ok(MessageData::Structured {
134                    channel_data,
135                    channel,
136                    user_data,
137                    extra,
138                });
139            }
140        }
141        Ok(MessageData::Json(v))
142    }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ErrorData {
147    pub code: Option<u16>,
148    pub message: String,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct PusherMessage {
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub event: Option<String>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub channel: Option<String>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub data: Option<MessageData>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub name: Option<String>,
161    #[serde(skip_serializing_if = "Option::is_none")]
162    pub user_id: Option<String>,
163    /// Tags for filtering - uses BTreeMap for deterministic serialization order
164    /// which is required for delta compression to work correctly
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub tags: Option<BTreeMap<String, String>>,
167    /// Delta compression sequence number for full messages
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub sequence: Option<u64>,
170    /// Delta compression conflation key for message grouping
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub conflation_key: Option<String>,
173    /// Unique message ID for client-side deduplication
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub message_id: Option<String>,
176    /// Monotonically increasing serial for connection recovery.
177    /// Assigned per-channel at broadcast time when connection recovery is enabled.
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub serial: Option<u64>,
180    /// Idempotency key for cross-region deduplication.
181    /// Threaded from the HTTP publish request through the broadcast pipeline
182    /// so that receiving nodes can register it in their local cache.
183    /// Never sent to WebSocket clients.
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub idempotency_key: Option<String>,
186    /// V2 message extras envelope. Carries ephemeral flag, per-message echo
187    /// control, header-based filtering metadata, and extras-level idempotency.
188    /// Stripped from V1 deliveries; included in V2 wire format.
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub extras: Option<MessageExtras>,
191    /// Delta sequence marker for full messages in V2 delta streams.
192    #[serde(rename = "__delta_seq", skip_serializing_if = "Option::is_none")]
193    pub delta_sequence: Option<u64>,
194    /// Delta conflation key marker for full messages in V2 delta streams.
195    #[serde(rename = "__conflation_key", skip_serializing_if = "Option::is_none")]
196    pub delta_conflation_key: Option<String>,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct PusherApiMessage {
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub name: Option<String>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    pub data: Option<ApiMessageData>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub channel: Option<String>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub channels: Option<Vec<String>>,
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub socket_id: Option<String>,
211    #[serde(skip_serializing_if = "Option::is_none")]
212    pub info: Option<String>,
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub tags: Option<AHashMap<String, String>>,
215    /// Per-publish delta compression control.
216    /// - `Some(true)`: Force delta compression for this message (if client supports it)
217    /// - `Some(false)`: Force full message (skip delta compression)
218    /// - `None`: Use default behavior based on channel/global configuration
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub delta: Option<bool>,
221    /// Idempotency key for deduplicating publish requests.
222    /// If the same key is seen within the TTL window, the server returns the
223    /// cached response without re-broadcasting.
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub idempotency_key: Option<String>,
226    /// V2 extras envelope. Passed through to PusherMessage for V2 delivery.
227    #[serde(skip_serializing_if = "Option::is_none")]
228    pub extras: Option<MessageExtras>,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct BatchPusherApiMessage {
233    pub batch: Vec<PusherApiMessage>,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
237#[serde(untagged)]
238pub enum ApiMessageData {
239    String(String),
240    Json(Value),
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct SentPusherMessage {
245    #[serde(skip_serializing_if = "Option::is_none")]
246    pub channel: Option<String>,
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub event: Option<String>,
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub data: Option<MessageData>,
251}
252
253// Helper implementations
254impl MessageData {
255    pub fn as_string(&self) -> Option<&str> {
256        match self {
257            MessageData::String(s) => Some(s),
258            _ => None,
259        }
260    }
261
262    pub fn into_string(self) -> Option<String> {
263        match self {
264            MessageData::String(s) => Some(s),
265            _ => None,
266        }
267    }
268
269    pub fn as_value(&self) -> Option<&Value> {
270        match self {
271            MessageData::Structured { extra, .. } => extra.values().next(),
272            _ => None,
273        }
274    }
275}
276
277impl From<String> for MessageData {
278    fn from(s: String) -> Self {
279        MessageData::String(s)
280    }
281}
282
283impl From<Value> for MessageData {
284    fn from(v: Value) -> Self {
285        MessageData::Json(v)
286    }
287}
288
289impl PusherMessage {
290    pub fn connection_established(socket_id: String, activity_timeout: u64) -> Self {
291        Self {
292            event: Some("pusher:connection_established".to_string()),
293            data: Some(MessageData::from(
294                json!({
295                    "socket_id": socket_id,
296                    "activity_timeout": activity_timeout  // Now configurable
297                })
298                .to_string(),
299            )),
300            channel: None,
301            name: None,
302            user_id: None,
303            sequence: None,
304            conflation_key: None,
305            tags: None,
306            message_id: None,
307            serial: None,
308            idempotency_key: None,
309            extras: None,
310            delta_sequence: None,
311            delta_conflation_key: None,
312        }
313    }
314    pub fn subscription_succeeded(channel: String, presence_data: Option<PresenceData>) -> Self {
315        let data_obj = if let Some(data) = presence_data {
316            json!({
317                "presence": {
318                    "ids": data.ids,
319                    "hash": data.hash,
320                    "count": data.count
321                }
322            })
323        } else {
324            json!({})
325        };
326
327        Self {
328            event: Some("pusher_internal:subscription_succeeded".to_string()),
329            channel: Some(channel),
330            data: Some(MessageData::String(data_obj.to_string())),
331            name: None,
332            user_id: None,
333            sequence: None,
334            conflation_key: None,
335            tags: None,
336            message_id: None,
337            serial: None,
338            idempotency_key: None,
339            extras: None,
340            delta_sequence: None,
341            delta_conflation_key: None,
342        }
343    }
344
345    pub fn error(code: u16, message: String, channel: Option<String>) -> Self {
346        Self {
347            event: Some("pusher:error".to_string()),
348            data: Some(MessageData::Json(json!({
349                "code": code,
350                "message": message
351            }))),
352            channel,
353            name: None,
354            user_id: None,
355            sequence: None,
356            conflation_key: None,
357            tags: None,
358            message_id: None,
359            serial: None,
360            idempotency_key: None,
361            extras: None,
362            delta_sequence: None,
363            delta_conflation_key: None,
364        }
365    }
366
367    pub fn ping() -> Self {
368        Self {
369            event: Some("pusher:ping".to_string()),
370            data: None,
371            channel: None,
372            name: None,
373            user_id: None,
374            sequence: None,
375            conflation_key: None,
376            tags: None,
377            message_id: None,
378            serial: None,
379            idempotency_key: None,
380            extras: None,
381            delta_sequence: None,
382            delta_conflation_key: None,
383        }
384    }
385    pub fn channel_event<S: Into<String>>(event: S, channel: S, data: Value) -> Self {
386        Self {
387            event: Some(event.into()),
388            channel: Some(channel.into()),
389            data: Some(MessageData::String(data.to_string())),
390            name: None,
391            user_id: None,
392            sequence: None,
393            conflation_key: None,
394            tags: None,
395            message_id: None,
396            serial: None,
397            idempotency_key: None,
398            extras: None,
399            delta_sequence: None,
400            delta_conflation_key: None,
401        }
402    }
403
404    pub fn member_added(channel: String, user_id: String, user_info: Option<Value>) -> Self {
405        Self {
406            event: Some("pusher_internal:member_added".to_string()),
407            channel: Some(channel),
408            // FIX: Use MessageData::String with JSON-encoded string instead of MessageData::Json
409            data: Some(MessageData::String(
410                json!({
411                    "user_id": user_id,
412                    "user_info": user_info.unwrap_or_else(|| json!({}))
413                })
414                .to_string(),
415            )),
416            name: None,
417            user_id: None,
418            sequence: None,
419            conflation_key: None,
420            tags: None,
421            message_id: None,
422            serial: None,
423            idempotency_key: None,
424            extras: None,
425            delta_sequence: None,
426            delta_conflation_key: None,
427        }
428    }
429
430    pub fn member_removed(channel: String, user_id: String) -> Self {
431        Self {
432            event: Some("pusher_internal:member_removed".to_string()),
433            channel: Some(channel),
434            // FIX: Also apply same fix to member_removed for consistency
435            data: Some(MessageData::String(
436                json!({
437                    "user_id": user_id
438                })
439                .to_string(),
440            )),
441            name: None,
442            user_id: None,
443            sequence: None,
444            conflation_key: None,
445            tags: None,
446            message_id: None,
447            serial: None,
448            idempotency_key: None,
449            extras: None,
450            delta_sequence: None,
451            delta_conflation_key: None,
452        }
453    }
454
455    // New helper method for pong response
456    pub fn pong() -> Self {
457        Self {
458            event: Some("pusher:pong".to_string()),
459            data: None,
460            channel: None,
461            name: None,
462            user_id: None,
463            sequence: None,
464            conflation_key: None,
465            tags: None,
466            message_id: None,
467            serial: None,
468            idempotency_key: None,
469            extras: None,
470            delta_sequence: None,
471            delta_conflation_key: None,
472        }
473    }
474
475    // Helper for creating channel info response
476    pub fn channel_info(
477        occupied: bool,
478        subscription_count: Option<u64>,
479        user_count: Option<u64>,
480        cache_data: Option<(String, Duration)>,
481    ) -> Value {
482        let mut response = json!({
483            "occupied": occupied
484        });
485
486        if let Some(count) = subscription_count {
487            response["subscription_count"] = json!(count);
488        }
489
490        if let Some(count) = user_count {
491            response["user_count"] = json!(count);
492        }
493
494        if let Some((data, ttl)) = cache_data {
495            response["cache"] = json!({
496                "data": data,
497                "ttl": ttl.as_secs()
498            });
499        }
500
501        response
502    }
503
504    // Helper for creating channels list response
505    pub fn channels_list(channels_info: AHashMap<String, Value>) -> Value {
506        json!({
507            "channels": channels_info
508        })
509    }
510
511    // Helper for creating user list response
512    pub fn user_list(user_ids: Vec<String>) -> Value {
513        let users = user_ids
514            .into_iter()
515            .map(|id| json!({ "id": id }))
516            .collect::<Vec<_>>();
517
518        json!({ "users": users })
519    }
520
521    // Helper for batch events response
522    pub fn batch_response(batch_info: Vec<Value>) -> Value {
523        json!({ "batch": batch_info })
524    }
525
526    // Helper for simple success response
527    pub fn success_response() -> Value {
528        json!({ "ok": true })
529    }
530
531    pub fn watchlist_online_event(user_ids: Vec<String>) -> Self {
532        Self {
533            event: Some("online".to_string()),
534            channel: None, // Watchlist events don't use channels
535            name: None,
536            data: Some(MessageData::Json(json!({
537                "user_ids": user_ids
538            }))),
539            user_id: None,
540            sequence: None,
541            conflation_key: None,
542            tags: None,
543            message_id: None,
544            serial: None,
545            idempotency_key: None,
546            extras: None,
547            delta_sequence: None,
548            delta_conflation_key: None,
549        }
550    }
551
552    pub fn watchlist_offline_event(user_ids: Vec<String>) -> Self {
553        Self {
554            event: Some("offline".to_string()),
555            channel: None,
556            name: None,
557            data: Some(MessageData::Json(json!({
558                "user_ids": user_ids
559            }))),
560            user_id: None,
561            sequence: None,
562            conflation_key: None,
563            tags: None,
564            message_id: None,
565            serial: None,
566            idempotency_key: None,
567            extras: None,
568            delta_sequence: None,
569            delta_conflation_key: None,
570        }
571    }
572
573    pub fn cache_miss_event(channel: String) -> Self {
574        Self {
575            event: Some("pusher:cache_miss".to_string()),
576            channel: Some(channel),
577            data: Some(MessageData::String("{}".to_string())),
578            name: None,
579            user_id: None,
580            sequence: None,
581            conflation_key: None,
582            tags: None,
583            message_id: None,
584            serial: None,
585            idempotency_key: None,
586            extras: None,
587            delta_sequence: None,
588            delta_conflation_key: None,
589        }
590    }
591
592    pub fn signin_success(user_data: String) -> Self {
593        Self {
594            event: Some("pusher:signin_success".to_string()),
595            data: Some(MessageData::Json(json!({
596                "user_data": user_data
597            }))),
598            channel: None,
599            name: None,
600            user_id: None,
601            sequence: None,
602            conflation_key: None,
603            tags: None,
604            message_id: None,
605            serial: None,
606            idempotency_key: None,
607            extras: None,
608            delta_sequence: None,
609            delta_conflation_key: None,
610        }
611    }
612
613    /// Create a delta-compressed message
614    pub fn delta_message(
615        channel: String,
616        event: String,
617        delta_base64: String,
618        base_sequence: u32,
619        target_sequence: u32,
620        algorithm: &str,
621    ) -> Self {
622        Self {
623            event: Some("pusher:delta".to_string()),
624            channel: Some(channel.clone()),
625            data: Some(MessageData::String(
626                json!({
627                    "channel": channel,
628                    "event": event,
629                    "delta": delta_base64,
630                    "base_seq": base_sequence,
631                    "target_seq": target_sequence,
632                    "algorithm": algorithm,
633                })
634                .to_string(),
635            )),
636            name: None,
637            user_id: None,
638            sequence: None,
639            conflation_key: None,
640            tags: None,
641            message_id: None,
642            serial: None,
643            idempotency_key: None,
644            extras: None,
645            delta_sequence: None,
646            delta_conflation_key: None,
647        }
648    }
649
650    /// Rewrite the event name prefix to match the given protocol version.
651    /// This is the single translation point between V1 (`pusher:`) and V2 (`sockudo:`) wire formats.
652    pub fn rewrite_prefix(&mut self, version: ProtocolVersion) {
653        if let Some(ref event) = self.event {
654            self.event = Some(version.rewrite_event_prefix(event));
655        }
656    }
657
658    /// Returns true if this message is ephemeral (skip recovery buffer and webhooks).
659    pub fn is_ephemeral(&self) -> bool {
660        self.extras
661            .as_ref()
662            .and_then(|e| e.ephemeral)
663            .unwrap_or(false)
664    }
665
666    /// Returns the extras-level idempotency key, if set.
667    pub fn extras_idempotency_key(&self) -> Option<&str> {
668        self.extras
669            .as_ref()
670            .and_then(|e| e.idempotency_key.as_deref())
671    }
672
673    /// Resolve whether this message should be echoed back to the publishing socket.
674    /// Message-level `extras.echo` takes precedence over the connection default.
675    pub fn should_echo(&self, connection_default: bool) -> bool {
676        self.extras
677            .as_ref()
678            .and_then(|e| e.echo)
679            .unwrap_or(connection_default)
680    }
681
682    /// Returns the extras headers for server-side filtering, if present.
683    pub fn filter_headers(&self) -> Option<&HashMap<String, ExtrasValue>> {
684        self.extras.as_ref().and_then(|e| e.headers.as_ref())
685    }
686
687    /// Returns true if the given protocol version should receive extras in delivered messages.
688    pub fn should_include_extras(protocol: &ProtocolVersion) -> bool {
689        matches!(protocol, ProtocolVersion::V2)
690    }
691
692    /// Add base sequence marker to a full message for delta tracking
693    pub fn add_base_sequence(mut self, base_sequence: u32) -> Self {
694        if let Some(MessageData::String(ref data_str)) = self.data
695            && let Ok(mut data_obj) = sonic_rs::from_str::<Value>(data_str)
696            && let Some(obj) = data_obj.as_object_mut()
697        {
698            obj.insert("__delta_base_seq", json!(base_sequence));
699            self.data = Some(MessageData::String(data_obj.to_string()));
700        }
701        self
702    }
703
704    /// Create delta compression enabled confirmation
705    pub fn delta_compression_enabled(default_algorithm: &str) -> Self {
706        Self {
707            event: Some("pusher:delta_compression_enabled".to_string()),
708            data: Some(MessageData::Json(json!({
709                "enabled": true,
710                "default_algorithm": default_algorithm,
711            }))),
712            channel: None,
713            name: None,
714            user_id: None,
715            sequence: None,
716            conflation_key: None,
717            tags: None,
718            message_id: None,
719            serial: None,
720            idempotency_key: None,
721            extras: None,
722            delta_sequence: None,
723            delta_conflation_key: None,
724        }
725    }
726}
727
728// Add a helper extension trait for working with info parameters
729pub trait InfoQueryParser {
730    fn parse_info(&self) -> Vec<&str>;
731    fn wants_user_count(&self) -> bool;
732    fn wants_subscription_count(&self) -> bool;
733    fn wants_cache(&self) -> bool;
734}
735
736impl InfoQueryParser for Option<&String> {
737    fn parse_info(&self) -> Vec<&str> {
738        self.map(|s| s.split(',').collect::<Vec<_>>())
739            .unwrap_or_default()
740    }
741
742    fn wants_user_count(&self) -> bool {
743        self.parse_info().contains(&"user_count")
744    }
745
746    fn wants_subscription_count(&self) -> bool {
747        self.parse_info().contains(&"subscription_count")
748    }
749
750    fn wants_cache(&self) -> bool {
751        self.parse_info().contains(&"cache")
752    }
753}