1use std::{
2 collections::{HashMap, HashSet},
3 sync::{
4 Arc,
5 atomic::{AtomicU64, Ordering},
6 },
7 time::Instant,
8};
9
10use chrono::Utc;
11use tokio::sync::mpsc;
12
13use crate::protocol::{DEFAULT_EVENT, ServerFrame};
14
15use super::{
16 Channel, ChannelName, ConnectionId, ConnectionMeta, DisconnectReason, Event, Payload,
17 RealtimeConfig, RealtimeError, SessionAuth, UserId,
18 policy::{ChannelPolicy, DefaultChannelPolicy},
19 session,
20};
21
22const HUB_QUEUE_SIZE: usize = 4096;
23const INBOUND_QUEUE_SIZE: usize = 4096;
24
25pub type SubscriptionId = u64;
26type ChannelHandler = Arc<dyn Fn(Payload) + Send + Sync>;
27type GlobalHandler = Arc<dyn Fn(Channel, Payload) + Send + Sync>;
28type ChannelEventHandler = Arc<dyn Fn(Event, Payload) + Send + Sync>;
29type GlobalEventHandler = Arc<dyn Fn(Channel, Event, Payload) + Send + Sync>;
30type ChannelHandlers =
31 Arc<std::sync::Mutex<HashMap<Channel, HashMap<SubscriptionId, ChannelHandler>>>>;
32type GlobalHandlers = Arc<std::sync::Mutex<HashMap<SubscriptionId, GlobalHandler>>>;
33type ChannelEventHandlers =
34 Arc<std::sync::Mutex<HashMap<Channel, HashMap<SubscriptionId, ChannelEventHandler>>>>;
35type GlobalEventHandlers = Arc<std::sync::Mutex<HashMap<SubscriptionId, GlobalEventHandler>>>;
36
37#[derive(Clone)]
38pub(crate) struct InboundMessage {
39 pub channel: Channel,
40 pub event: Event,
41 pub payload: Payload,
42}
43
44#[derive(Clone)]
45pub struct SocketServerHandle {
46 config: RealtimeConfig,
47 tx: Option<mpsc::Sender<HubCommand>>,
48 channel_handlers: ChannelHandlers,
49 global_handlers: GlobalHandlers,
50 channel_event_handlers: ChannelEventHandlers,
51 global_event_handlers: GlobalEventHandlers,
52 next_subscription_id: Arc<AtomicU64>,
53}
54
55impl SocketServerHandle {
56 pub fn spawn(config: RealtimeConfig) -> Self {
57 Self::spawn_with_policy(config, Arc::new(DefaultChannelPolicy))
58 }
59
60 pub fn spawn_with_policy(config: RealtimeConfig, policy: Arc<dyn ChannelPolicy>) -> Self {
61 let channel_handlers: ChannelHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
62 let global_handlers: GlobalHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
63 let channel_event_handlers: ChannelEventHandlers =
64 Arc::new(std::sync::Mutex::new(HashMap::new()));
65 let global_event_handlers: GlobalEventHandlers =
66 Arc::new(std::sync::Mutex::new(HashMap::new()));
67 let next_subscription_id = Arc::new(AtomicU64::new(1));
68
69 if !config.enabled {
70 return Self {
71 config,
72 tx: None,
73 channel_handlers,
74 global_handlers,
75 channel_event_handlers,
76 global_event_handlers,
77 next_subscription_id,
78 };
79 }
80
81 let (tx, rx) = mpsc::channel(HUB_QUEUE_SIZE);
82 let (inbound_tx, inbound_rx) = mpsc::channel(INBOUND_QUEUE_SIZE);
83 let mut hub = SocketServer::new(config.clone(), rx, policy, Some(inbound_tx));
84 tokio::spawn(async move {
85 hub.run().await;
86 });
87 spawn_inbound_dispatcher(
88 inbound_rx,
89 Arc::clone(&channel_handlers),
90 Arc::clone(&global_handlers),
91 Arc::clone(&channel_event_handlers),
92 Arc::clone(&global_event_handlers),
93 );
94
95 Self {
96 config,
97 tx: Some(tx),
98 channel_handlers,
99 global_handlers,
100 channel_event_handlers,
101 global_event_handlers,
102 next_subscription_id,
103 }
104 }
105
106 pub fn disabled(config: RealtimeConfig) -> Self {
107 Self {
108 config,
109 tx: None,
110 channel_handlers: Arc::new(std::sync::Mutex::new(HashMap::new())),
111 global_handlers: Arc::new(std::sync::Mutex::new(HashMap::new())),
112 channel_event_handlers: Arc::new(std::sync::Mutex::new(HashMap::new())),
113 global_event_handlers: Arc::new(std::sync::Mutex::new(HashMap::new())),
114 next_subscription_id: Arc::new(AtomicU64::new(1)),
115 }
116 }
117
118 pub fn is_enabled(&self) -> bool {
119 self.config.enabled && self.tx.is_some()
120 }
121
122 pub fn max_message_bytes(&self) -> usize {
123 self.config.max_message_bytes
124 }
125
126 pub async fn serve_socket(&self, socket: axum::extract::ws::WebSocket, auth: SessionAuth) {
127 let Some(hub_tx) = self.tx.clone() else {
128 return;
129 };
130 session::run_socket_session(socket, auth, hub_tx, self.config.clone()).await;
131 }
132
133 pub async fn send(
134 &self,
135 channel_name: impl Into<Channel>,
136 message: Payload,
137 ) -> Result<(), RealtimeError> {
138 self.send_event(channel_name, DEFAULT_EVENT, message).await
139 }
140
141 pub async fn send_to_user(
142 &self,
143 user_id: impl Into<UserId>,
144 message: Payload,
145 ) -> Result<(), RealtimeError> {
146 self.send_event_to_user(user_id, DEFAULT_EVENT, message)
147 .await
148 }
149
150 pub async fn send_event(
151 &self,
152 channel_name: impl Into<Channel>,
153 event: impl Into<Event>,
154 payload: Payload,
155 ) -> Result<(), RealtimeError> {
156 let Some(tx) = &self.tx else {
157 return Ok(());
158 };
159 let channel_name = channel_name.into();
160 let channel = ChannelName::parse(&channel_name)?;
161 tx.send(HubCommand::SendToChannel {
162 channel,
163 event: event.into(),
164 payload,
165 })
166 .await
167 .map_err(|_| RealtimeError::internal("realtime hub is unavailable"))
168 }
169
170 pub async fn send_event_to_user(
171 &self,
172 user_id: impl Into<UserId>,
173 event: impl Into<Event>,
174 payload: Payload,
175 ) -> Result<(), RealtimeError> {
176 let Some(tx) = &self.tx else {
177 return Ok(());
178 };
179 tx.send(HubCommand::SendToUser {
180 user_id: user_id.into(),
181 event: event.into(),
182 payload,
183 })
184 .await
185 .map_err(|_| RealtimeError::internal("realtime hub is unavailable"))
186 }
187
188 pub async fn emit_to_user(
189 &self,
190 user_id: impl Into<UserId>,
191 event: impl Into<Event>,
192 payload: Payload,
193 ) -> Result<(), RealtimeError> {
194 self.send_event_to_user(user_id, event, payload).await
195 }
196
197 pub fn on_message<F>(&self, channel: &str, handler: F) -> SubscriptionId
198 where
199 F: Fn(Payload) + Send + Sync + 'static,
200 {
201 let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
202 let mut handlers = self
203 .channel_handlers
204 .lock()
205 .expect("channel handler mutex poisoned");
206 handlers
207 .entry(channel.to_string())
208 .or_default()
209 .insert(id, Arc::new(handler));
210 id
211 }
212
213 pub fn on_messages<F>(&self, handler: F) -> SubscriptionId
214 where
215 F: Fn(Channel, Payload) + Send + Sync + 'static,
216 {
217 let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
218 self.global_handlers
219 .lock()
220 .expect("global handler mutex poisoned")
221 .insert(id, Arc::new(handler));
222 id
223 }
224
225 pub fn on_channel_event<F>(&self, channel: &str, handler: F) -> SubscriptionId
226 where
227 F: Fn(Event, Payload) + Send + Sync + 'static,
228 {
229 let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
230 let mut handlers = self
231 .channel_event_handlers
232 .lock()
233 .expect("channel event handler mutex poisoned");
234 handlers
235 .entry(channel.to_string())
236 .or_default()
237 .insert(id, Arc::new(handler));
238 id
239 }
240
241 pub fn on_events<F>(&self, handler: F) -> SubscriptionId
242 where
243 F: Fn(Channel, Event, Payload) + Send + Sync + 'static,
244 {
245 let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
246 self.global_event_handlers
247 .lock()
248 .expect("global event handler mutex poisoned")
249 .insert(id, Arc::new(handler));
250 id
251 }
252
253 pub fn off(&self, id: SubscriptionId) -> bool {
254 let mut removed = false;
255
256 let mut global = self
257 .global_handlers
258 .lock()
259 .expect("global handler mutex poisoned");
260 if global.remove(&id).is_some() {
261 removed = true;
262 }
263 drop(global);
264
265 let mut channels = self
266 .channel_handlers
267 .lock()
268 .expect("channel handler mutex poisoned");
269 for handlers in channels.values_mut() {
270 if handlers.remove(&id).is_some() {
271 removed = true;
272 }
273 }
274
275 let mut global_events = self
276 .global_event_handlers
277 .lock()
278 .expect("global event handler mutex poisoned");
279 if global_events.remove(&id).is_some() {
280 removed = true;
281 }
282 drop(global_events);
283
284 let mut channel_events = self
285 .channel_event_handlers
286 .lock()
287 .expect("channel event handler mutex poisoned");
288 for handlers in channel_events.values_mut() {
289 if handlers.remove(&id).is_some() {
290 removed = true;
291 }
292 }
293
294 removed
295 }
296}
297
298pub(crate) enum HubCommand {
299 Register {
300 meta: ConnectionMeta,
301 outbound_tx: mpsc::Sender<ServerFrame>,
302 },
303 Unregister {
304 conn_id: ConnectionId,
305 reason: DisconnectReason,
306 },
307 Join {
308 conn_id: ConnectionId,
309 channel: ChannelName,
310 req_id: String,
311 },
312 Leave {
313 conn_id: ConnectionId,
314 channel: ChannelName,
315 req_id: String,
316 },
317 Emit {
318 conn_id: ConnectionId,
319 channel: ChannelName,
320 event: Event,
321 payload: Payload,
322 req_id: String,
323 },
324 Ping {
325 conn_id: ConnectionId,
326 req_id: String,
327 },
328 SendToChannel {
329 channel: ChannelName,
330 event: Event,
331 payload: Payload,
332 },
333 SendToUser {
334 user_id: UserId,
335 event: Event,
336 payload: Payload,
337 },
338}
339
340struct SocketServer {
341 config: RealtimeConfig,
342 rx: mpsc::Receiver<HubCommand>,
343 policy: Arc<dyn ChannelPolicy>,
344 inbound_tx: Option<mpsc::Sender<InboundMessage>>,
345 connections: HashMap<ConnectionId, ConnectionState>,
346 users: HashMap<UserId, HashSet<ConnectionId>>,
347 channels: HashMap<ChannelName, HashSet<ConnectionId>>,
348 connection_channels: HashMap<ConnectionId, HashSet<ChannelName>>,
349}
350
351struct ConnectionState {
352 meta: ConnectionMeta,
353 outbound_tx: mpsc::Sender<ServerFrame>,
354 rate: ConnectionRateState,
355}
356
357struct ConnectionRateState {
358 join_window_started_at: Instant,
359 joins_in_window: u32,
360 emit_window_started_at: Instant,
361 emits_in_window: u32,
362}
363
364impl SocketServer {
365 fn new(
366 config: RealtimeConfig,
367 rx: mpsc::Receiver<HubCommand>,
368 policy: Arc<dyn ChannelPolicy>,
369 inbound_tx: Option<mpsc::Sender<InboundMessage>>,
370 ) -> Self {
371 Self {
372 config,
373 rx,
374 policy,
375 inbound_tx,
376 connections: HashMap::new(),
377 users: HashMap::new(),
378 channels: HashMap::new(),
379 connection_channels: HashMap::new(),
380 }
381 }
382
383 async fn run(&mut self) {
384 while let Some(command) = self.rx.recv().await {
385 self.handle_command(command);
386 }
387 }
388
389 fn handle_command(&mut self, command: HubCommand) {
390 match command {
391 HubCommand::Register { meta, outbound_tx } => self.register(meta, outbound_tx),
392 HubCommand::Unregister { conn_id, reason } => self.unregister(conn_id, reason),
393 HubCommand::Join {
394 conn_id,
395 channel,
396 req_id,
397 } => self.handle_join(conn_id, channel, req_id),
398 HubCommand::Leave {
399 conn_id,
400 channel,
401 req_id,
402 } => self.handle_leave(conn_id, channel, req_id),
403 HubCommand::Emit {
404 conn_id,
405 channel,
406 event,
407 payload,
408 req_id,
409 } => self.handle_emit(conn_id, channel, event, payload, req_id),
410 HubCommand::Ping { conn_id, req_id } => self.handle_ping(conn_id, req_id),
411 HubCommand::SendToChannel {
412 channel,
413 event,
414 payload,
415 } => self.handle_send_to_channel(channel, event, payload),
416 HubCommand::SendToUser {
417 user_id,
418 event,
419 payload,
420 } => self.handle_send_to_user(user_id, event, payload),
421 }
422 }
423
424 fn register(&mut self, meta: ConnectionMeta, outbound_tx: mpsc::Sender<ServerFrame>) {
425 if self.connections.len() >= self.config.max_connections {
426 let _ = outbound_tx.try_send(ServerFrame::error(
427 "capacity_exceeded",
428 "Realtime server is at capacity",
429 ));
430 return;
431 }
432
433 let conn_id = meta.id;
434 let user_id = meta.user_id.clone();
435 let now = Instant::now();
436
437 self.connections.insert(
438 conn_id,
439 ConnectionState {
440 meta,
441 outbound_tx: outbound_tx.clone(),
442 rate: ConnectionRateState {
443 join_window_started_at: now,
444 joins_in_window: 0,
445 emit_window_started_at: now,
446 emits_in_window: 0,
447 },
448 },
449 );
450
451 self.users
452 .entry(user_id.clone())
453 .or_default()
454 .insert(conn_id);
455
456 let _ = outbound_tx.try_send(ServerFrame::connected(conn_id.to_string(), user_id.clone()));
457
458 let private_channel = ChannelName(format!("user:{user_id}"));
459 if let Some(reason) = self.join_internal(conn_id, private_channel.clone()) {
460 self.unregister(conn_id, reason);
461 return;
462 }
463 self.send_frame(
464 conn_id,
465 ServerFrame::Joined {
466 id: uuid::Uuid::new_v4().to_string(),
467 channel: private_channel.to_string(),
468 ts: Utc::now().timestamp(),
469 },
470 );
471 }
472
473 fn unregister(&mut self, conn_id: ConnectionId, reason: DisconnectReason) {
474 let Some(existing) = self.connections.remove(&conn_id) else {
475 return;
476 };
477
478 tracing::debug!(
479 conn_id = %conn_id,
480 user_id = %existing.meta.user_id,
481 reason = ?reason,
482 "realtime connection disconnected"
483 );
484
485 if let Some(user_set) = self.users.get_mut(&existing.meta.user_id) {
486 user_set.remove(&conn_id);
487 if user_set.is_empty() {
488 self.users.remove(&existing.meta.user_id);
489 }
490 }
491
492 if let Some(joined) = self.connection_channels.remove(&conn_id) {
493 for channel in joined {
494 if let Some(member_set) = self.channels.get_mut(&channel) {
495 member_set.remove(&conn_id);
496 if member_set.is_empty() {
497 self.channels.remove(&channel);
498 }
499 }
500 }
501 }
502 }
503
504 fn handle_join(&mut self, conn_id: ConnectionId, channel: ChannelName, req_id: String) {
505 tracing::debug!(
506 conn_id = %conn_id,
507 channel = %channel,
508 req_id = %req_id,
509 "realtime join requested"
510 );
511
512 if !self.check_join_rate(conn_id) {
513 tracing::debug!(
514 conn_id = %conn_id,
515 channel = %channel,
516 req_id = %req_id,
517 "realtime join denied: rate limited"
518 );
519 self.send_frame(
520 conn_id,
521 ServerFrame::ack_err(req_id, "rate_limited", "Join rate limit exceeded"),
522 );
523 return;
524 }
525
526 let Some(meta) = self.connections.get(&conn_id).map(|conn| conn.meta.clone()) else {
527 return;
528 };
529
530 if let Err(err) = self.policy.can_join(&meta, &channel) {
531 tracing::debug!(
532 conn_id = %conn_id,
533 user_id = %meta.user_id,
534 channel = %channel,
535 req_id = %req_id,
536 reason = %err,
537 "realtime join denied by policy"
538 );
539 self.send_frame(
540 conn_id,
541 ServerFrame::ack_err(req_id, "forbidden_channel", err.message()),
542 );
543 return;
544 }
545
546 if self
547 .connection_channels
548 .get(&conn_id)
549 .is_some_and(|set| set.contains(&channel))
550 {
551 tracing::debug!(
552 conn_id = %conn_id,
553 channel = %channel,
554 req_id = %req_id,
555 "realtime join acknowledged: already joined"
556 );
557 self.send_frame(conn_id, ServerFrame::ack_ok(req_id));
558 return;
559 }
560
561 if self
562 .connection_channels
563 .get(&conn_id)
564 .map(|set| set.len())
565 .unwrap_or(0)
566 >= self.config.max_channels_per_connection
567 {
568 tracing::debug!(
569 conn_id = %conn_id,
570 channel = %channel,
571 req_id = %req_id,
572 "realtime join denied: channel limit exceeded"
573 );
574 self.send_frame(
575 conn_id,
576 ServerFrame::ack_err(
577 req_id,
578 "channel_limit_exceeded",
579 "Maximum channels per connection reached",
580 ),
581 );
582 return;
583 }
584
585 if let Some(reason) = self.join_internal(conn_id, channel.clone()) {
586 tracing::debug!(
587 conn_id = %conn_id,
588 channel = %channel,
589 reason = ?reason,
590 "realtime join caused disconnect"
591 );
592 self.unregister(conn_id, reason);
593 return;
594 }
595
596 tracing::debug!(
597 conn_id = %conn_id,
598 user_id = %meta.user_id,
599 channel = %channel,
600 req_id = %req_id,
601 "realtime join succeeded"
602 );
603
604 self.send_frame(conn_id, ServerFrame::ack_ok(req_id));
605 self.send_frame(
606 conn_id,
607 ServerFrame::Joined {
608 id: uuid::Uuid::new_v4().to_string(),
609 channel: channel.to_string(),
610 ts: Utc::now().timestamp(),
611 },
612 );
613 }
614
615 fn handle_leave(&mut self, conn_id: ConnectionId, channel: ChannelName, req_id: String) {
616 tracing::debug!(
617 conn_id = %conn_id,
618 channel = %channel,
619 req_id = %req_id,
620 "realtime leave requested"
621 );
622 let was_member = self
623 .connection_channels
624 .get(&conn_id)
625 .is_some_and(|set| set.contains(&channel));
626 if !was_member {
627 tracing::debug!(
628 conn_id = %conn_id,
629 channel = %channel,
630 req_id = %req_id,
631 "realtime leave denied: not joined"
632 );
633 self.send_frame(
634 conn_id,
635 ServerFrame::ack_err(req_id, "channel_not_joined", "Not a member of channel"),
636 );
637 return;
638 }
639
640 self.leave_internal(conn_id, &channel);
641 tracing::debug!(
642 conn_id = %conn_id,
643 channel = %channel,
644 req_id = %req_id,
645 "realtime leave succeeded"
646 );
647 self.send_frame(conn_id, ServerFrame::ack_ok(req_id));
648 self.send_frame(
649 conn_id,
650 ServerFrame::Left {
651 id: uuid::Uuid::new_v4().to_string(),
652 channel: channel.to_string(),
653 ts: Utc::now().timestamp(),
654 },
655 );
656 }
657
658 fn handle_emit(
659 &mut self,
660 conn_id: ConnectionId,
661 channel: ChannelName,
662 event: Event,
663 payload: Payload,
664 req_id: String,
665 ) {
666 if !self.check_emit_rate(conn_id) {
667 self.send_frame(
668 conn_id,
669 ServerFrame::ack_err(req_id, "rate_limited", "Emit rate limit exceeded"),
670 );
671 return;
672 }
673
674 let Some(meta) = self.connections.get(&conn_id).map(|conn| conn.meta.clone()) else {
675 return;
676 };
677
678 if let Err(err) = self.policy.can_publish(&meta, &channel, &event) {
679 self.send_frame(
680 conn_id,
681 ServerFrame::ack_err(req_id, "forbidden_channel", err.message()),
682 );
683 return;
684 }
685
686 let sender_is_member = self
687 .connection_channels
688 .get(&conn_id)
689 .is_some_and(|set| set.contains(&channel));
690 if !sender_is_member {
691 self.send_frame(
692 conn_id,
693 ServerFrame::ack_err(req_id, "channel_not_joined", "Join channel before emitting"),
694 );
695 return;
696 }
697
698 let recipients = self.channels.get(&channel).cloned().unwrap_or_default();
699 let include_sender = should_echo_to_sender(&channel);
700 self.publish_inbound(InboundMessage {
701 channel: channel.to_string(),
702 event: event.clone(),
703 payload: payload.clone(),
704 });
705 let event_frame = ServerFrame::event(
706 channel.to_string(),
707 event,
708 payload,
709 Some(meta.user_id.clone()),
710 );
711 for recipient_id in recipients {
712 if recipient_id == conn_id && !include_sender {
713 continue;
714 }
715 self.send_frame(recipient_id, event_frame.clone());
716 }
717
718 self.send_frame(conn_id, ServerFrame::ack_ok(req_id));
719 }
720
721 fn handle_ping(&mut self, conn_id: ConnectionId, req_id: String) {
722 self.send_frame(conn_id, ServerFrame::pong(req_id));
723 }
724
725 fn handle_send_to_channel(&mut self, channel: ChannelName, event: Event, payload: Payload) {
726 let Some(conn_ids) = self.channels.get(&channel).cloned() else {
727 return;
728 };
729
730 let frame = ServerFrame::event(channel.to_string(), event, payload, None);
731 for conn_id in conn_ids {
732 self.send_frame(conn_id, frame.clone());
733 }
734 }
735
736 fn handle_send_to_user(&mut self, user_id: UserId, event: Event, payload: Payload) {
737 let Some(conn_ids) = self.users.get(&user_id).cloned() else {
738 return;
739 };
740
741 let channel = format!("user:{user_id}");
742 let frame = ServerFrame::event(channel, event, payload, None);
743 for conn_id in conn_ids {
744 self.send_frame(conn_id, frame.clone());
745 }
746 }
747
748 fn publish_inbound(&mut self, message: InboundMessage) {
749 let Some(tx) = &self.inbound_tx else {
750 return;
751 };
752
753 match tx.try_send(message) {
754 Ok(_) => {}
755 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
756 tracing::debug!("realtime inbound dispatch queue is full; dropping message");
757 }
758 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
759 self.inbound_tx = None;
760 }
761 }
762 }
763
764 fn check_join_rate(&mut self, conn_id: ConnectionId) -> bool {
765 let Some(state) = self.connections.get_mut(&conn_id) else {
766 return false;
767 };
768 allow_within_window(
769 &mut state.rate.join_window_started_at,
770 &mut state.rate.joins_in_window,
771 self.config.join_rate_per_sec,
772 )
773 }
774
775 fn check_emit_rate(&mut self, conn_id: ConnectionId) -> bool {
776 let Some(state) = self.connections.get_mut(&conn_id) else {
777 return false;
778 };
779 allow_within_window(
780 &mut state.rate.emit_window_started_at,
781 &mut state.rate.emits_in_window,
782 self.config.emit_rate_per_sec,
783 )
784 }
785
786 fn join_internal(
787 &mut self,
788 conn_id: ConnectionId,
789 channel: ChannelName,
790 ) -> Option<DisconnectReason> {
791 self.connection_channels
792 .entry(conn_id)
793 .or_default()
794 .insert(channel.clone());
795 self.channels.entry(channel).or_default().insert(conn_id);
796 None
797 }
798
799 fn leave_internal(&mut self, conn_id: ConnectionId, channel: &ChannelName) {
800 if let Some(set) = self.connection_channels.get_mut(&conn_id) {
801 set.remove(channel);
802 if set.is_empty() {
803 self.connection_channels.remove(&conn_id);
804 }
805 }
806
807 if let Some(set) = self.channels.get_mut(channel) {
808 set.remove(&conn_id);
809 if set.is_empty() {
810 self.channels.remove(channel);
811 }
812 }
813 }
814
815 fn send_frame(&mut self, conn_id: ConnectionId, frame: ServerFrame) {
816 let Some(outbound_tx) = self
817 .connections
818 .get(&conn_id)
819 .map(|connection| connection.outbound_tx.clone())
820 else {
821 return;
822 };
823
824 match outbound_tx.try_send(frame) {
825 Ok(_) => {}
826 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
827 self.unregister(conn_id, DisconnectReason::SlowConsumer);
828 }
829 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
830 self.unregister(conn_id, DisconnectReason::SocketError);
831 }
832 }
833 }
834}
835
836fn allow_within_window(start: &mut Instant, count: &mut u32, max_per_sec: u32) -> bool {
837 let now = Instant::now();
838 if now.duration_since(*start).as_secs() >= 1 {
839 *start = now;
840 *count = 0;
841 }
842 if *count >= max_per_sec {
843 return false;
844 }
845 *count += 1;
846 true
847}
848
849fn should_echo_to_sender(channel: &ChannelName) -> bool {
850 channel.as_str().starts_with("echo:")
851}
852
853fn spawn_inbound_dispatcher(
854 mut inbound_rx: mpsc::Receiver<InboundMessage>,
855 channel_handlers: ChannelHandlers,
856 global_handlers: GlobalHandlers,
857 channel_event_handlers: ChannelEventHandlers,
858 global_event_handlers: GlobalEventHandlers,
859) {
860 tokio::spawn(async move {
861 while let Some(message) = inbound_rx.recv().await {
862 dispatch_channel_handlers(&channel_handlers, &message.channel, &message.payload);
863 dispatch_global_handlers(&global_handlers, &message.channel, &message.payload);
864 dispatch_channel_event_handlers(
865 &channel_event_handlers,
866 &message.channel,
867 &message.event,
868 &message.payload,
869 );
870 dispatch_global_event_handlers(
871 &global_event_handlers,
872 &message.channel,
873 &message.event,
874 &message.payload,
875 );
876 }
877 });
878}
879
880fn dispatch_channel_handlers(handlers: &ChannelHandlers, channel: &str, message: &Payload) {
881 let callbacks: Vec<ChannelHandler> = {
882 let guard = handlers.lock().expect("channel handler mutex poisoned");
883 guard
884 .get(channel)
885 .map(|entries| entries.values().cloned().collect())
886 .unwrap_or_default()
887 };
888
889 for callback in callbacks {
890 callback(message.clone());
891 }
892}
893
894fn dispatch_global_handlers(handlers: &GlobalHandlers, channel: &str, message: &Payload) {
895 let callbacks: Vec<GlobalHandler> = {
896 let guard = handlers.lock().expect("global handler mutex poisoned");
897 guard.values().cloned().collect()
898 };
899
900 for callback in callbacks {
901 callback(channel.to_string(), message.clone());
902 }
903}
904
905fn dispatch_channel_event_handlers(
906 handlers: &ChannelEventHandlers,
907 channel: &str,
908 event: &str,
909 message: &Payload,
910) {
911 let callbacks: Vec<ChannelEventHandler> = {
912 let guard = handlers
913 .lock()
914 .expect("channel event handler mutex poisoned");
915 guard
916 .get(channel)
917 .map(|entries| entries.values().cloned().collect())
918 .unwrap_or_default()
919 };
920
921 for callback in callbacks {
922 callback(event.to_string(), message.clone());
923 }
924}
925
926fn dispatch_global_event_handlers(
927 handlers: &GlobalEventHandlers,
928 channel: &str,
929 event: &str,
930 message: &Payload,
931) {
932 let callbacks: Vec<GlobalEventHandler> = {
933 let guard = handlers
934 .lock()
935 .expect("global event handler mutex poisoned");
936 guard.values().cloned().collect()
937 };
938
939 for callback in callbacks {
940 callback(channel.to_string(), event.to_string(), message.clone());
941 }
942}
943
944#[cfg(test)]
945mod tests {
946 use std::{
947 collections::HashMap,
948 sync::{
949 Arc,
950 atomic::{AtomicUsize, Ordering},
951 },
952 };
953
954 use serde_json::json;
955
956 use super::{
957 ChannelEventHandlers, ChannelHandlers, GlobalEventHandlers, GlobalHandlers,
958 dispatch_channel_event_handlers, dispatch_channel_handlers, dispatch_global_event_handlers,
959 dispatch_global_handlers, should_echo_to_sender,
960 };
961 use crate::server::ChannelName;
962
963 #[test]
964 fn echo_channel_includes_sender() {
965 let channel = ChannelName::parse("echo:room").expect("channel should parse");
966 assert!(should_echo_to_sender(&channel));
967 }
968
969 #[test]
970 fn non_echo_channel_excludes_sender() {
971 let channel = ChannelName::parse("public:lobby").expect("channel should parse");
972 assert!(!should_echo_to_sender(&channel));
973 }
974
975 #[test]
976 fn dispatch_channel_handlers_only_targets_matching_channel() {
977 let handlers: ChannelHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
978 let count = Arc::new(AtomicUsize::new(0));
979 let count_for_handler = Arc::clone(&count);
980 handlers
981 .lock()
982 .expect("channel handlers lock")
983 .entry("chat:room:1".to_string())
984 .or_default()
985 .insert(
986 1,
987 Arc::new(move |_| {
988 count_for_handler.fetch_add(1, Ordering::Relaxed);
989 }),
990 );
991
992 dispatch_channel_handlers(&handlers, "chat:room:1", &json!({"text":"hello"}));
993 dispatch_channel_handlers(&handlers, "chat:room:2", &json!({"text":"hello"}));
994
995 assert_eq!(count.load(Ordering::Relaxed), 1);
996 }
997
998 #[test]
999 fn dispatch_global_handlers_receives_channel_and_message() {
1000 let handlers: GlobalHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
1001 let count = Arc::new(AtomicUsize::new(0));
1002 let count_for_handler = Arc::clone(&count);
1003 handlers.lock().expect("global handlers lock").insert(
1004 1,
1005 Arc::new(move |channel, payload| {
1006 assert_eq!(channel, "chat:room:1");
1007 assert_eq!(payload["text"], "hello");
1008 count_for_handler.fetch_add(1, Ordering::Relaxed);
1009 }),
1010 );
1011
1012 dispatch_global_handlers(&handlers, "chat:room:1", &json!({"text":"hello"}));
1013
1014 assert_eq!(count.load(Ordering::Relaxed), 1);
1015 }
1016
1017 #[test]
1018 fn dispatch_channel_event_handlers_receives_event_name() {
1019 let handlers: ChannelEventHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
1020 let count = Arc::new(AtomicUsize::new(0));
1021 let count_for_handler = Arc::clone(&count);
1022 handlers
1023 .lock()
1024 .expect("channel event handlers lock")
1025 .entry("chat:room:1".to_string())
1026 .or_default()
1027 .insert(
1028 1,
1029 Arc::new(move |event, payload| {
1030 assert_eq!(event, "chat.typing");
1031 assert_eq!(payload["typing"], true);
1032 count_for_handler.fetch_add(1, Ordering::Relaxed);
1033 }),
1034 );
1035
1036 dispatch_channel_event_handlers(
1037 &handlers,
1038 "chat:room:1",
1039 "chat.typing",
1040 &json!({"typing": true}),
1041 );
1042
1043 assert_eq!(count.load(Ordering::Relaxed), 1);
1044 }
1045
1046 #[test]
1047 fn dispatch_global_event_handlers_receives_channel_event_and_message() {
1048 let handlers: GlobalEventHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
1049 let count = Arc::new(AtomicUsize::new(0));
1050 let count_for_handler = Arc::clone(&count);
1051 handlers.lock().expect("global event handlers lock").insert(
1052 1,
1053 Arc::new(move |channel, event, payload| {
1054 assert_eq!(channel, "chat:room:1");
1055 assert_eq!(event, "chat.message");
1056 assert_eq!(payload["text"], "hello");
1057 count_for_handler.fetch_add(1, Ordering::Relaxed);
1058 }),
1059 );
1060
1061 dispatch_global_event_handlers(
1062 &handlers,
1063 "chat:room:1",
1064 "chat.message",
1065 &json!({"text":"hello"}),
1066 );
1067
1068 assert_eq!(count.load(Ordering::Relaxed), 1);
1069 }
1070}