Skip to main content

sockudo_protocol/
messages.rs

1use ahash::AHashMap;
2use serde::de::Error as _;
3use serde::{Deserialize, Serialize};
4use serde_json::Value as JsonValue;
5use sonic_rs::prelude::*;
6use sonic_rs::{Value, json};
7use std::collections::{BTreeMap, HashMap};
8use std::time::Duration;
9
10use crate::protocol_version::ProtocolVersion;
11
12/// Allowed value types for extras.headers.
13/// Flat only — no Object or Array variant so nesting is structurally impossible.
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15#[serde(untagged)]
16pub enum ExtrasValue {
17    String(String),
18    Number(f64),
19    Bool(bool),
20}
21
22/// Structured metadata envelope for V2-specific message features.
23///
24/// Present on the wire for V2 connections only. V1 connections receive messages
25/// with extras stripped entirely. Pusher SDKs ignore unknown fields so the
26/// field is safe to carry through internal pipelines even when the publisher
27/// is V1.
28#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
29#[serde(rename_all = "camelCase")]
30pub struct MessageExtras {
31    /// Flat metadata for server-side event name filtering.
32    /// Must be a flat object — no nested objects, no arrays.
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub headers: Option<HashMap<String, ExtrasValue>>,
35
36    /// If true: skip connection recovery buffer and webhook forwarding.
37    /// Deliver to currently connected V2 subscribers only.
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub ephemeral: Option<bool>,
40
41    /// Server-side deduplication key. If the same key arrives again within
42    /// the app's idempotency TTL window, the message is silently dropped.
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub idempotency_key: Option<String>,
45
46    /// Per-message echo control. Overrides the connection-level echo setting
47    /// when explicitly set.
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub echo: Option<bool>,
50}
51
52impl MessageExtras {
53    /// Validate that headers (if present) contain only flat scalar values.
54    /// This is structurally guaranteed by `ExtrasValue` having no Object/Array
55    /// variants, but this method provides an explicit check with a clear error
56    /// when validating raw JSON before deserialization.
57    pub fn validate_headers_from_json(raw: &Value) -> Result<(), String> {
58        if let Some(extras) = raw.get("extras")
59            && let Some(headers) = extras.get("headers")
60            && let Some(obj) = headers.as_object()
61        {
62            for (key, val) in obj.iter() {
63                if val.is_object() || val.is_array() {
64                    return Err(format!(
65                        "extras.headers must be a flat object — nested objects and arrays are not allowed (key: '{key}')"
66                    ));
67                }
68            }
69        }
70        Ok(())
71    }
72}
73
74/// Generate a unique message ID (UUIDv4) for client-side deduplication.
75pub fn generate_message_id() -> String {
76    uuid::Uuid::new_v4().to_string()
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct PresenceData {
81    pub ids: Vec<String>,
82    pub hash: AHashMap<String, Option<Value>>,
83    pub count: usize,
84}
85
86#[derive(Debug, Clone, Serialize, PartialEq)]
87#[serde(untagged)]
88pub enum MessageData {
89    String(String),
90    Structured {
91        #[serde(skip_serializing_if = "Option::is_none")]
92        channel_data: Option<String>,
93        #[serde(skip_serializing_if = "Option::is_none")]
94        channel: Option<String>,
95        #[serde(skip_serializing_if = "Option::is_none")]
96        user_data: Option<String>,
97        #[serde(flatten)]
98        extra: AHashMap<String, Value>,
99    },
100    Json(Value),
101}
102
103impl<'de> Deserialize<'de> for MessageData {
104    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
105    where
106        D: serde::Deserializer<'de>,
107    {
108        let v = JsonValue::deserialize(deserializer)?;
109        if let Some(s) = v.as_str() {
110            return Ok(MessageData::String(s.to_string()));
111        }
112        if let Some(obj) = v.as_object() {
113            // Flatten workaround for sonic-rs issue #114:
114            // manually split known structured keys and keep remaining keys in `extra`.
115            let channel_data = obj
116                .get("channel_data")
117                .and_then(|x| x.as_str())
118                .map(ToString::to_string);
119            let channel = obj
120                .get("channel")
121                .and_then(|x| x.as_str())
122                .map(ToString::to_string);
123            let user_data = obj
124                .get("user_data")
125                .and_then(|x| x.as_str())
126                .map(ToString::to_string);
127
128            if channel_data.is_some() || channel.is_some() || user_data.is_some() {
129                let mut extra = AHashMap::new();
130                for (k, val) in obj.iter() {
131                    if k != "channel_data" && k != "channel" && k != "user_data" {
132                        extra.insert(
133                            k.to_string(),
134                            serde_json_value_to_sonic(val.clone()).map_err(D::Error::custom)?,
135                        );
136                    }
137                }
138                return Ok(MessageData::Structured {
139                    channel_data,
140                    channel,
141                    user_data,
142                    extra,
143                });
144            }
145        }
146        Ok(MessageData::Json(
147            serde_json_value_to_sonic(v).map_err(D::Error::custom)?,
148        ))
149    }
150}
151
152fn serde_json_value_to_sonic(value: JsonValue) -> Result<Value, String> {
153    let encoded = serde_json::to_string(&value)
154        .map_err(|err| format!("failed to encode json value for MessageData: {err}"))?;
155    sonic_rs::from_str(&encoded)
156        .map_err(|err| format!("failed to decode json value for MessageData: {err}"))
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ErrorData {
161    pub code: Option<u16>,
162    pub message: String,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
166pub struct PusherMessage {
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub event: Option<String>,
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub channel: Option<String>,
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub data: Option<MessageData>,
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub name: Option<String>,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub user_id: Option<String>,
177    /// Tags for filtering - uses BTreeMap for deterministic serialization order
178    /// which is required for delta compression to work correctly
179    #[serde(skip_serializing_if = "Option::is_none")]
180    pub tags: Option<BTreeMap<String, String>>,
181    /// Delta compression sequence number for full messages
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub sequence: Option<u64>,
184    /// Delta compression conflation key for message grouping
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub conflation_key: Option<String>,
187    /// Unique message ID for client-side deduplication
188    #[serde(skip_serializing_if = "Option::is_none")]
189    pub message_id: Option<String>,
190    /// Opaque per-channel continuity token for durable history and recovery.
191    /// Changes only when the server can no longer prove continuity for the channel stream.
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub stream_id: Option<String>,
194    /// Monotonically increasing serial for connection recovery.
195    /// Assigned per-channel at broadcast time when connection recovery is enabled.
196    #[serde(skip_serializing_if = "Option::is_none")]
197    pub serial: Option<u64>,
198    /// Idempotency key for cross-region deduplication.
199    /// Threaded from the HTTP publish request through the broadcast pipeline
200    /// so that receiving nodes can register it in their local cache.
201    /// Never sent to WebSocket clients.
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub idempotency_key: Option<String>,
204    /// V2 message extras envelope. Carries ephemeral flag, per-message echo
205    /// control, header-based filtering metadata, and extras-level idempotency.
206    /// Stripped from V1 deliveries; included in V2 wire format.
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub extras: Option<MessageExtras>,
209    /// Delta sequence marker for full messages in V2 delta streams.
210    #[serde(rename = "__delta_seq", skip_serializing_if = "Option::is_none")]
211    pub delta_sequence: Option<u64>,
212    /// Delta conflation key marker for full messages in V2 delta streams.
213    #[serde(rename = "__conflation_key", skip_serializing_if = "Option::is_none")]
214    pub delta_conflation_key: Option<String>,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct PusherApiMessage {
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub name: Option<String>,
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub data: Option<ApiMessageData>,
223    #[serde(skip_serializing_if = "Option::is_none")]
224    pub channel: Option<String>,
225    #[serde(skip_serializing_if = "Option::is_none")]
226    pub channels: Option<Vec<String>>,
227    #[serde(skip_serializing_if = "Option::is_none")]
228    pub socket_id: Option<String>,
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub info: Option<String>,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub tags: Option<AHashMap<String, String>>,
233    /// Per-publish delta compression control.
234    /// - `Some(true)`: Force delta compression for this message (if client supports it)
235    /// - `Some(false)`: Force full message (skip delta compression)
236    /// - `None`: Use default behavior based on channel/global configuration
237    #[serde(skip_serializing_if = "Option::is_none")]
238    pub delta: Option<bool>,
239    /// Idempotency key for deduplicating publish requests.
240    /// If the same key is seen within the TTL window, the server returns the
241    /// cached response without re-broadcasting.
242    #[serde(skip_serializing_if = "Option::is_none")]
243    pub idempotency_key: Option<String>,
244    /// V2 extras envelope. Passed through to PusherMessage for V2 delivery.
245    #[serde(skip_serializing_if = "Option::is_none")]
246    pub extras: Option<MessageExtras>,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct BatchPusherApiMessage {
251    pub batch: Vec<PusherApiMessage>,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
255#[serde(untagged)]
256pub enum ApiMessageData {
257    String(String),
258    Json(Value),
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct SentPusherMessage {
263    #[serde(skip_serializing_if = "Option::is_none")]
264    pub channel: Option<String>,
265    #[serde(skip_serializing_if = "Option::is_none")]
266    pub event: Option<String>,
267    #[serde(skip_serializing_if = "Option::is_none")]
268    pub data: Option<MessageData>,
269}
270
271// Helper implementations
272impl MessageData {
273    pub fn as_string(&self) -> Option<&str> {
274        match self {
275            MessageData::String(s) => Some(s),
276            _ => None,
277        }
278    }
279
280    pub fn into_string(self) -> Option<String> {
281        match self {
282            MessageData::String(s) => Some(s),
283            _ => None,
284        }
285    }
286
287    pub fn as_value(&self) -> Option<&Value> {
288        match self {
289            MessageData::Structured { extra, .. } => extra.values().next(),
290            _ => None,
291        }
292    }
293}
294
295impl From<String> for MessageData {
296    fn from(s: String) -> Self {
297        MessageData::String(s)
298    }
299}
300
301impl From<Value> for MessageData {
302    fn from(v: Value) -> Self {
303        MessageData::Json(v)
304    }
305}
306
307impl PusherMessage {
308    pub fn is_protocol_ping_or_pong(&self) -> bool {
309        let Some(event) = self.event.as_deref() else {
310            return false;
311        };
312
313        matches!(
314            ProtocolVersion::parse_any_protocol_event(event),
315            Some(("ping", _)) | Some(("pong", _))
316        )
317    }
318
319    pub fn connection_established(socket_id: String, activity_timeout: u64) -> Self {
320        Self {
321            event: Some("pusher:connection_established".to_string()),
322            data: Some(MessageData::from(
323                json!({
324                    "socket_id": socket_id,
325                    "activity_timeout": activity_timeout  // Now configurable
326                })
327                .to_string(),
328            )),
329            channel: None,
330            name: None,
331            user_id: None,
332            sequence: None,
333            conflation_key: None,
334            tags: None,
335            message_id: None,
336            stream_id: None,
337            serial: None,
338            idempotency_key: None,
339            extras: None,
340            delta_sequence: None,
341            delta_conflation_key: None,
342        }
343    }
344    pub fn subscription_succeeded(channel: String, presence_data: Option<PresenceData>) -> Self {
345        let data_obj = if let Some(data) = presence_data {
346            json!({
347                "presence": {
348                    "ids": data.ids,
349                    "hash": data.hash,
350                    "count": data.count
351                }
352            })
353        } else {
354            json!({})
355        };
356
357        Self {
358            event: Some("pusher_internal:subscription_succeeded".to_string()),
359            channel: Some(channel),
360            data: Some(MessageData::String(data_obj.to_string())),
361            name: None,
362            user_id: None,
363            sequence: None,
364            conflation_key: None,
365            tags: None,
366            message_id: None,
367            stream_id: None,
368            serial: None,
369            idempotency_key: None,
370            extras: None,
371            delta_sequence: None,
372            delta_conflation_key: None,
373        }
374    }
375
376    pub fn error(code: u16, message: String, channel: Option<String>) -> Self {
377        Self {
378            event: Some("pusher:error".to_string()),
379            data: Some(MessageData::Json(json!({
380                "code": code,
381                "message": message
382            }))),
383            channel,
384            name: None,
385            user_id: None,
386            sequence: None,
387            conflation_key: None,
388            tags: None,
389            message_id: None,
390            stream_id: None,
391            serial: None,
392            idempotency_key: None,
393            extras: None,
394            delta_sequence: None,
395            delta_conflation_key: None,
396        }
397    }
398
399    pub fn ping() -> Self {
400        Self {
401            event: Some("pusher:ping".to_string()),
402            data: None,
403            channel: None,
404            name: None,
405            user_id: None,
406            sequence: None,
407            conflation_key: None,
408            tags: None,
409            message_id: None,
410            stream_id: None,
411            serial: None,
412            idempotency_key: None,
413            extras: None,
414            delta_sequence: None,
415            delta_conflation_key: None,
416        }
417    }
418    pub fn channel_event<S: Into<String>>(event: S, channel: S, data: Value) -> Self {
419        Self {
420            event: Some(event.into()),
421            channel: Some(channel.into()),
422            data: Some(MessageData::String(data.to_string())),
423            name: None,
424            user_id: None,
425            sequence: None,
426            conflation_key: None,
427            tags: None,
428            message_id: None,
429            stream_id: None,
430            serial: None,
431            idempotency_key: None,
432            extras: None,
433            delta_sequence: None,
434            delta_conflation_key: None,
435        }
436    }
437
438    pub fn member_added(channel: String, user_id: String, user_info: Option<Value>) -> Self {
439        Self {
440            event: Some("pusher_internal:member_added".to_string()),
441            channel: Some(channel),
442            // FIX: Use MessageData::String with JSON-encoded string instead of MessageData::Json
443            data: Some(MessageData::String(
444                json!({
445                    "user_id": user_id,
446                    "user_info": user_info.unwrap_or_else(|| json!({}))
447                })
448                .to_string(),
449            )),
450            name: None,
451            user_id: None,
452            sequence: None,
453            conflation_key: None,
454            tags: None,
455            message_id: None,
456            stream_id: None,
457            serial: None,
458            idempotency_key: None,
459            extras: None,
460            delta_sequence: None,
461            delta_conflation_key: None,
462        }
463    }
464
465    pub fn member_removed(channel: String, user_id: String) -> Self {
466        Self {
467            event: Some("pusher_internal:member_removed".to_string()),
468            channel: Some(channel),
469            // FIX: Also apply same fix to member_removed for consistency
470            data: Some(MessageData::String(
471                json!({
472                    "user_id": user_id
473                })
474                .to_string(),
475            )),
476            name: None,
477            user_id: None,
478            sequence: None,
479            conflation_key: None,
480            tags: None,
481            message_id: None,
482            stream_id: None,
483            serial: None,
484            idempotency_key: None,
485            extras: None,
486            delta_sequence: None,
487            delta_conflation_key: None,
488        }
489    }
490
491    // New helper method for pong response
492    pub fn pong() -> Self {
493        Self {
494            event: Some("pusher:pong".to_string()),
495            data: None,
496            channel: None,
497            name: None,
498            user_id: None,
499            sequence: None,
500            conflation_key: None,
501            tags: None,
502            message_id: None,
503            stream_id: None,
504            serial: None,
505            idempotency_key: None,
506            extras: None,
507            delta_sequence: None,
508            delta_conflation_key: None,
509        }
510    }
511
512    // Helper for creating channel info response
513    pub fn channel_info(
514        occupied: bool,
515        subscription_count: Option<u64>,
516        user_count: Option<u64>,
517        cache_data: Option<(String, Duration)>,
518    ) -> Value {
519        let mut response = json!({
520            "occupied": occupied
521        });
522
523        if let Some(count) = subscription_count {
524            response["subscription_count"] = json!(count);
525        }
526
527        if let Some(count) = user_count {
528            response["user_count"] = json!(count);
529        }
530
531        if let Some((data, ttl)) = cache_data {
532            response["cache"] = json!({
533                "data": data,
534                "ttl": ttl.as_secs()
535            });
536        }
537
538        response
539    }
540
541    // Helper for creating channels list response
542    pub fn channels_list(channels_info: AHashMap<String, Value>) -> Value {
543        json!({
544            "channels": channels_info
545        })
546    }
547
548    // Helper for creating user list response
549    pub fn user_list(user_ids: Vec<String>) -> Value {
550        let users = user_ids
551            .into_iter()
552            .map(|id| json!({ "id": id }))
553            .collect::<Vec<_>>();
554
555        json!({ "users": users })
556    }
557
558    // Helper for batch events response
559    pub fn batch_response(batch_info: Vec<Value>) -> Value {
560        json!({ "batch": batch_info })
561    }
562
563    // Helper for simple success response
564    pub fn success_response() -> Value {
565        json!({ "ok": true })
566    }
567
568    pub fn watchlist_online_event(user_ids: Vec<String>) -> Self {
569        Self {
570            event: Some("online".to_string()),
571            channel: None, // Watchlist events don't use channels
572            name: None,
573            data: Some(MessageData::Json(json!({
574                "user_ids": user_ids
575            }))),
576            user_id: None,
577            sequence: None,
578            conflation_key: None,
579            tags: None,
580            message_id: None,
581            stream_id: None,
582            serial: None,
583            idempotency_key: None,
584            extras: None,
585            delta_sequence: None,
586            delta_conflation_key: None,
587        }
588    }
589
590    pub fn watchlist_offline_event(user_ids: Vec<String>) -> Self {
591        Self {
592            event: Some("offline".to_string()),
593            channel: None,
594            name: None,
595            data: Some(MessageData::Json(json!({
596                "user_ids": user_ids
597            }))),
598            user_id: None,
599            sequence: None,
600            conflation_key: None,
601            tags: None,
602            message_id: None,
603            stream_id: None,
604            serial: None,
605            idempotency_key: None,
606            extras: None,
607            delta_sequence: None,
608            delta_conflation_key: None,
609        }
610    }
611
612    pub fn cache_miss_event(channel: String) -> Self {
613        Self {
614            event: Some("pusher:cache_miss".to_string()),
615            channel: Some(channel),
616            data: Some(MessageData::String("{}".to_string())),
617            name: None,
618            user_id: None,
619            sequence: None,
620            conflation_key: None,
621            tags: None,
622            message_id: None,
623            stream_id: None,
624            serial: None,
625            idempotency_key: None,
626            extras: None,
627            delta_sequence: None,
628            delta_conflation_key: None,
629        }
630    }
631
632    pub fn signin_success(user_data: String) -> Self {
633        Self {
634            event: Some("pusher:signin_success".to_string()),
635            data: Some(MessageData::Json(json!({
636                "user_data": user_data
637            }))),
638            channel: None,
639            name: None,
640            user_id: None,
641            sequence: None,
642            conflation_key: None,
643            tags: None,
644            message_id: None,
645            stream_id: None,
646            serial: None,
647            idempotency_key: None,
648            extras: None,
649            delta_sequence: None,
650            delta_conflation_key: None,
651        }
652    }
653
654    /// Create a delta-compressed message
655    pub fn delta_message(
656        channel: String,
657        event: String,
658        delta_base64: String,
659        base_sequence: u32,
660        target_sequence: u32,
661        algorithm: &str,
662    ) -> Self {
663        Self {
664            event: Some("pusher:delta".to_string()),
665            channel: Some(channel.clone()),
666            data: Some(MessageData::String(
667                json!({
668                    "channel": channel,
669                    "event": event,
670                    "delta": delta_base64,
671                    "base_seq": base_sequence,
672                    "target_seq": target_sequence,
673                    "algorithm": algorithm,
674                })
675                .to_string(),
676            )),
677            name: None,
678            user_id: None,
679            sequence: None,
680            conflation_key: None,
681            tags: None,
682            message_id: None,
683            stream_id: None,
684            serial: None,
685            idempotency_key: None,
686            extras: None,
687            delta_sequence: None,
688            delta_conflation_key: None,
689        }
690    }
691
692    /// Rewrite the event name prefix to match the given protocol version.
693    /// This is the single translation point between V1 (`pusher:`) and V2 (`sockudo:`) wire formats.
694    pub fn rewrite_prefix(&mut self, version: ProtocolVersion) {
695        if let Some(ref event) = self.event {
696            self.event = Some(version.rewrite_event_prefix(event));
697        }
698    }
699
700    /// Returns true if this message is ephemeral (skip recovery buffer and webhooks).
701    pub fn is_ephemeral(&self) -> bool {
702        self.extras
703            .as_ref()
704            .and_then(|e| e.ephemeral)
705            .unwrap_or(false)
706    }
707
708    /// Returns the extras-level idempotency key, if set.
709    pub fn extras_idempotency_key(&self) -> Option<&str> {
710        self.extras
711            .as_ref()
712            .and_then(|e| e.idempotency_key.as_deref())
713    }
714
715    /// Resolve whether this message should be echoed back to the publishing socket.
716    /// Message-level `extras.echo` takes precedence over the connection default.
717    pub fn should_echo(&self, connection_default: bool) -> bool {
718        self.extras
719            .as_ref()
720            .and_then(|e| e.echo)
721            .unwrap_or(connection_default)
722    }
723
724    /// Returns the extras headers for server-side filtering, if present.
725    pub fn filter_headers(&self) -> Option<&HashMap<String, ExtrasValue>> {
726        self.extras.as_ref().and_then(|e| e.headers.as_ref())
727    }
728
729    /// Returns true if the given protocol version should receive extras in delivered messages.
730    pub fn should_include_extras(protocol: &ProtocolVersion) -> bool {
731        matches!(protocol, ProtocolVersion::V2)
732    }
733
734    /// Add base sequence marker to a full message for delta tracking
735    pub fn add_base_sequence(mut self, base_sequence: u32) -> Self {
736        if let Some(MessageData::String(ref data_str)) = self.data
737            && let Ok(mut data_obj) = sonic_rs::from_str::<Value>(data_str)
738            && let Some(obj) = data_obj.as_object_mut()
739        {
740            obj.insert("__delta_base_seq", json!(base_sequence));
741            self.data = Some(MessageData::String(data_obj.to_string()));
742        }
743        self
744    }
745
746    /// Create delta compression enabled confirmation
747    pub fn delta_compression_enabled(default_algorithm: &str) -> Self {
748        Self {
749            event: Some("pusher:delta_compression_enabled".to_string()),
750            data: Some(MessageData::Json(json!({
751                "enabled": true,
752                "default_algorithm": default_algorithm,
753            }))),
754            channel: None,
755            name: None,
756            user_id: None,
757            sequence: None,
758            conflation_key: None,
759            tags: None,
760            message_id: None,
761            stream_id: None,
762            serial: None,
763            idempotency_key: None,
764            extras: None,
765            delta_sequence: None,
766            delta_conflation_key: None,
767        }
768    }
769}
770
771// Add a helper extension trait for working with info parameters
772pub trait InfoQueryParser {
773    fn parse_info(&self) -> Vec<&str>;
774    fn wants_user_count(&self) -> bool;
775    fn wants_subscription_count(&self) -> bool;
776    fn wants_cache(&self) -> bool;
777}
778
779impl InfoQueryParser for Option<&String> {
780    fn parse_info(&self) -> Vec<&str> {
781        self.map(|s| s.split(',').collect::<Vec<_>>())
782            .unwrap_or_default()
783    }
784
785    fn wants_user_count(&self) -> bool {
786        self.parse_info().contains(&"user_count")
787    }
788
789    fn wants_subscription_count(&self) -> bool {
790        self.parse_info().contains(&"subscription_count")
791    }
792
793    fn wants_cache(&self) -> bool {
794        self.parse_info().contains(&"cache")
795    }
796}
797
798#[cfg(test)]
799mod tests {
800    use super::PusherMessage;
801
802    #[test]
803    fn protocol_heartbeat_detection_matches_both_prefix_families() {
804        let mut ping = PusherMessage::ping();
805        assert!(ping.is_protocol_ping_or_pong());
806
807        ping.rewrite_prefix(crate::protocol_version::ProtocolVersion::V2);
808        assert!(ping.is_protocol_ping_or_pong());
809
810        let mut pong = PusherMessage::pong();
811        assert!(pong.is_protocol_ping_or_pong());
812
813        pong.rewrite_prefix(crate::protocol_version::ProtocolVersion::V2);
814        assert!(pong.is_protocol_ping_or_pong());
815    }
816
817    #[test]
818    fn protocol_heartbeat_detection_ignores_regular_messages() {
819        let message = PusherMessage::channel_event(
820            "chat.message",
821            "room",
822            sonic_rs::json!({"text": "hello"}),
823        );
824
825        assert!(!message.is_protocol_ping_or_pong());
826    }
827}