Skip to main content

sockudo_protocol/
messages.rs

1use ahash::AHashMap;
2use serde::{Deserialize, Serialize};
3use sonic_rs::prelude::*;
4use sonic_rs::{Value, json};
5use std::collections::BTreeMap;
6use std::time::Duration;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct PresenceData {
10    pub ids: Vec<String>,
11    pub hash: AHashMap<String, Option<Value>>,
12    pub count: usize,
13}
14
15#[derive(Debug, Clone, Serialize)]
16#[serde(untagged)]
17pub enum MessageData {
18    String(String),
19    Structured {
20        #[serde(skip_serializing_if = "Option::is_none")]
21        channel_data: Option<String>,
22        #[serde(skip_serializing_if = "Option::is_none")]
23        channel: Option<String>,
24        #[serde(skip_serializing_if = "Option::is_none")]
25        user_data: Option<String>,
26        #[serde(flatten)]
27        extra: AHashMap<String, Value>,
28    },
29    Json(Value),
30}
31
32impl<'de> Deserialize<'de> for MessageData {
33    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
34    where
35        D: serde::Deserializer<'de>,
36    {
37        let v = Value::deserialize(deserializer)?;
38        if let Some(s) = v.as_str() {
39            return Ok(MessageData::String(s.to_string()));
40        }
41        if let Some(obj) = v.as_object() {
42            // Flatten workaround for sonic-rs issue #114:
43            // manually split known structured keys and keep remaining keys in `extra`.
44            let channel_data = obj
45                .get(&"channel_data")
46                .and_then(|x| x.as_str())
47                .map(ToString::to_string);
48            let channel = obj
49                .get(&"channel")
50                .and_then(|x| x.as_str())
51                .map(ToString::to_string);
52            let user_data = obj
53                .get(&"user_data")
54                .and_then(|x| x.as_str())
55                .map(ToString::to_string);
56
57            if channel_data.is_some() || channel.is_some() || user_data.is_some() {
58                let mut extra = AHashMap::new();
59                for (k, val) in obj.iter() {
60                    if k != "channel_data" && k != "channel" && k != "user_data" {
61                        extra.insert(k.to_string(), val.clone());
62                    }
63                }
64                return Ok(MessageData::Structured {
65                    channel_data,
66                    channel,
67                    user_data,
68                    extra,
69                });
70            }
71        }
72        Ok(MessageData::Json(v))
73    }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ErrorData {
78    pub code: Option<u16>,
79    pub message: String,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct PusherMessage {
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub event: Option<String>,
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub channel: Option<String>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub data: Option<MessageData>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub name: Option<String>,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub user_id: Option<String>,
94    /// Tags for filtering - uses BTreeMap for deterministic serialization order
95    /// which is required for delta compression to work correctly
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub tags: Option<BTreeMap<String, String>>,
98    /// Delta compression sequence number for full messages
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub sequence: Option<u64>,
101    /// Delta compression conflation key for message grouping
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub conflation_key: Option<String>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct PusherApiMessage {
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub name: Option<String>,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub data: Option<ApiMessageData>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub channel: Option<String>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub channels: Option<Vec<String>>,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub socket_id: Option<String>,
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub info: Option<String>,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub tags: Option<AHashMap<String, String>>,
122    /// Per-publish delta compression control.
123    /// - `Some(true)`: Force delta compression for this message (if client supports it)
124    /// - `Some(false)`: Force full message (skip delta compression)
125    /// - `None`: Use default behavior based on channel/global configuration
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub delta: Option<bool>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct BatchPusherApiMessage {
132    pub batch: Vec<PusherApiMessage>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
136#[serde(untagged)]
137pub enum ApiMessageData {
138    String(String),
139    Json(Value),
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct SentPusherMessage {
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub channel: Option<String>,
146    #[serde(skip_serializing_if = "Option::is_none")]
147    pub event: Option<String>,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub data: Option<MessageData>,
150}
151
152// Helper implementations
153impl MessageData {
154    pub fn as_string(&self) -> Option<&str> {
155        match self {
156            MessageData::String(s) => Some(s),
157            _ => None,
158        }
159    }
160
161    pub fn into_string(self) -> Option<String> {
162        match self {
163            MessageData::String(s) => Some(s),
164            _ => None,
165        }
166    }
167
168    pub fn as_value(&self) -> Option<&Value> {
169        match self {
170            MessageData::Structured { extra, .. } => extra.values().next(),
171            _ => None,
172        }
173    }
174}
175
176impl From<String> for MessageData {
177    fn from(s: String) -> Self {
178        MessageData::String(s)
179    }
180}
181
182impl From<Value> for MessageData {
183    fn from(v: Value) -> Self {
184        MessageData::Json(v)
185    }
186}
187
188impl PusherMessage {
189    pub fn connection_established(socket_id: String, activity_timeout: u64) -> Self {
190        Self {
191            event: Some("pusher:connection_established".to_string()),
192            data: Some(MessageData::from(
193                json!({
194                    "socket_id": socket_id,
195                    "activity_timeout": activity_timeout  // Now configurable
196                })
197                .to_string(),
198            )),
199            channel: None,
200            name: None,
201            user_id: None,
202            sequence: None,
203            conflation_key: None,
204            tags: None,
205        }
206    }
207    pub fn subscription_succeeded(channel: String, presence_data: Option<PresenceData>) -> Self {
208        let data_obj = if let Some(data) = presence_data {
209            json!({
210                "presence": {
211                    "ids": data.ids,
212                    "hash": data.hash,
213                    "count": data.count
214                }
215            })
216        } else {
217            json!({})
218        };
219
220        Self {
221            event: Some("pusher_internal:subscription_succeeded".to_string()),
222            channel: Some(channel),
223            data: Some(MessageData::String(data_obj.to_string())),
224            name: None,
225            user_id: None,
226            sequence: None,
227            conflation_key: None,
228            tags: None,
229        }
230    }
231
232    pub fn error(code: u16, message: String, channel: Option<String>) -> Self {
233        Self {
234            event: Some("pusher:error".to_string()),
235            data: Some(MessageData::Json(json!({
236                "code": code,
237                "message": message
238            }))),
239            channel,
240            name: None,
241            user_id: None,
242            sequence: None,
243            conflation_key: None,
244            tags: None,
245        }
246    }
247
248    pub fn ping() -> Self {
249        Self {
250            event: Some("pusher:ping".to_string()),
251            data: None,
252            channel: None,
253            name: None,
254            user_id: None,
255            sequence: None,
256            conflation_key: None,
257            tags: None,
258        }
259    }
260    pub fn channel_event<S: Into<String>>(event: S, channel: S, data: Value) -> Self {
261        Self {
262            event: Some(event.into()),
263            channel: Some(channel.into()),
264            data: Some(MessageData::String(data.to_string())),
265            name: None,
266            user_id: None,
267            sequence: None,
268            conflation_key: None,
269            tags: None,
270        }
271    }
272
273    pub fn member_added(channel: String, user_id: String, user_info: Option<Value>) -> Self {
274        Self {
275            event: Some("pusher_internal:member_added".to_string()),
276            channel: Some(channel),
277            // FIX: Use MessageData::String with JSON-encoded string instead of MessageData::Json
278            data: Some(MessageData::String(
279                json!({
280                    "user_id": user_id,
281                    "user_info": user_info.unwrap_or_else(|| json!({}))
282                })
283                .to_string(),
284            )),
285            name: None,
286            user_id: None,
287            sequence: None,
288            conflation_key: None,
289            tags: None,
290        }
291    }
292
293    pub fn member_removed(channel: String, user_id: String) -> Self {
294        Self {
295            event: Some("pusher_internal:member_removed".to_string()),
296            channel: Some(channel),
297            // FIX: Also apply same fix to member_removed for consistency
298            data: Some(MessageData::String(
299                json!({
300                    "user_id": user_id
301                })
302                .to_string(),
303            )),
304            name: None,
305            user_id: None,
306            sequence: None,
307            conflation_key: None,
308            tags: None,
309        }
310    }
311
312    // New helper method for pong response
313    pub fn pong() -> Self {
314        Self {
315            event: Some("pusher:pong".to_string()),
316            data: None,
317            channel: None,
318            name: None,
319            user_id: None,
320            sequence: None,
321            conflation_key: None,
322            tags: None,
323        }
324    }
325
326    // Helper for creating channel info response
327    pub fn channel_info(
328        occupied: bool,
329        subscription_count: Option<u64>,
330        user_count: Option<u64>,
331        cache_data: Option<(String, Duration)>,
332    ) -> Value {
333        let mut response = json!({
334            "occupied": occupied
335        });
336
337        if let Some(count) = subscription_count {
338            response["subscription_count"] = json!(count);
339        }
340
341        if let Some(count) = user_count {
342            response["user_count"] = json!(count);
343        }
344
345        if let Some((data, ttl)) = cache_data {
346            response["cache"] = json!({
347                "data": data,
348                "ttl": ttl.as_secs()
349            });
350        }
351
352        response
353    }
354
355    // Helper for creating channels list response
356    pub fn channels_list(channels_info: AHashMap<String, Value>) -> Value {
357        json!({
358            "channels": channels_info
359        })
360    }
361
362    // Helper for creating user list response
363    pub fn user_list(user_ids: Vec<String>) -> Value {
364        let users = user_ids
365            .into_iter()
366            .map(|id| json!({ "id": id }))
367            .collect::<Vec<_>>();
368
369        json!({ "users": users })
370    }
371
372    // Helper for batch events response
373    pub fn batch_response(batch_info: Vec<Value>) -> Value {
374        json!({ "batch": batch_info })
375    }
376
377    // Helper for simple success response
378    pub fn success_response() -> Value {
379        json!({ "ok": true })
380    }
381
382    pub fn watchlist_online_event(user_ids: Vec<String>) -> Self {
383        Self {
384            event: Some("online".to_string()),
385            channel: None, // Watchlist events don't use channels
386            name: None,
387            data: Some(MessageData::Json(json!({
388                "user_ids": user_ids
389            }))),
390            user_id: None,
391            sequence: None,
392            conflation_key: None,
393            tags: None,
394        }
395    }
396
397    pub fn watchlist_offline_event(user_ids: Vec<String>) -> Self {
398        Self {
399            event: Some("offline".to_string()),
400            channel: None,
401            name: None,
402            data: Some(MessageData::Json(json!({
403                "user_ids": user_ids
404            }))),
405            user_id: None,
406            sequence: None,
407            conflation_key: None,
408            tags: None,
409        }
410    }
411
412    pub fn cache_miss_event(channel: String) -> Self {
413        Self {
414            event: Some("pusher:cache_miss".to_string()),
415            channel: Some(channel),
416            data: Some(MessageData::String("{}".to_string())),
417            name: None,
418            user_id: None,
419            sequence: None,
420            conflation_key: None,
421            tags: None,
422        }
423    }
424
425    pub fn signin_success(user_data: String) -> Self {
426        Self {
427            event: Some("pusher:signin_success".to_string()),
428            data: Some(MessageData::Json(json!({
429                "user_data": user_data
430            }))),
431            channel: None,
432            name: None,
433            user_id: None,
434            sequence: None,
435            conflation_key: None,
436            tags: None,
437        }
438    }
439
440    /// Create a delta-compressed message
441    pub fn delta_message(
442        channel: String,
443        event: String,
444        delta_base64: String,
445        base_sequence: u32,
446        target_sequence: u32,
447        algorithm: &str,
448    ) -> Self {
449        Self {
450            event: Some("pusher:delta".to_string()),
451            channel: Some(channel.clone()),
452            data: Some(MessageData::String(
453                json!({
454                    "channel": channel,
455                    "event": event,
456                    "delta": delta_base64,
457                    "base_seq": base_sequence,
458                    "target_seq": target_sequence,
459                    "algorithm": algorithm,
460                })
461                .to_string(),
462            )),
463            name: None,
464            user_id: None,
465            sequence: None,
466            conflation_key: None,
467            tags: None,
468        }
469    }
470
471    /// Add base sequence marker to a full message for delta tracking
472    pub fn add_base_sequence(mut self, base_sequence: u32) -> Self {
473        if let Some(MessageData::String(ref data_str)) = self.data
474            && let Ok(mut data_obj) = sonic_rs::from_str::<Value>(data_str)
475            && let Some(obj) = data_obj.as_object_mut()
476        {
477            obj.insert("__delta_base_seq", json!(base_sequence));
478            self.data = Some(MessageData::String(data_obj.to_string()));
479        }
480        self
481    }
482
483    /// Create delta compression enabled confirmation
484    pub fn delta_compression_enabled(default_algorithm: &str) -> Self {
485        Self {
486            event: Some("pusher:delta_compression_enabled".to_string()),
487            data: Some(MessageData::Json(json!({
488                "enabled": true,
489                "default_algorithm": default_algorithm,
490            }))),
491            channel: None,
492            name: None,
493            user_id: None,
494            sequence: None,
495            conflation_key: None,
496            tags: None,
497        }
498    }
499}
500
501// Add a helper extension trait for working with info parameters
502pub trait InfoQueryParser {
503    fn parse_info(&self) -> Vec<&str>;
504    fn wants_user_count(&self) -> bool;
505    fn wants_subscription_count(&self) -> bool;
506    fn wants_cache(&self) -> bool;
507}
508
509impl InfoQueryParser for Option<&String> {
510    fn parse_info(&self) -> Vec<&str> {
511        self.map(|s| s.split(',').collect::<Vec<_>>())
512            .unwrap_or_default()
513    }
514
515    fn wants_user_count(&self) -> bool {
516        self.parse_info().contains(&"user_count")
517    }
518
519    fn wants_subscription_count(&self) -> bool {
520        self.parse_info().contains(&"subscription_count")
521    }
522
523    fn wants_cache(&self) -> bool {
524        self.parse_info().contains(&"cache")
525    }
526}