1use std::{collections::HashMap, fmt::Display, str::FromStr};
2
3use rand::{distributions, prelude::Distribution};
4use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
5
6pub mod signature;
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
9pub struct SocketId(String);
10
11impl Display for SocketId {
12 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13 write!(f, "{}", self.0)
14 }
15}
16
17impl Distribution<SocketId> for distributions::Standard {
18 fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> SocketId {
19 let digits = distributions::Uniform::from(0..=9)
20 .sample_iter(rng)
21 .take(32)
22 .map(|s| s.to_string())
23 .collect::<String>();
24 let (p1, p2) = digits.split_at(16);
25 SocketId(format!("{p1}.{p2}"))
26 }
27}
28
29#[derive(Debug, Clone)]
30pub enum SocketIdParseError {
31 InvalidSocketId,
32}
33
34impl FromStr for SocketId {
35 type Err = SocketIdParseError;
36
37 fn from_str(s: &str) -> Result<Self, Self::Err> {
38 match s.find('.') {
39 Some(index) if index > 0 && index < s.len() => Ok(SocketId(s.to_owned())),
40 _ => Err(SocketIdParseError::InvalidSocketId),
41 }
42 }
43}
44
45impl AsRef<str> for SocketId {
46 fn as_ref(&self) -> &str {
47 &self.0
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum ChannelName {
53 Public(String),
54 Private(String),
55 Presence(String),
56 Encrypted(String),
57}
58
59impl AsRef<str> for ChannelName {
60 fn as_ref(&self) -> &str {
61 match self {
62 ChannelName::Public(ref name) => name,
63 ChannelName::Private(ref name) => name,
64 ChannelName::Presence(ref name) => name,
65 ChannelName::Encrypted(ref name) => name,
66 }
67 }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub enum ChannelNameParseError {
72 InvalidChannelName,
73}
74
75impl Display for ChannelNameParseError {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 ChannelNameParseError::InvalidChannelName => f.write_str("Invalid channel name"),
79 }
80 }
81}
82
83impl FromStr for ChannelName {
84 type Err = ChannelNameParseError;
85
86 fn from_str(s: &str) -> Result<Self, Self::Err> {
87 match s.splitn(3, '-').collect::<Vec<&str>>().as_slice() {
88 ["private", "encrypted", name, ..] if !name.is_empty() => {
89 Ok(ChannelName::Encrypted(s.to_owned()))
90 }
91 ["private", name, ..] if !name.is_empty() => Ok(ChannelName::Private(s.to_owned())),
92 ["presence", name, ..] if !name.is_empty() => Ok(ChannelName::Presence(s.to_owned())),
93 _ if !s.is_empty() => Ok(ChannelName::Public(s.to_owned())),
94 _ => Err(ChannelNameParseError::InvalidChannelName),
95 }
96 }
97}
98
99impl Display for ChannelName {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 match self {
102 ChannelName::Public(channel) => f.write_str(channel),
103 ChannelName::Private(channel) => f.write_str(channel),
104 ChannelName::Presence(channel) => f.write_str(channel),
105 ChannelName::Encrypted(channel) => f.write_str(channel),
106 }
107 }
108}
109
110impl Serialize for ChannelName {
111 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
112 serializer.serialize_str(self.as_ref())
113 }
114}
115
116impl<'de> Deserialize<'de> for ChannelName {
117 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
118 let name = String::deserialize(deserializer)?;
119 FromStr::from_str(&name).map_err(de::Error::custom)
120 }
121}
122
123#[derive(Debug, PartialEq, Serialize, Deserialize)]
124#[serde(from = "ClientEventJson", tag = "event", content = "data")]
125pub enum ClientEvent {
126 #[serde(rename = "pusher:signin")]
127 Signin { auth: String, user_data: String },
128 #[serde(rename = "pusher:subscribe")]
129 Subscribe {
130 channel: ChannelName,
131 auth: Option<String>,
132 channel_data: Option<serde_json::Value>,
133 },
134 #[serde(rename = "pusher:unsubscribe")]
135 Unsubscribe { channel: ChannelName },
136 #[serde(rename = "pusher:ping")]
137 Ping,
138 #[serde(untagged)]
139 ChannelEvent {
140 event: String,
141 channel: ChannelName,
142 data: serde_json::Value,
143 },
144}
145
146#[derive(Debug, PartialEq, Deserialize)]
147#[serde(tag = "event", content = "data")]
148enum PusherClientEvent {
149 #[serde(rename = "pusher:signin")]
150 Signin { auth: String, user_data: String },
151 #[serde(rename = "pusher:subscribe")]
152 Subscribe {
153 channel: ChannelName,
154 auth: Option<String>,
155 channel_data: Option<serde_json::Value>,
156 },
157 #[serde(rename = "pusher:unsubscribe")]
158 Unsubscribe { channel: ChannelName },
159 #[serde(rename = "pusher:ping")]
160 Ping { data: Option<serde_json::Value> },
161}
162
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164struct CustomClientEvent {
165 event: String,
166 channel: ChannelName,
167 data: serde_json::Value,
168}
169
170#[derive(Debug, PartialEq, Deserialize)]
171#[serde(untagged)]
172enum ClientEventJson {
173 PusherEvent(PusherClientEvent),
174 CustomEvent(CustomClientEvent),
175}
176
177impl From<ClientEventJson> for ClientEvent {
178 fn from(json: ClientEventJson) -> Self {
179 use ClientEventJson::*;
180 use PusherClientEvent::*;
181 match json {
182 PusherEvent(Signin { auth, user_data }) => ClientEvent::Signin { auth, user_data },
183 PusherEvent(Subscribe {
184 channel,
185 auth,
186 channel_data,
187 }) => ClientEvent::Subscribe {
188 channel,
189 auth,
190 channel_data,
191 },
192 PusherEvent(Unsubscribe { channel }) => ClientEvent::Unsubscribe { channel },
193 PusherEvent(Ping { .. }) => ClientEvent::Ping,
194 CustomEvent(CustomClientEvent {
195 event,
196 channel,
197 data,
198 }) => ClientEvent::ChannelEvent {
199 event,
200 channel,
201 data,
202 },
203 }
204 }
205}
206
207#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
208pub struct SigninInformation {
209 #[serde(with = "json_string")]
210 pub user_data: UserData,
211}
212
213#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
214pub struct PresenceInformation {
215 ids: Vec<String>,
216 hash: HashMap<String, HashMap<String, String>>,
217 count: u32,
218}
219
220#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
221pub struct PresenceUser {
222 #[serde(rename = "user_id")]
223 id: String,
224 #[serde(rename = "user_info")]
225 info: serde_json::Value,
226}
227
228#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
229pub struct RemovedMember {
230 #[serde(rename = "user_id")]
231 id: String,
232}
233
234#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
235pub struct CustomEvent {
236 pub event: String,
237 pub channel: ChannelName,
238 #[serde(with = "json_string")]
239 pub data: serde_json::Value,
240 #[serde(skip_serializing_if = "Option::is_none")]
241 pub user_id: Option<String>,
242}
243
244#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
245pub struct ConnectionInfo {
246 pub socket_id: SocketId,
247 pub activity_timeout: u8,
248}
249
250#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
251pub struct UserData {
252 pub id: String,
253 pub user_info: Option<serde_json::Value>,
254 pub watchlist: Option<Vec<String>>,
255}
256
257#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
258#[serde(from = "ServerEventJson", into = "ServerEventJson")]
259pub enum ServerEvent {
260 #[serde(rename = "pusher:connection_established")]
261 ConnectionEstablished {
262 #[serde(with = "json_string")]
263 data: ConnectionInfo,
264 },
265
266 #[serde(rename = "pusher:signin_success")]
267 SigninSucceeded {
268 data: SigninInformation,
269 },
270
271 #[serde(rename = "pusher:error")]
272 Error {
273 message: String,
274 code: Option<u16>,
275 },
276
277 #[serde(rename = "pusher:pong")]
278 Pong,
279
280 #[serde(rename = "pusher_internal:subscription_succeeded")]
281 SubscriptionSucceeded {
282 channel: ChannelName,
283 #[serde(with = "json_string")]
284 data: Option<PresenceInformation>,
285 },
286
287 #[serde(rename = "pusher_internal:member_added")]
288 MemberAdded {
289 channel: ChannelName,
290 #[serde(with = "json_string")]
291 data: PresenceUser,
292 },
293
294 #[serde(rename = "pusher_internal:member_removed")]
295 MemberRemoved {
296 channel: ChannelName,
297 #[serde(with = "json_string")]
298 data: RemovedMember,
299 },
300
301 ChannelEvent(CustomEvent),
302}
303
304impl ServerEvent {
305 pub fn signin_succeeded(user_data: UserData) -> Self {
306 Self::SigninSucceeded {
307 data: SigninInformation { user_data },
308 }
309 }
310
311 pub fn subscription_succeeded(channel: impl Into<ChannelName>) -> Self {
312 Self::SubscriptionSucceeded {
313 channel: channel.into(),
314 data: None,
315 }
316 }
317
318 pub fn custom_event(
319 event: impl Into<String>,
320 channel: impl Into<ChannelName>,
321 data: impl Into<serde_json::Value>,
322 user_id: impl Into<Option<String>>,
323 ) -> Self {
324 Self::ChannelEvent(CustomEvent {
325 event: event.into(),
326 channel: channel.into(),
327 data: data.into(),
328 user_id: user_id.into(),
329 })
330 }
331
332 pub fn invalid_signature_error() -> Self {
333 Self::error("Invalid signature", 409)
334 }
335
336 pub fn authentication_error(message: impl Into<String>) -> Self {
337 Self::error(message, 409)
338 }
339
340 pub fn error(message: impl Into<String>, code: impl Into<Option<u16>>) -> Self {
341 Self::Error {
342 message: message.into(),
343 code: code.into(),
344 }
345 }
346}
347
348impl From<ServerEventJson> for ServerEvent {
349 fn from(json: ServerEventJson) -> Self {
350 use PusherServerEvent::*;
351 use ServerEventJson::*;
352 match json {
353 PusherEvent(ConnectionEstablished { data }) => {
354 ServerEvent::ConnectionEstablished { data }
355 }
356 PusherEvent(SigninSucceeded { data }) => ServerEvent::SigninSucceeded { data },
357 PusherEvent(Error { message, code }) => ServerEvent::Error { message, code },
358 PusherEvent(Pong) => ServerEvent::Pong,
359 PusherEvent(SubscriptionSucceeded { channel, data }) => {
360 ServerEvent::SubscriptionSucceeded { channel, data }
361 }
362 PusherEvent(MemberAdded { channel, data }) => {
363 ServerEvent::MemberAdded { channel, data }
364 }
365 PusherEvent(MemberRemoved { channel, data }) => {
366 ServerEvent::MemberRemoved { channel, data }
367 }
368 UserEvent(event) => ServerEvent::ChannelEvent(event),
369 }
370 }
371}
372
373impl From<ServerEvent> for ServerEventJson {
374 fn from(value: ServerEvent) -> Self {
375 use ServerEvent::*;
376 use ServerEventJson::*;
377 match value {
378 ConnectionEstablished { data } => {
379 PusherEvent(PusherServerEvent::ConnectionEstablished { data })
380 }
381 SigninSucceeded { data } => PusherEvent(PusherServerEvent::SigninSucceeded { data }),
382 Error { message, code } => PusherEvent(PusherServerEvent::Error { message, code }),
383 Pong => PusherEvent(PusherServerEvent::Pong),
384 SubscriptionSucceeded { channel, data } => {
385 PusherEvent(PusherServerEvent::SubscriptionSucceeded { channel, data })
386 }
387 MemberRemoved { channel, data } => {
388 PusherEvent(PusherServerEvent::MemberRemoved { channel, data })
389 }
390 MemberAdded { channel, data } => {
391 PusherEvent(PusherServerEvent::MemberAdded { channel, data })
392 }
393 ChannelEvent(event) => UserEvent(event),
394 }
395 }
396}
397
398#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
399#[serde(untagged)]
400enum ServerEventJson {
401 PusherEvent(PusherServerEvent),
402 UserEvent(CustomEvent),
403}
404
405#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
406#[serde(tag = "event")]
407pub enum PusherServerEvent {
408 #[serde(rename = "pusher:connection_established")]
409 ConnectionEstablished {
410 #[serde(with = "json_string")]
411 data: ConnectionInfo,
412 },
413
414 #[serde(rename = "pusher:signin_success")]
415 SigninSucceeded { data: SigninInformation },
416
417 #[serde(rename = "pusher:error")]
418 Error { message: String, code: Option<u16> },
419
420 #[serde(rename = "pusher:pong")]
421 Pong,
422
423 #[serde(rename = "pusher_internal:subscription_succeeded")]
424 SubscriptionSucceeded {
425 channel: ChannelName,
426 #[serde(with = "json_string")]
427 data: Option<PresenceInformation>,
428 },
429
430 #[serde(rename = "pusher_internal:member_added")]
431 MemberAdded {
432 channel: ChannelName,
433 #[serde(with = "json_string")]
434 data: PresenceUser,
435 },
436
437 #[serde(rename = "pusher_internal:member_removed")]
438 MemberRemoved {
439 channel: ChannelName,
440 #[serde(with = "json_string")]
441 data: RemovedMember,
442 },
443}
444
445mod json_string {
446 use serde::{
447 de::{self, DeserializeOwned},
448 ser::{self, Serialize, Serializer},
449 Deserialize, Deserializer,
450 };
451
452 pub fn serialize<T: Serialize, S: Serializer>(
453 value: &T,
454 serializer: S,
455 ) -> Result<S::Ok, S::Error> {
456 let json = serde_json::to_string(value).map_err(ser::Error::custom)?;
457 json.serialize(serializer)
458 }
459
460 pub fn deserialize<'de, T: DeserializeOwned, D: Deserializer<'de>>(
461 deserializer: D,
462 ) -> Result<T, D::Error> {
463 let json = String::deserialize(deserializer)?;
464 serde_json::from_str(&json).map_err(de::Error::custom)
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use serde_json::json;
472
473 #[test]
474 fn parse_channel_name() {
475 assert_eq!(Ok(ChannelName::Public("lol".to_owned())), "lol".parse());
476 assert_eq!(
477 Ok(ChannelName::Private("private-lol".to_owned())),
478 "private-lol".parse()
479 );
480 assert_eq!(
481 Ok(ChannelName::Presence("presence-lol".to_owned())),
482 "presence-lol".parse()
483 );
484 assert_eq!(
485 Ok(ChannelName::Encrypted("private-encrypted-lol".to_owned())),
486 "private-encrypted-lol".parse()
487 );
488 assert_eq!(
489 Err(ChannelNameParseError::InvalidChannelName),
490 "".parse::<ChannelName>()
491 );
492 }
493
494 #[test]
495 fn test_member_removed() {
496 let event = ServerEvent::MemberRemoved {
497 channel: "channel".parse().unwrap(),
498 data: RemovedMember {
499 id: "lolwut".to_owned(),
500 },
501 };
502
503 let serialized = serde_json::to_value(&event).unwrap();
504
505 let expected = json!({
506 "event": "pusher_internal:member_removed",
507 "channel": "channel",
508 "data": r#"{"user_id":"lolwut"}"#,
509 });
510
511 assert_eq!(expected, serialized);
512
513 let deserialized = serde_json::from_value(expected).unwrap();
514
515 assert_eq!(event, deserialized);
516 }
517
518 #[test]
519 fn test_custom_event() {
520 let event = ServerEvent::ChannelEvent(CustomEvent {
521 event: "client-message".to_owned(),
522 channel: "channel".parse().unwrap(),
523 data: json!({ "some": "data" }),
524 user_id: Some("user".to_owned()),
525 });
526
527 let serialized = serde_json::to_value(&event).unwrap();
528
529 let expected = json!({
530 "event": "client-message",
531 "channel": "channel",
532 "data": r#"{"some":"data"}"#,
533 "user_id": "user",
534 });
535
536 assert_eq!(expected, serialized);
537
538 let deserialized = serde_json::from_value(expected).unwrap();
539
540 assert_eq!(event, deserialized);
541 }
542
543 #[test]
544 fn test_deserialize_ping() {
545 let event = ClientEvent::Ping;
546 let serialized = json!({ "event": "pusher:ping", "data": {} });
547 let deserialized = serde_json::from_value::<ClientEvent>(serialized).unwrap();
548 assert_eq!(event, deserialized);
549 }
550
551 #[test]
552 fn test_deserialize_signin() {
553 let event = ClientEvent::Signin {
554 auth: "1234".to_owned(),
555 user_data: serde_json::to_string(&UserData {
556 id: "user1".to_owned(),
557 user_info: Some(json!({ "lol": "wut" })),
558 watchlist: Some(vec!["user2".to_owned(), "user3".to_owned()]),
559 })
560 .unwrap(),
561 };
562
563 let serialized = json!({
564 "event": "pusher:signin",
565 "data": {
566 "auth": "1234",
567 "user_data": serde_json::to_string(&json!({
568 "id": "user1",
569 "user_info": { "lol": "wut" },
570 "watchlist": ["user2", "user3"],
571 })).unwrap(),
572 },
573 });
574 let deserialized = serde_json::from_value::<ClientEvent>(serialized).unwrap();
575 assert_eq!(event, deserialized);
576 }
577
578 #[test]
579 fn test_deserialize_subscribe() {
580 let event = ClientEvent::Subscribe {
581 channel: "lolwut".parse().unwrap(),
582 auth: None,
583 channel_data: Some(json!({ "lol": "wut" })),
584 };
585 let serialized = json!({
586 "event": "pusher:subscribe",
587 "data": {
588 "channel": "lolwut",
589 "channel_data": { "lol": "wut" },
590 },
591 });
592 let deserialized = serde_json::from_value::<ClientEvent>(serialized).unwrap();
593 assert_eq!(event, deserialized);
594 }
595
596 #[test]
597 fn test_deserialize_unsubscribe() {
598 let event = ClientEvent::Unsubscribe {
599 channel: "lolwut".parse().unwrap(),
600 };
601 let serialized = json!({
602 "event": "pusher:unsubscribe",
603 "data": {
604 "channel": "lolwut",
605 },
606 });
607 let deserialized = serde_json::from_value::<ClientEvent>(serialized).unwrap();
608 assert_eq!(event, deserialized);
609 }
610
611 #[test]
612 fn test_deserialize_channel_event() {
613 let event = ClientEvent::ChannelEvent {
614 event: "client-lolwut".to_owned(),
615 channel: "lolwut".parse().unwrap(),
616 data: json!({ "lol": "wut" }),
617 };
618 let serialized = json!({
619 "event": "client-lolwut",
620 "channel": "lolwut",
621 "data": { "lol": "wut" },
622 });
623 let deserialized = serde_json::from_value::<ClientEvent>(serialized).unwrap();
624 assert_eq!(event, deserialized);
625 }
626}