1use ahash::AHashMap;
2use serde::de::Error as _;
3use serde::{Deserialize, Serialize};
4use serde_json::Value as JsonValue;
5use sonic_rs::prelude::*;
6use sonic_rs::{Value, json};
7use std::collections::{BTreeMap, HashMap};
8use std::time::Duration;
9
10use crate::protocol_version::ProtocolVersion;
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15#[serde(untagged)]
16pub enum ExtrasValue {
17 String(String),
18 Number(f64),
19 Bool(bool),
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
29#[serde(rename_all = "camelCase")]
30pub struct MessageExtras {
31 #[serde(skip_serializing_if = "Option::is_none")]
34 pub headers: Option<HashMap<String, ExtrasValue>>,
35
36 #[serde(skip_serializing_if = "Option::is_none")]
39 pub ephemeral: Option<bool>,
40
41 #[serde(skip_serializing_if = "Option::is_none")]
44 pub idempotency_key: Option<String>,
45
46 #[serde(skip_serializing_if = "Option::is_none")]
49 pub echo: Option<bool>,
50}
51
52impl MessageExtras {
53 pub fn validate_headers_from_json(raw: &Value) -> Result<(), String> {
58 if let Some(extras) = raw.get("extras")
59 && let Some(headers) = extras.get("headers")
60 && let Some(obj) = headers.as_object()
61 {
62 for (key, val) in obj.iter() {
63 if val.is_object() || val.is_array() {
64 return Err(format!(
65 "extras.headers must be a flat object — nested objects and arrays are not allowed (key: '{key}')"
66 ));
67 }
68 }
69 }
70 Ok(())
71 }
72}
73
74pub fn generate_message_id() -> String {
76 uuid::Uuid::new_v4().to_string()
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct PresenceData {
81 pub ids: Vec<String>,
82 pub hash: AHashMap<String, Option<Value>>,
83 pub count: usize,
84}
85
86#[derive(Debug, Clone, Serialize, PartialEq)]
87#[serde(untagged)]
88pub enum MessageData {
89 String(String),
90 Structured {
91 #[serde(skip_serializing_if = "Option::is_none")]
92 channel_data: Option<String>,
93 #[serde(skip_serializing_if = "Option::is_none")]
94 channel: Option<String>,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 user_data: Option<String>,
97 #[serde(flatten)]
98 extra: AHashMap<String, Value>,
99 },
100 Json(Value),
101}
102
103impl<'de> Deserialize<'de> for MessageData {
104 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
105 where
106 D: serde::Deserializer<'de>,
107 {
108 let v = JsonValue::deserialize(deserializer)?;
109 if let Some(s) = v.as_str() {
110 return Ok(MessageData::String(s.to_string()));
111 }
112 if let Some(obj) = v.as_object() {
113 let channel_data = obj
116 .get("channel_data")
117 .and_then(|x| x.as_str())
118 .map(ToString::to_string);
119 let channel = obj
120 .get("channel")
121 .and_then(|x| x.as_str())
122 .map(ToString::to_string);
123 let user_data = obj
124 .get("user_data")
125 .and_then(|x| x.as_str())
126 .map(ToString::to_string);
127
128 if channel_data.is_some() || channel.is_some() || user_data.is_some() {
129 let mut extra = AHashMap::new();
130 for (k, val) in obj.iter() {
131 if k != "channel_data" && k != "channel" && k != "user_data" {
132 extra.insert(
133 k.to_string(),
134 serde_json_value_to_sonic(val.clone()).map_err(D::Error::custom)?,
135 );
136 }
137 }
138 return Ok(MessageData::Structured {
139 channel_data,
140 channel,
141 user_data,
142 extra,
143 });
144 }
145 }
146 Ok(MessageData::Json(
147 serde_json_value_to_sonic(v).map_err(D::Error::custom)?,
148 ))
149 }
150}
151
152fn serde_json_value_to_sonic(value: JsonValue) -> Result<Value, String> {
153 let encoded = serde_json::to_string(&value)
154 .map_err(|err| format!("failed to encode json value for MessageData: {err}"))?;
155 sonic_rs::from_str(&encoded)
156 .map_err(|err| format!("failed to decode json value for MessageData: {err}"))
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ErrorData {
161 pub code: Option<u16>,
162 pub message: String,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
166pub struct PusherMessage {
167 #[serde(skip_serializing_if = "Option::is_none")]
168 pub event: Option<String>,
169 #[serde(skip_serializing_if = "Option::is_none")]
170 pub channel: Option<String>,
171 #[serde(skip_serializing_if = "Option::is_none")]
172 pub data: Option<MessageData>,
173 #[serde(skip_serializing_if = "Option::is_none")]
174 pub name: Option<String>,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 pub user_id: Option<String>,
177 #[serde(skip_serializing_if = "Option::is_none")]
180 pub tags: Option<BTreeMap<String, String>>,
181 #[serde(skip_serializing_if = "Option::is_none")]
183 pub sequence: Option<u64>,
184 #[serde(skip_serializing_if = "Option::is_none")]
186 pub conflation_key: Option<String>,
187 #[serde(skip_serializing_if = "Option::is_none")]
189 pub message_id: Option<String>,
190 #[serde(skip_serializing_if = "Option::is_none")]
193 pub stream_id: Option<String>,
194 #[serde(skip_serializing_if = "Option::is_none")]
197 pub serial: Option<u64>,
198 #[serde(skip_serializing_if = "Option::is_none")]
203 pub idempotency_key: Option<String>,
204 #[serde(skip_serializing_if = "Option::is_none")]
208 pub extras: Option<MessageExtras>,
209 #[serde(rename = "__delta_seq", skip_serializing_if = "Option::is_none")]
211 pub delta_sequence: Option<u64>,
212 #[serde(rename = "__conflation_key", skip_serializing_if = "Option::is_none")]
214 pub delta_conflation_key: Option<String>,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct PusherApiMessage {
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub name: Option<String>,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 pub data: Option<ApiMessageData>,
223 #[serde(skip_serializing_if = "Option::is_none")]
224 pub channel: Option<String>,
225 #[serde(skip_serializing_if = "Option::is_none")]
226 pub channels: Option<Vec<String>>,
227 #[serde(skip_serializing_if = "Option::is_none")]
228 pub socket_id: Option<String>,
229 #[serde(skip_serializing_if = "Option::is_none")]
230 pub info: Option<String>,
231 #[serde(skip_serializing_if = "Option::is_none")]
232 pub tags: Option<AHashMap<String, String>>,
233 #[serde(skip_serializing_if = "Option::is_none")]
238 pub delta: Option<bool>,
239 #[serde(skip_serializing_if = "Option::is_none")]
243 pub idempotency_key: Option<String>,
244 #[serde(skip_serializing_if = "Option::is_none")]
246 pub extras: Option<MessageExtras>,
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct BatchPusherApiMessage {
251 pub batch: Vec<PusherApiMessage>,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
255#[serde(untagged)]
256pub enum ApiMessageData {
257 String(String),
258 Json(Value),
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct SentPusherMessage {
263 #[serde(skip_serializing_if = "Option::is_none")]
264 pub channel: Option<String>,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 pub event: Option<String>,
267 #[serde(skip_serializing_if = "Option::is_none")]
268 pub data: Option<MessageData>,
269}
270
271impl MessageData {
273 pub fn as_string(&self) -> Option<&str> {
274 match self {
275 MessageData::String(s) => Some(s),
276 _ => None,
277 }
278 }
279
280 pub fn into_string(self) -> Option<String> {
281 match self {
282 MessageData::String(s) => Some(s),
283 _ => None,
284 }
285 }
286
287 pub fn as_value(&self) -> Option<&Value> {
288 match self {
289 MessageData::Structured { extra, .. } => extra.values().next(),
290 _ => None,
291 }
292 }
293}
294
295impl From<String> for MessageData {
296 fn from(s: String) -> Self {
297 MessageData::String(s)
298 }
299}
300
301impl From<Value> for MessageData {
302 fn from(v: Value) -> Self {
303 MessageData::Json(v)
304 }
305}
306
307impl PusherMessage {
308 pub fn is_protocol_ping_or_pong(&self) -> bool {
309 let Some(event) = self.event.as_deref() else {
310 return false;
311 };
312
313 matches!(
314 ProtocolVersion::parse_any_protocol_event(event),
315 Some(("ping", _)) | Some(("pong", _))
316 )
317 }
318
319 pub fn connection_established(socket_id: String, activity_timeout: u64) -> Self {
320 Self {
321 event: Some("pusher:connection_established".to_string()),
322 data: Some(MessageData::from(
323 json!({
324 "socket_id": socket_id,
325 "activity_timeout": activity_timeout })
327 .to_string(),
328 )),
329 channel: None,
330 name: None,
331 user_id: None,
332 sequence: None,
333 conflation_key: None,
334 tags: None,
335 message_id: None,
336 stream_id: None,
337 serial: None,
338 idempotency_key: None,
339 extras: None,
340 delta_sequence: None,
341 delta_conflation_key: None,
342 }
343 }
344 pub fn subscription_succeeded(channel: String, presence_data: Option<PresenceData>) -> Self {
345 let data_obj = if let Some(data) = presence_data {
346 json!({
347 "presence": {
348 "ids": data.ids,
349 "hash": data.hash,
350 "count": data.count
351 }
352 })
353 } else {
354 json!({})
355 };
356
357 Self {
358 event: Some("pusher_internal:subscription_succeeded".to_string()),
359 channel: Some(channel),
360 data: Some(MessageData::String(data_obj.to_string())),
361 name: None,
362 user_id: None,
363 sequence: None,
364 conflation_key: None,
365 tags: None,
366 message_id: None,
367 stream_id: None,
368 serial: None,
369 idempotency_key: None,
370 extras: None,
371 delta_sequence: None,
372 delta_conflation_key: None,
373 }
374 }
375
376 pub fn error(code: u16, message: String, channel: Option<String>) -> Self {
377 Self {
378 event: Some("pusher:error".to_string()),
379 data: Some(MessageData::Json(json!({
380 "code": code,
381 "message": message
382 }))),
383 channel,
384 name: None,
385 user_id: None,
386 sequence: None,
387 conflation_key: None,
388 tags: None,
389 message_id: None,
390 stream_id: None,
391 serial: None,
392 idempotency_key: None,
393 extras: None,
394 delta_sequence: None,
395 delta_conflation_key: None,
396 }
397 }
398
399 pub fn ping() -> Self {
400 Self {
401 event: Some("pusher:ping".to_string()),
402 data: None,
403 channel: None,
404 name: None,
405 user_id: None,
406 sequence: None,
407 conflation_key: None,
408 tags: None,
409 message_id: None,
410 stream_id: None,
411 serial: None,
412 idempotency_key: None,
413 extras: None,
414 delta_sequence: None,
415 delta_conflation_key: None,
416 }
417 }
418 pub fn channel_event<S: Into<String>>(event: S, channel: S, data: Value) -> Self {
419 Self {
420 event: Some(event.into()),
421 channel: Some(channel.into()),
422 data: Some(MessageData::String(data.to_string())),
423 name: None,
424 user_id: None,
425 sequence: None,
426 conflation_key: None,
427 tags: None,
428 message_id: None,
429 stream_id: None,
430 serial: None,
431 idempotency_key: None,
432 extras: None,
433 delta_sequence: None,
434 delta_conflation_key: None,
435 }
436 }
437
438 pub fn member_added(channel: String, user_id: String, user_info: Option<Value>) -> Self {
439 Self {
440 event: Some("pusher_internal:member_added".to_string()),
441 channel: Some(channel),
442 data: Some(MessageData::String(
444 json!({
445 "user_id": user_id,
446 "user_info": user_info.unwrap_or_else(|| json!({}))
447 })
448 .to_string(),
449 )),
450 name: None,
451 user_id: None,
452 sequence: None,
453 conflation_key: None,
454 tags: None,
455 message_id: None,
456 stream_id: None,
457 serial: None,
458 idempotency_key: None,
459 extras: None,
460 delta_sequence: None,
461 delta_conflation_key: None,
462 }
463 }
464
465 pub fn member_removed(channel: String, user_id: String) -> Self {
466 Self {
467 event: Some("pusher_internal:member_removed".to_string()),
468 channel: Some(channel),
469 data: Some(MessageData::String(
471 json!({
472 "user_id": user_id
473 })
474 .to_string(),
475 )),
476 name: None,
477 user_id: None,
478 sequence: None,
479 conflation_key: None,
480 tags: None,
481 message_id: None,
482 stream_id: None,
483 serial: None,
484 idempotency_key: None,
485 extras: None,
486 delta_sequence: None,
487 delta_conflation_key: None,
488 }
489 }
490
491 pub fn pong() -> Self {
493 Self {
494 event: Some("pusher:pong".to_string()),
495 data: None,
496 channel: None,
497 name: None,
498 user_id: None,
499 sequence: None,
500 conflation_key: None,
501 tags: None,
502 message_id: None,
503 stream_id: None,
504 serial: None,
505 idempotency_key: None,
506 extras: None,
507 delta_sequence: None,
508 delta_conflation_key: None,
509 }
510 }
511
512 pub fn channel_info(
514 occupied: bool,
515 subscription_count: Option<u64>,
516 user_count: Option<u64>,
517 cache_data: Option<(String, Duration)>,
518 ) -> Value {
519 let mut response = json!({
520 "occupied": occupied
521 });
522
523 if let Some(count) = subscription_count {
524 response["subscription_count"] = json!(count);
525 }
526
527 if let Some(count) = user_count {
528 response["user_count"] = json!(count);
529 }
530
531 if let Some((data, ttl)) = cache_data {
532 response["cache"] = json!({
533 "data": data,
534 "ttl": ttl.as_secs()
535 });
536 }
537
538 response
539 }
540
541 pub fn channels_list(channels_info: AHashMap<String, Value>) -> Value {
543 json!({
544 "channels": channels_info
545 })
546 }
547
548 pub fn user_list(user_ids: Vec<String>) -> Value {
550 let users = user_ids
551 .into_iter()
552 .map(|id| json!({ "id": id }))
553 .collect::<Vec<_>>();
554
555 json!({ "users": users })
556 }
557
558 pub fn batch_response(batch_info: Vec<Value>) -> Value {
560 json!({ "batch": batch_info })
561 }
562
563 pub fn success_response() -> Value {
565 json!({ "ok": true })
566 }
567
568 pub fn watchlist_online_event(user_ids: Vec<String>) -> Self {
569 Self {
570 event: Some("online".to_string()),
571 channel: None, name: None,
573 data: Some(MessageData::Json(json!({
574 "user_ids": user_ids
575 }))),
576 user_id: None,
577 sequence: None,
578 conflation_key: None,
579 tags: None,
580 message_id: None,
581 stream_id: None,
582 serial: None,
583 idempotency_key: None,
584 extras: None,
585 delta_sequence: None,
586 delta_conflation_key: None,
587 }
588 }
589
590 pub fn watchlist_offline_event(user_ids: Vec<String>) -> Self {
591 Self {
592 event: Some("offline".to_string()),
593 channel: None,
594 name: None,
595 data: Some(MessageData::Json(json!({
596 "user_ids": user_ids
597 }))),
598 user_id: None,
599 sequence: None,
600 conflation_key: None,
601 tags: None,
602 message_id: None,
603 stream_id: None,
604 serial: None,
605 idempotency_key: None,
606 extras: None,
607 delta_sequence: None,
608 delta_conflation_key: None,
609 }
610 }
611
612 pub fn cache_miss_event(channel: String) -> Self {
613 Self {
614 event: Some("pusher:cache_miss".to_string()),
615 channel: Some(channel),
616 data: Some(MessageData::String("{}".to_string())),
617 name: None,
618 user_id: None,
619 sequence: None,
620 conflation_key: None,
621 tags: None,
622 message_id: None,
623 stream_id: None,
624 serial: None,
625 idempotency_key: None,
626 extras: None,
627 delta_sequence: None,
628 delta_conflation_key: None,
629 }
630 }
631
632 pub fn signin_success(user_data: String) -> Self {
633 Self {
634 event: Some("pusher:signin_success".to_string()),
635 data: Some(MessageData::Json(json!({
636 "user_data": user_data
637 }))),
638 channel: None,
639 name: None,
640 user_id: None,
641 sequence: None,
642 conflation_key: None,
643 tags: None,
644 message_id: None,
645 stream_id: None,
646 serial: None,
647 idempotency_key: None,
648 extras: None,
649 delta_sequence: None,
650 delta_conflation_key: None,
651 }
652 }
653
654 pub fn delta_message(
656 channel: String,
657 event: String,
658 delta_base64: String,
659 base_sequence: u32,
660 target_sequence: u32,
661 algorithm: &str,
662 ) -> Self {
663 Self {
664 event: Some("pusher:delta".to_string()),
665 channel: Some(channel.clone()),
666 data: Some(MessageData::String(
667 json!({
668 "channel": channel,
669 "event": event,
670 "delta": delta_base64,
671 "base_seq": base_sequence,
672 "target_seq": target_sequence,
673 "algorithm": algorithm,
674 })
675 .to_string(),
676 )),
677 name: None,
678 user_id: None,
679 sequence: None,
680 conflation_key: None,
681 tags: None,
682 message_id: None,
683 stream_id: None,
684 serial: None,
685 idempotency_key: None,
686 extras: None,
687 delta_sequence: None,
688 delta_conflation_key: None,
689 }
690 }
691
692 pub fn rewrite_prefix(&mut self, version: ProtocolVersion) {
695 if let Some(ref event) = self.event {
696 self.event = Some(version.rewrite_event_prefix(event));
697 }
698 }
699
700 pub fn is_ephemeral(&self) -> bool {
702 self.extras
703 .as_ref()
704 .and_then(|e| e.ephemeral)
705 .unwrap_or(false)
706 }
707
708 pub fn extras_idempotency_key(&self) -> Option<&str> {
710 self.extras
711 .as_ref()
712 .and_then(|e| e.idempotency_key.as_deref())
713 }
714
715 pub fn should_echo(&self, connection_default: bool) -> bool {
718 self.extras
719 .as_ref()
720 .and_then(|e| e.echo)
721 .unwrap_or(connection_default)
722 }
723
724 pub fn filter_headers(&self) -> Option<&HashMap<String, ExtrasValue>> {
726 self.extras.as_ref().and_then(|e| e.headers.as_ref())
727 }
728
729 pub fn should_include_extras(protocol: &ProtocolVersion) -> bool {
731 matches!(protocol, ProtocolVersion::V2)
732 }
733
734 pub fn add_base_sequence(mut self, base_sequence: u32) -> Self {
736 if let Some(MessageData::String(ref data_str)) = self.data
737 && let Ok(mut data_obj) = sonic_rs::from_str::<Value>(data_str)
738 && let Some(obj) = data_obj.as_object_mut()
739 {
740 obj.insert("__delta_base_seq", json!(base_sequence));
741 self.data = Some(MessageData::String(data_obj.to_string()));
742 }
743 self
744 }
745
746 pub fn delta_compression_enabled(default_algorithm: &str) -> Self {
748 Self {
749 event: Some("pusher:delta_compression_enabled".to_string()),
750 data: Some(MessageData::Json(json!({
751 "enabled": true,
752 "default_algorithm": default_algorithm,
753 }))),
754 channel: None,
755 name: None,
756 user_id: None,
757 sequence: None,
758 conflation_key: None,
759 tags: None,
760 message_id: None,
761 stream_id: None,
762 serial: None,
763 idempotency_key: None,
764 extras: None,
765 delta_sequence: None,
766 delta_conflation_key: None,
767 }
768 }
769}
770
771pub trait InfoQueryParser {
773 fn parse_info(&self) -> Vec<&str>;
774 fn wants_user_count(&self) -> bool;
775 fn wants_subscription_count(&self) -> bool;
776 fn wants_cache(&self) -> bool;
777}
778
779impl InfoQueryParser for Option<&String> {
780 fn parse_info(&self) -> Vec<&str> {
781 self.map(|s| s.split(',').collect::<Vec<_>>())
782 .unwrap_or_default()
783 }
784
785 fn wants_user_count(&self) -> bool {
786 self.parse_info().contains(&"user_count")
787 }
788
789 fn wants_subscription_count(&self) -> bool {
790 self.parse_info().contains(&"subscription_count")
791 }
792
793 fn wants_cache(&self) -> bool {
794 self.parse_info().contains(&"cache")
795 }
796}
797
798#[cfg(test)]
799mod tests {
800 use super::PusherMessage;
801
802 #[test]
803 fn protocol_heartbeat_detection_matches_both_prefix_families() {
804 let mut ping = PusherMessage::ping();
805 assert!(ping.is_protocol_ping_or_pong());
806
807 ping.rewrite_prefix(crate::protocol_version::ProtocolVersion::V2);
808 assert!(ping.is_protocol_ping_or_pong());
809
810 let mut pong = PusherMessage::pong();
811 assert!(pong.is_protocol_ping_or_pong());
812
813 pong.rewrite_prefix(crate::protocol_version::ProtocolVersion::V2);
814 assert!(pong.is_protocol_ping_or_pong());
815 }
816
817 #[test]
818 fn protocol_heartbeat_detection_ignores_regular_messages() {
819 let message = PusherMessage::channel_event(
820 "chat.message",
821 "room",
822 sonic_rs::json!({"text": "hello"}),
823 );
824
825 assert!(!message.is_protocol_ping_or_pong());
826 }
827}