rusher_core/
lib.rs

1use std::{collections::HashMap, fmt::Display, str::FromStr};
2
3use rand::{distributions, prelude::Distribution};
4use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
5
6pub mod signature;
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
9pub struct SocketId(String);
10
11impl Display for SocketId {
12    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13        write!(f, "{}", self.0)
14    }
15}
16
17impl Distribution<SocketId> for distributions::Standard {
18    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> SocketId {
19        let digits = distributions::Uniform::from(0..=9)
20            .sample_iter(rng)
21            .take(32)
22            .map(|s| s.to_string())
23            .collect::<String>();
24        let (p1, p2) = digits.split_at(16);
25        SocketId(format!("{p1}.{p2}"))
26    }
27}
28
29#[derive(Debug, Clone)]
30pub enum SocketIdParseError {
31    InvalidSocketId,
32}
33
34impl FromStr for SocketId {
35    type Err = SocketIdParseError;
36
37    fn from_str(s: &str) -> Result<Self, Self::Err> {
38        match s.find('.') {
39            Some(index) if index > 0 && index < s.len() => Ok(SocketId(s.to_owned())),
40            _ => Err(SocketIdParseError::InvalidSocketId),
41        }
42    }
43}
44
45impl AsRef<str> for SocketId {
46    fn as_ref(&self) -> &str {
47        &self.0
48    }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum ChannelName {
53    Public(String),
54    Private(String),
55    Presence(String),
56    Encrypted(String),
57}
58
59impl AsRef<str> for ChannelName {
60    fn as_ref(&self) -> &str {
61        match self {
62            ChannelName::Public(ref name) => name,
63            ChannelName::Private(ref name) => name,
64            ChannelName::Presence(ref name) => name,
65            ChannelName::Encrypted(ref name) => name,
66        }
67    }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub enum ChannelNameParseError {
72    InvalidChannelName,
73}
74
75impl Display for ChannelNameParseError {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            ChannelNameParseError::InvalidChannelName => f.write_str("Invalid channel name"),
79        }
80    }
81}
82
83impl FromStr for ChannelName {
84    type Err = ChannelNameParseError;
85
86    fn from_str(s: &str) -> Result<Self, Self::Err> {
87        match s.splitn(3, '-').collect::<Vec<&str>>().as_slice() {
88            ["private", "encrypted", name, ..] if !name.is_empty() => {
89                Ok(ChannelName::Encrypted(s.to_owned()))
90            }
91            ["private", name, ..] if !name.is_empty() => Ok(ChannelName::Private(s.to_owned())),
92            ["presence", name, ..] if !name.is_empty() => Ok(ChannelName::Presence(s.to_owned())),
93            _ if !s.is_empty() => Ok(ChannelName::Public(s.to_owned())),
94            _ => Err(ChannelNameParseError::InvalidChannelName),
95        }
96    }
97}
98
99impl Display for ChannelName {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        match self {
102            ChannelName::Public(channel) => f.write_str(channel),
103            ChannelName::Private(channel) => f.write_str(channel),
104            ChannelName::Presence(channel) => f.write_str(channel),
105            ChannelName::Encrypted(channel) => f.write_str(channel),
106        }
107    }
108}
109
110impl Serialize for ChannelName {
111    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
112        serializer.serialize_str(self.as_ref())
113    }
114}
115
116impl<'de> Deserialize<'de> for ChannelName {
117    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
118        let name = String::deserialize(deserializer)?;
119        FromStr::from_str(&name).map_err(de::Error::custom)
120    }
121}
122
123#[derive(Debug, PartialEq, Serialize, Deserialize)]
124#[serde(from = "ClientEventJson", tag = "event", content = "data")]
125pub enum ClientEvent {
126    #[serde(rename = "pusher:signin")]
127    Signin { auth: String, user_data: String },
128    #[serde(rename = "pusher:subscribe")]
129    Subscribe {
130        channel: ChannelName,
131        auth: Option<String>,
132        channel_data: Option<serde_json::Value>,
133    },
134    #[serde(rename = "pusher:unsubscribe")]
135    Unsubscribe { channel: ChannelName },
136    #[serde(rename = "pusher:ping")]
137    Ping,
138    #[serde(untagged)]
139    ChannelEvent {
140        event: String,
141        channel: ChannelName,
142        data: serde_json::Value,
143    },
144}
145
146#[derive(Debug, PartialEq, Deserialize)]
147#[serde(tag = "event", content = "data")]
148enum PusherClientEvent {
149    #[serde(rename = "pusher:signin")]
150    Signin { auth: String, user_data: String },
151    #[serde(rename = "pusher:subscribe")]
152    Subscribe {
153        channel: ChannelName,
154        auth: Option<String>,
155        channel_data: Option<serde_json::Value>,
156    },
157    #[serde(rename = "pusher:unsubscribe")]
158    Unsubscribe { channel: ChannelName },
159    #[serde(rename = "pusher:ping")]
160    Ping { data: Option<serde_json::Value> },
161}
162
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164struct CustomClientEvent {
165    event: String,
166    channel: ChannelName,
167    data: serde_json::Value,
168}
169
170#[derive(Debug, PartialEq, Deserialize)]
171#[serde(untagged)]
172enum ClientEventJson {
173    PusherEvent(PusherClientEvent),
174    CustomEvent(CustomClientEvent),
175}
176
177impl From<ClientEventJson> for ClientEvent {
178    fn from(json: ClientEventJson) -> Self {
179        use ClientEventJson::*;
180        use PusherClientEvent::*;
181        match json {
182            PusherEvent(Signin { auth, user_data }) => ClientEvent::Signin { auth, user_data },
183            PusherEvent(Subscribe {
184                channel,
185                auth,
186                channel_data,
187            }) => ClientEvent::Subscribe {
188                channel,
189                auth,
190                channel_data,
191            },
192            PusherEvent(Unsubscribe { channel }) => ClientEvent::Unsubscribe { channel },
193            PusherEvent(Ping { .. }) => ClientEvent::Ping,
194            CustomEvent(CustomClientEvent {
195                event,
196                channel,
197                data,
198            }) => ClientEvent::ChannelEvent {
199                event,
200                channel,
201                data,
202            },
203        }
204    }
205}
206
207#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
208pub struct SigninInformation {
209    #[serde(with = "json_string")]
210    pub user_data: UserData,
211}
212
213#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
214pub struct PresenceInformation {
215    ids: Vec<String>,
216    hash: HashMap<String, HashMap<String, String>>,
217    count: u32,
218}
219
220#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
221pub struct PresenceUser {
222    #[serde(rename = "user_id")]
223    id: String,
224    #[serde(rename = "user_info")]
225    info: serde_json::Value,
226}
227
228#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
229pub struct RemovedMember {
230    #[serde(rename = "user_id")]
231    id: String,
232}
233
234#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
235pub struct CustomEvent {
236    pub event: String,
237    pub channel: ChannelName,
238    #[serde(with = "json_string")]
239    pub data: serde_json::Value,
240    #[serde(skip_serializing_if = "Option::is_none")]
241    pub user_id: Option<String>,
242}
243
244#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
245pub struct ConnectionInfo {
246    pub socket_id: SocketId,
247    pub activity_timeout: u8,
248}
249
250#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
251pub struct UserData {
252    pub id: String,
253    pub user_info: Option<serde_json::Value>,
254    pub watchlist: Option<Vec<String>>,
255}
256
257#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
258#[serde(from = "ServerEventJson", into = "ServerEventJson")]
259pub enum ServerEvent {
260    #[serde(rename = "pusher:connection_established")]
261    ConnectionEstablished {
262        #[serde(with = "json_string")]
263        data: ConnectionInfo,
264    },
265
266    #[serde(rename = "pusher:signin_success")]
267    SigninSucceeded {
268        data: SigninInformation,
269    },
270
271    #[serde(rename = "pusher:error")]
272    Error {
273        message: String,
274        code: Option<u16>,
275    },
276
277    #[serde(rename = "pusher:pong")]
278    Pong,
279
280    #[serde(rename = "pusher_internal:subscription_succeeded")]
281    SubscriptionSucceeded {
282        channel: ChannelName,
283        #[serde(with = "json_string")]
284        data: Option<PresenceInformation>,
285    },
286
287    #[serde(rename = "pusher_internal:member_added")]
288    MemberAdded {
289        channel: ChannelName,
290        #[serde(with = "json_string")]
291        data: PresenceUser,
292    },
293
294    #[serde(rename = "pusher_internal:member_removed")]
295    MemberRemoved {
296        channel: ChannelName,
297        #[serde(with = "json_string")]
298        data: RemovedMember,
299    },
300
301    ChannelEvent(CustomEvent),
302}
303
304impl ServerEvent {
305    pub fn signin_succeeded(user_data: UserData) -> Self {
306        Self::SigninSucceeded {
307            data: SigninInformation { user_data },
308        }
309    }
310
311    pub fn subscription_succeeded(channel: impl Into<ChannelName>) -> Self {
312        Self::SubscriptionSucceeded {
313            channel: channel.into(),
314            data: None,
315        }
316    }
317
318    pub fn custom_event(
319        event: impl Into<String>,
320        channel: impl Into<ChannelName>,
321        data: impl Into<serde_json::Value>,
322        user_id: impl Into<Option<String>>,
323    ) -> Self {
324        Self::ChannelEvent(CustomEvent {
325            event: event.into(),
326            channel: channel.into(),
327            data: data.into(),
328            user_id: user_id.into(),
329        })
330    }
331
332    pub fn invalid_signature_error() -> Self {
333        Self::error("Invalid signature", 409)
334    }
335
336    pub fn authentication_error(message: impl Into<String>) -> Self {
337        Self::error(message, 409)
338    }
339
340    pub fn error(message: impl Into<String>, code: impl Into<Option<u16>>) -> Self {
341        Self::Error {
342            message: message.into(),
343            code: code.into(),
344        }
345    }
346}
347
348impl From<ServerEventJson> for ServerEvent {
349    fn from(json: ServerEventJson) -> Self {
350        use PusherServerEvent::*;
351        use ServerEventJson::*;
352        match json {
353            PusherEvent(ConnectionEstablished { data }) => {
354                ServerEvent::ConnectionEstablished { data }
355            }
356            PusherEvent(SigninSucceeded { data }) => ServerEvent::SigninSucceeded { data },
357            PusherEvent(Error { message, code }) => ServerEvent::Error { message, code },
358            PusherEvent(Pong) => ServerEvent::Pong,
359            PusherEvent(SubscriptionSucceeded { channel, data }) => {
360                ServerEvent::SubscriptionSucceeded { channel, data }
361            }
362            PusherEvent(MemberAdded { channel, data }) => {
363                ServerEvent::MemberAdded { channel, data }
364            }
365            PusherEvent(MemberRemoved { channel, data }) => {
366                ServerEvent::MemberRemoved { channel, data }
367            }
368            UserEvent(event) => ServerEvent::ChannelEvent(event),
369        }
370    }
371}
372
373impl From<ServerEvent> for ServerEventJson {
374    fn from(value: ServerEvent) -> Self {
375        use ServerEvent::*;
376        use ServerEventJson::*;
377        match value {
378            ConnectionEstablished { data } => {
379                PusherEvent(PusherServerEvent::ConnectionEstablished { data })
380            }
381            SigninSucceeded { data } => PusherEvent(PusherServerEvent::SigninSucceeded { data }),
382            Error { message, code } => PusherEvent(PusherServerEvent::Error { message, code }),
383            Pong => PusherEvent(PusherServerEvent::Pong),
384            SubscriptionSucceeded { channel, data } => {
385                PusherEvent(PusherServerEvent::SubscriptionSucceeded { channel, data })
386            }
387            MemberRemoved { channel, data } => {
388                PusherEvent(PusherServerEvent::MemberRemoved { channel, data })
389            }
390            MemberAdded { channel, data } => {
391                PusherEvent(PusherServerEvent::MemberAdded { channel, data })
392            }
393            ChannelEvent(event) => UserEvent(event),
394        }
395    }
396}
397
398#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
399#[serde(untagged)]
400enum ServerEventJson {
401    PusherEvent(PusherServerEvent),
402    UserEvent(CustomEvent),
403}
404
405#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
406#[serde(tag = "event")]
407pub enum PusherServerEvent {
408    #[serde(rename = "pusher:connection_established")]
409    ConnectionEstablished {
410        #[serde(with = "json_string")]
411        data: ConnectionInfo,
412    },
413
414    #[serde(rename = "pusher:signin_success")]
415    SigninSucceeded { data: SigninInformation },
416
417    #[serde(rename = "pusher:error")]
418    Error { message: String, code: Option<u16> },
419
420    #[serde(rename = "pusher:pong")]
421    Pong,
422
423    #[serde(rename = "pusher_internal:subscription_succeeded")]
424    SubscriptionSucceeded {
425        channel: ChannelName,
426        #[serde(with = "json_string")]
427        data: Option<PresenceInformation>,
428    },
429
430    #[serde(rename = "pusher_internal:member_added")]
431    MemberAdded {
432        channel: ChannelName,
433        #[serde(with = "json_string")]
434        data: PresenceUser,
435    },
436
437    #[serde(rename = "pusher_internal:member_removed")]
438    MemberRemoved {
439        channel: ChannelName,
440        #[serde(with = "json_string")]
441        data: RemovedMember,
442    },
443}
444
445mod json_string {
446    use serde::{
447        de::{self, DeserializeOwned},
448        ser::{self, Serialize, Serializer},
449        Deserialize, Deserializer,
450    };
451
452    pub fn serialize<T: Serialize, S: Serializer>(
453        value: &T,
454        serializer: S,
455    ) -> Result<S::Ok, S::Error> {
456        let json = serde_json::to_string(value).map_err(ser::Error::custom)?;
457        json.serialize(serializer)
458    }
459
460    pub fn deserialize<'de, T: DeserializeOwned, D: Deserializer<'de>>(
461        deserializer: D,
462    ) -> Result<T, D::Error> {
463        let json = String::deserialize(deserializer)?;
464        serde_json::from_str(&json).map_err(de::Error::custom)
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use serde_json::json;
472
473    #[test]
474    fn parse_channel_name() {
475        assert_eq!(Ok(ChannelName::Public("lol".to_owned())), "lol".parse());
476        assert_eq!(
477            Ok(ChannelName::Private("private-lol".to_owned())),
478            "private-lol".parse()
479        );
480        assert_eq!(
481            Ok(ChannelName::Presence("presence-lol".to_owned())),
482            "presence-lol".parse()
483        );
484        assert_eq!(
485            Ok(ChannelName::Encrypted("private-encrypted-lol".to_owned())),
486            "private-encrypted-lol".parse()
487        );
488        assert_eq!(
489            Err(ChannelNameParseError::InvalidChannelName),
490            "".parse::<ChannelName>()
491        );
492    }
493
494    #[test]
495    fn test_member_removed() {
496        let event = ServerEvent::MemberRemoved {
497            channel: "channel".parse().unwrap(),
498            data: RemovedMember {
499                id: "lolwut".to_owned(),
500            },
501        };
502
503        let serialized = serde_json::to_value(&event).unwrap();
504
505        let expected = json!({
506            "event": "pusher_internal:member_removed",
507            "channel": "channel",
508            "data": r#"{"user_id":"lolwut"}"#,
509        });
510
511        assert_eq!(expected, serialized);
512
513        let deserialized = serde_json::from_value(expected).unwrap();
514
515        assert_eq!(event, deserialized);
516    }
517
518    #[test]
519    fn test_custom_event() {
520        let event = ServerEvent::ChannelEvent(CustomEvent {
521            event: "client-message".to_owned(),
522            channel: "channel".parse().unwrap(),
523            data: json!({ "some": "data" }),
524            user_id: Some("user".to_owned()),
525        });
526
527        let serialized = serde_json::to_value(&event).unwrap();
528
529        let expected = json!({
530            "event": "client-message",
531            "channel": "channel",
532            "data": r#"{"some":"data"}"#,
533            "user_id": "user",
534        });
535
536        assert_eq!(expected, serialized);
537
538        let deserialized = serde_json::from_value(expected).unwrap();
539
540        assert_eq!(event, deserialized);
541    }
542
543    #[test]
544    fn test_deserialize_ping() {
545        let event = ClientEvent::Ping;
546        let serialized = json!({ "event": "pusher:ping", "data": {} });
547        let deserialized = serde_json::from_value::<ClientEvent>(serialized).unwrap();
548        assert_eq!(event, deserialized);
549    }
550
551    #[test]
552    fn test_deserialize_signin() {
553        let event = ClientEvent::Signin {
554            auth: "1234".to_owned(),
555            user_data: serde_json::to_string(&UserData {
556                id: "user1".to_owned(),
557                user_info: Some(json!({ "lol": "wut" })),
558                watchlist: Some(vec!["user2".to_owned(), "user3".to_owned()]),
559            })
560            .unwrap(),
561        };
562
563        let serialized = json!({
564            "event": "pusher:signin",
565            "data": {
566                "auth": "1234",
567                "user_data": serde_json::to_string(&json!({
568                    "id": "user1",
569                    "user_info": { "lol": "wut" },
570                    "watchlist": ["user2", "user3"],
571                })).unwrap(),
572            },
573        });
574        let deserialized = serde_json::from_value::<ClientEvent>(serialized).unwrap();
575        assert_eq!(event, deserialized);
576    }
577
578    #[test]
579    fn test_deserialize_subscribe() {
580        let event = ClientEvent::Subscribe {
581            channel: "lolwut".parse().unwrap(),
582            auth: None,
583            channel_data: Some(json!({ "lol": "wut" })),
584        };
585        let serialized = json!({
586            "event": "pusher:subscribe",
587            "data": {
588                "channel": "lolwut",
589                "channel_data": { "lol": "wut" },
590            },
591        });
592        let deserialized = serde_json::from_value::<ClientEvent>(serialized).unwrap();
593        assert_eq!(event, deserialized);
594    }
595
596    #[test]
597    fn test_deserialize_unsubscribe() {
598        let event = ClientEvent::Unsubscribe {
599            channel: "lolwut".parse().unwrap(),
600        };
601        let serialized = json!({
602            "event": "pusher:unsubscribe",
603            "data": {
604                "channel": "lolwut",
605            },
606        });
607        let deserialized = serde_json::from_value::<ClientEvent>(serialized).unwrap();
608        assert_eq!(event, deserialized);
609    }
610
611    #[test]
612    fn test_deserialize_channel_event() {
613        let event = ClientEvent::ChannelEvent {
614            event: "client-lolwut".to_owned(),
615            channel: "lolwut".parse().unwrap(),
616            data: json!({ "lol": "wut" }),
617        };
618        let serialized = json!({
619            "event": "client-lolwut",
620            "channel": "lolwut",
621            "data": { "lol": "wut" },
622        });
623        let deserialized = serde_json::from_value::<ClientEvent>(serialized).unwrap();
624        assert_eq!(event, deserialized);
625    }
626}