Skip to main content

realtime/server/
hub.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::{
4        Arc,
5        atomic::{AtomicU64, Ordering},
6    },
7    time::Instant,
8};
9
10use chrono::Utc;
11use tokio::sync::mpsc;
12
13use crate::protocol::{DEFAULT_EVENT, ServerFrame};
14
15use super::{
16    Channel, ChannelName, ConnectionId, ConnectionMeta, DisconnectReason, Event, Payload,
17    RealtimeConfig, RealtimeError, SessionAuth, UserId,
18    policy::{ChannelPolicy, DefaultChannelPolicy},
19    session,
20};
21
22const HUB_QUEUE_SIZE: usize = 4096;
23const INBOUND_QUEUE_SIZE: usize = 4096;
24
25pub type SubscriptionId = u64;
26type ChannelHandler = Arc<dyn Fn(Payload) + Send + Sync>;
27type GlobalHandler = Arc<dyn Fn(Channel, Payload) + Send + Sync>;
28type ChannelEventHandler = Arc<dyn Fn(Event, Payload) + Send + Sync>;
29type GlobalEventHandler = Arc<dyn Fn(Channel, Event, Payload) + Send + Sync>;
30type ChannelHandlers =
31    Arc<std::sync::Mutex<HashMap<Channel, HashMap<SubscriptionId, ChannelHandler>>>>;
32type GlobalHandlers = Arc<std::sync::Mutex<HashMap<SubscriptionId, GlobalHandler>>>;
33type ChannelEventHandlers =
34    Arc<std::sync::Mutex<HashMap<Channel, HashMap<SubscriptionId, ChannelEventHandler>>>>;
35type GlobalEventHandlers = Arc<std::sync::Mutex<HashMap<SubscriptionId, GlobalEventHandler>>>;
36
37#[derive(Clone)]
38pub(crate) struct InboundMessage {
39    pub channel: Channel,
40    pub event: Event,
41    pub payload: Payload,
42}
43
44#[derive(Clone)]
45pub struct SocketServerHandle {
46    config: RealtimeConfig,
47    tx: Option<mpsc::Sender<HubCommand>>,
48    channel_handlers: ChannelHandlers,
49    global_handlers: GlobalHandlers,
50    channel_event_handlers: ChannelEventHandlers,
51    global_event_handlers: GlobalEventHandlers,
52    next_subscription_id: Arc<AtomicU64>,
53}
54
55impl SocketServerHandle {
56    pub fn spawn(config: RealtimeConfig) -> Self {
57        Self::spawn_with_policy(config, Arc::new(DefaultChannelPolicy))
58    }
59
60    pub fn spawn_with_policy(config: RealtimeConfig, policy: Arc<dyn ChannelPolicy>) -> Self {
61        let channel_handlers: ChannelHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
62        let global_handlers: GlobalHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
63        let channel_event_handlers: ChannelEventHandlers =
64            Arc::new(std::sync::Mutex::new(HashMap::new()));
65        let global_event_handlers: GlobalEventHandlers =
66            Arc::new(std::sync::Mutex::new(HashMap::new()));
67        let next_subscription_id = Arc::new(AtomicU64::new(1));
68
69        if !config.enabled {
70            return Self {
71                config,
72                tx: None,
73                channel_handlers,
74                global_handlers,
75                channel_event_handlers,
76                global_event_handlers,
77                next_subscription_id,
78            };
79        }
80
81        let (tx, rx) = mpsc::channel(HUB_QUEUE_SIZE);
82        let (inbound_tx, inbound_rx) = mpsc::channel(INBOUND_QUEUE_SIZE);
83        let mut hub = SocketServer::new(config.clone(), rx, policy, Some(inbound_tx));
84        tokio::spawn(async move {
85            hub.run().await;
86        });
87        spawn_inbound_dispatcher(
88            inbound_rx,
89            Arc::clone(&channel_handlers),
90            Arc::clone(&global_handlers),
91            Arc::clone(&channel_event_handlers),
92            Arc::clone(&global_event_handlers),
93        );
94
95        Self {
96            config,
97            tx: Some(tx),
98            channel_handlers,
99            global_handlers,
100            channel_event_handlers,
101            global_event_handlers,
102            next_subscription_id,
103        }
104    }
105
106    pub fn disabled(config: RealtimeConfig) -> Self {
107        Self {
108            config,
109            tx: None,
110            channel_handlers: Arc::new(std::sync::Mutex::new(HashMap::new())),
111            global_handlers: Arc::new(std::sync::Mutex::new(HashMap::new())),
112            channel_event_handlers: Arc::new(std::sync::Mutex::new(HashMap::new())),
113            global_event_handlers: Arc::new(std::sync::Mutex::new(HashMap::new())),
114            next_subscription_id: Arc::new(AtomicU64::new(1)),
115        }
116    }
117
118    pub fn is_enabled(&self) -> bool {
119        self.config.enabled && self.tx.is_some()
120    }
121
122    pub fn max_message_bytes(&self) -> usize {
123        self.config.max_message_bytes
124    }
125
126    pub async fn serve_socket(&self, socket: axum::extract::ws::WebSocket, auth: SessionAuth) {
127        let Some(hub_tx) = self.tx.clone() else {
128            return;
129        };
130        session::run_socket_session(socket, auth, hub_tx, self.config.clone()).await;
131    }
132
133    pub async fn send(
134        &self,
135        channel_name: impl Into<Channel>,
136        message: Payload,
137    ) -> Result<(), RealtimeError> {
138        self.send_event(channel_name, DEFAULT_EVENT, message).await
139    }
140
141    pub async fn send_to_user(
142        &self,
143        user_id: impl Into<UserId>,
144        message: Payload,
145    ) -> Result<(), RealtimeError> {
146        self.send_event_to_user(user_id, DEFAULT_EVENT, message)
147            .await
148    }
149
150    pub async fn send_event(
151        &self,
152        channel_name: impl Into<Channel>,
153        event: impl Into<Event>,
154        payload: Payload,
155    ) -> Result<(), RealtimeError> {
156        let Some(tx) = &self.tx else {
157            return Ok(());
158        };
159        let channel_name = channel_name.into();
160        let channel = ChannelName::parse(&channel_name)?;
161        tx.send(HubCommand::SendToChannel {
162            channel,
163            event: event.into(),
164            payload,
165        })
166        .await
167        .map_err(|_| RealtimeError::internal("realtime hub is unavailable"))
168    }
169
170    pub async fn send_event_to_user(
171        &self,
172        user_id: impl Into<UserId>,
173        event: impl Into<Event>,
174        payload: Payload,
175    ) -> Result<(), RealtimeError> {
176        let Some(tx) = &self.tx else {
177            return Ok(());
178        };
179        tx.send(HubCommand::SendToUser {
180            user_id: user_id.into(),
181            event: event.into(),
182            payload,
183        })
184        .await
185        .map_err(|_| RealtimeError::internal("realtime hub is unavailable"))
186    }
187
188    pub async fn emit_to_user(
189        &self,
190        user_id: impl Into<UserId>,
191        event: impl Into<Event>,
192        payload: Payload,
193    ) -> Result<(), RealtimeError> {
194        self.send_event_to_user(user_id, event, payload).await
195    }
196
197    pub fn on_message<F>(&self, channel: &str, handler: F) -> SubscriptionId
198    where
199        F: Fn(Payload) + Send + Sync + 'static,
200    {
201        let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
202        let mut handlers = self
203            .channel_handlers
204            .lock()
205            .expect("channel handler mutex poisoned");
206        handlers
207            .entry(channel.to_string())
208            .or_default()
209            .insert(id, Arc::new(handler));
210        id
211    }
212
213    pub fn on_messages<F>(&self, handler: F) -> SubscriptionId
214    where
215        F: Fn(Channel, Payload) + Send + Sync + 'static,
216    {
217        let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
218        self.global_handlers
219            .lock()
220            .expect("global handler mutex poisoned")
221            .insert(id, Arc::new(handler));
222        id
223    }
224
225    pub fn on_channel_event<F>(&self, channel: &str, handler: F) -> SubscriptionId
226    where
227        F: Fn(Event, Payload) + Send + Sync + 'static,
228    {
229        let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
230        let mut handlers = self
231            .channel_event_handlers
232            .lock()
233            .expect("channel event handler mutex poisoned");
234        handlers
235            .entry(channel.to_string())
236            .or_default()
237            .insert(id, Arc::new(handler));
238        id
239    }
240
241    pub fn on_events<F>(&self, handler: F) -> SubscriptionId
242    where
243        F: Fn(Channel, Event, Payload) + Send + Sync + 'static,
244    {
245        let id = self.next_subscription_id.fetch_add(1, Ordering::Relaxed);
246        self.global_event_handlers
247            .lock()
248            .expect("global event handler mutex poisoned")
249            .insert(id, Arc::new(handler));
250        id
251    }
252
253    pub fn off(&self, id: SubscriptionId) -> bool {
254        let mut removed = false;
255
256        let mut global = self
257            .global_handlers
258            .lock()
259            .expect("global handler mutex poisoned");
260        if global.remove(&id).is_some() {
261            removed = true;
262        }
263        drop(global);
264
265        let mut channels = self
266            .channel_handlers
267            .lock()
268            .expect("channel handler mutex poisoned");
269        for handlers in channels.values_mut() {
270            if handlers.remove(&id).is_some() {
271                removed = true;
272            }
273        }
274
275        let mut global_events = self
276            .global_event_handlers
277            .lock()
278            .expect("global event handler mutex poisoned");
279        if global_events.remove(&id).is_some() {
280            removed = true;
281        }
282        drop(global_events);
283
284        let mut channel_events = self
285            .channel_event_handlers
286            .lock()
287            .expect("channel event handler mutex poisoned");
288        for handlers in channel_events.values_mut() {
289            if handlers.remove(&id).is_some() {
290                removed = true;
291            }
292        }
293
294        removed
295    }
296}
297
298pub(crate) enum HubCommand {
299    Register {
300        meta: ConnectionMeta,
301        outbound_tx: mpsc::Sender<ServerFrame>,
302    },
303    Unregister {
304        conn_id: ConnectionId,
305        reason: DisconnectReason,
306    },
307    Join {
308        conn_id: ConnectionId,
309        channel: ChannelName,
310        req_id: String,
311    },
312    Leave {
313        conn_id: ConnectionId,
314        channel: ChannelName,
315        req_id: String,
316    },
317    Emit {
318        conn_id: ConnectionId,
319        channel: ChannelName,
320        event: Event,
321        payload: Payload,
322        req_id: String,
323    },
324    Ping {
325        conn_id: ConnectionId,
326        req_id: String,
327    },
328    SendToChannel {
329        channel: ChannelName,
330        event: Event,
331        payload: Payload,
332    },
333    SendToUser {
334        user_id: UserId,
335        event: Event,
336        payload: Payload,
337    },
338}
339
340struct SocketServer {
341    config: RealtimeConfig,
342    rx: mpsc::Receiver<HubCommand>,
343    policy: Arc<dyn ChannelPolicy>,
344    inbound_tx: Option<mpsc::Sender<InboundMessage>>,
345    connections: HashMap<ConnectionId, ConnectionState>,
346    users: HashMap<UserId, HashSet<ConnectionId>>,
347    channels: HashMap<ChannelName, HashSet<ConnectionId>>,
348    connection_channels: HashMap<ConnectionId, HashSet<ChannelName>>,
349}
350
351struct ConnectionState {
352    meta: ConnectionMeta,
353    outbound_tx: mpsc::Sender<ServerFrame>,
354    rate: ConnectionRateState,
355}
356
357struct ConnectionRateState {
358    join_window_started_at: Instant,
359    joins_in_window: u32,
360    emit_window_started_at: Instant,
361    emits_in_window: u32,
362}
363
364impl SocketServer {
365    fn new(
366        config: RealtimeConfig,
367        rx: mpsc::Receiver<HubCommand>,
368        policy: Arc<dyn ChannelPolicy>,
369        inbound_tx: Option<mpsc::Sender<InboundMessage>>,
370    ) -> Self {
371        Self {
372            config,
373            rx,
374            policy,
375            inbound_tx,
376            connections: HashMap::new(),
377            users: HashMap::new(),
378            channels: HashMap::new(),
379            connection_channels: HashMap::new(),
380        }
381    }
382
383    async fn run(&mut self) {
384        while let Some(command) = self.rx.recv().await {
385            self.handle_command(command);
386        }
387    }
388
389    fn handle_command(&mut self, command: HubCommand) {
390        match command {
391            HubCommand::Register { meta, outbound_tx } => self.register(meta, outbound_tx),
392            HubCommand::Unregister { conn_id, reason } => self.unregister(conn_id, reason),
393            HubCommand::Join {
394                conn_id,
395                channel,
396                req_id,
397            } => self.handle_join(conn_id, channel, req_id),
398            HubCommand::Leave {
399                conn_id,
400                channel,
401                req_id,
402            } => self.handle_leave(conn_id, channel, req_id),
403            HubCommand::Emit {
404                conn_id,
405                channel,
406                event,
407                payload,
408                req_id,
409            } => self.handle_emit(conn_id, channel, event, payload, req_id),
410            HubCommand::Ping { conn_id, req_id } => self.handle_ping(conn_id, req_id),
411            HubCommand::SendToChannel {
412                channel,
413                event,
414                payload,
415            } => self.handle_send_to_channel(channel, event, payload),
416            HubCommand::SendToUser {
417                user_id,
418                event,
419                payload,
420            } => self.handle_send_to_user(user_id, event, payload),
421        }
422    }
423
424    fn register(&mut self, meta: ConnectionMeta, outbound_tx: mpsc::Sender<ServerFrame>) {
425        if self.connections.len() >= self.config.max_connections {
426            let _ = outbound_tx.try_send(ServerFrame::error(
427                "capacity_exceeded",
428                "Realtime server is at capacity",
429            ));
430            return;
431        }
432
433        let conn_id = meta.id;
434        let user_id = meta.user_id.clone();
435        let now = Instant::now();
436
437        self.connections.insert(
438            conn_id,
439            ConnectionState {
440                meta,
441                outbound_tx: outbound_tx.clone(),
442                rate: ConnectionRateState {
443                    join_window_started_at: now,
444                    joins_in_window: 0,
445                    emit_window_started_at: now,
446                    emits_in_window: 0,
447                },
448            },
449        );
450
451        self.users
452            .entry(user_id.clone())
453            .or_default()
454            .insert(conn_id);
455
456        let _ = outbound_tx.try_send(ServerFrame::connected(conn_id.to_string(), user_id.clone()));
457
458        let private_channel = ChannelName(format!("user:{user_id}"));
459        if let Some(reason) = self.join_internal(conn_id, private_channel.clone()) {
460            self.unregister(conn_id, reason);
461            return;
462        }
463        self.send_frame(
464            conn_id,
465            ServerFrame::Joined {
466                id: uuid::Uuid::new_v4().to_string(),
467                channel: private_channel.to_string(),
468                ts: Utc::now().timestamp(),
469            },
470        );
471    }
472
473    fn unregister(&mut self, conn_id: ConnectionId, reason: DisconnectReason) {
474        let Some(existing) = self.connections.remove(&conn_id) else {
475            return;
476        };
477
478        tracing::debug!(
479            conn_id = %conn_id,
480            user_id = %existing.meta.user_id,
481            reason = ?reason,
482            "realtime connection disconnected"
483        );
484
485        if let Some(user_set) = self.users.get_mut(&existing.meta.user_id) {
486            user_set.remove(&conn_id);
487            if user_set.is_empty() {
488                self.users.remove(&existing.meta.user_id);
489            }
490        }
491
492        if let Some(joined) = self.connection_channels.remove(&conn_id) {
493            for channel in joined {
494                if let Some(member_set) = self.channels.get_mut(&channel) {
495                    member_set.remove(&conn_id);
496                    if member_set.is_empty() {
497                        self.channels.remove(&channel);
498                    }
499                }
500            }
501        }
502    }
503
504    fn handle_join(&mut self, conn_id: ConnectionId, channel: ChannelName, req_id: String) {
505        tracing::debug!(
506            conn_id = %conn_id,
507            channel = %channel,
508            req_id = %req_id,
509            "realtime join requested"
510        );
511
512        if !self.check_join_rate(conn_id) {
513            tracing::debug!(
514                conn_id = %conn_id,
515                channel = %channel,
516                req_id = %req_id,
517                "realtime join denied: rate limited"
518            );
519            self.send_frame(
520                conn_id,
521                ServerFrame::ack_err(req_id, "rate_limited", "Join rate limit exceeded"),
522            );
523            return;
524        }
525
526        let Some(meta) = self.connections.get(&conn_id).map(|conn| conn.meta.clone()) else {
527            return;
528        };
529
530        if let Err(err) = self.policy.can_join(&meta, &channel) {
531            tracing::debug!(
532                conn_id = %conn_id,
533                user_id = %meta.user_id,
534                channel = %channel,
535                req_id = %req_id,
536                reason = %err,
537                "realtime join denied by policy"
538            );
539            self.send_frame(
540                conn_id,
541                ServerFrame::ack_err(req_id, "forbidden_channel", err.message()),
542            );
543            return;
544        }
545
546        if self
547            .connection_channels
548            .get(&conn_id)
549            .is_some_and(|set| set.contains(&channel))
550        {
551            tracing::debug!(
552                conn_id = %conn_id,
553                channel = %channel,
554                req_id = %req_id,
555                "realtime join acknowledged: already joined"
556            );
557            self.send_frame(conn_id, ServerFrame::ack_ok(req_id));
558            return;
559        }
560
561        if self
562            .connection_channels
563            .get(&conn_id)
564            .map(|set| set.len())
565            .unwrap_or(0)
566            >= self.config.max_channels_per_connection
567        {
568            tracing::debug!(
569                conn_id = %conn_id,
570                channel = %channel,
571                req_id = %req_id,
572                "realtime join denied: channel limit exceeded"
573            );
574            self.send_frame(
575                conn_id,
576                ServerFrame::ack_err(
577                    req_id,
578                    "channel_limit_exceeded",
579                    "Maximum channels per connection reached",
580                ),
581            );
582            return;
583        }
584
585        if let Some(reason) = self.join_internal(conn_id, channel.clone()) {
586            tracing::debug!(
587                conn_id = %conn_id,
588                channel = %channel,
589                reason = ?reason,
590                "realtime join caused disconnect"
591            );
592            self.unregister(conn_id, reason);
593            return;
594        }
595
596        tracing::debug!(
597            conn_id = %conn_id,
598            user_id = %meta.user_id,
599            channel = %channel,
600            req_id = %req_id,
601            "realtime join succeeded"
602        );
603
604        self.send_frame(conn_id, ServerFrame::ack_ok(req_id));
605        self.send_frame(
606            conn_id,
607            ServerFrame::Joined {
608                id: uuid::Uuid::new_v4().to_string(),
609                channel: channel.to_string(),
610                ts: Utc::now().timestamp(),
611            },
612        );
613    }
614
615    fn handle_leave(&mut self, conn_id: ConnectionId, channel: ChannelName, req_id: String) {
616        tracing::debug!(
617            conn_id = %conn_id,
618            channel = %channel,
619            req_id = %req_id,
620            "realtime leave requested"
621        );
622        let was_member = self
623            .connection_channels
624            .get(&conn_id)
625            .is_some_and(|set| set.contains(&channel));
626        if !was_member {
627            tracing::debug!(
628                conn_id = %conn_id,
629                channel = %channel,
630                req_id = %req_id,
631                "realtime leave denied: not joined"
632            );
633            self.send_frame(
634                conn_id,
635                ServerFrame::ack_err(req_id, "channel_not_joined", "Not a member of channel"),
636            );
637            return;
638        }
639
640        self.leave_internal(conn_id, &channel);
641        tracing::debug!(
642            conn_id = %conn_id,
643            channel = %channel,
644            req_id = %req_id,
645            "realtime leave succeeded"
646        );
647        self.send_frame(conn_id, ServerFrame::ack_ok(req_id));
648        self.send_frame(
649            conn_id,
650            ServerFrame::Left {
651                id: uuid::Uuid::new_v4().to_string(),
652                channel: channel.to_string(),
653                ts: Utc::now().timestamp(),
654            },
655        );
656    }
657
658    fn handle_emit(
659        &mut self,
660        conn_id: ConnectionId,
661        channel: ChannelName,
662        event: Event,
663        payload: Payload,
664        req_id: String,
665    ) {
666        if !self.check_emit_rate(conn_id) {
667            self.send_frame(
668                conn_id,
669                ServerFrame::ack_err(req_id, "rate_limited", "Emit rate limit exceeded"),
670            );
671            return;
672        }
673
674        let Some(meta) = self.connections.get(&conn_id).map(|conn| conn.meta.clone()) else {
675            return;
676        };
677
678        if let Err(err) = self.policy.can_publish(&meta, &channel, &event) {
679            self.send_frame(
680                conn_id,
681                ServerFrame::ack_err(req_id, "forbidden_channel", err.message()),
682            );
683            return;
684        }
685
686        let sender_is_member = self
687            .connection_channels
688            .get(&conn_id)
689            .is_some_and(|set| set.contains(&channel));
690        if !sender_is_member {
691            self.send_frame(
692                conn_id,
693                ServerFrame::ack_err(req_id, "channel_not_joined", "Join channel before emitting"),
694            );
695            return;
696        }
697
698        let recipients = self.channels.get(&channel).cloned().unwrap_or_default();
699        let include_sender = should_echo_to_sender(&channel);
700        self.publish_inbound(InboundMessage {
701            channel: channel.to_string(),
702            event: event.clone(),
703            payload: payload.clone(),
704        });
705        let event_frame = ServerFrame::event(
706            channel.to_string(),
707            event,
708            payload,
709            Some(meta.user_id.clone()),
710        );
711        for recipient_id in recipients {
712            if recipient_id == conn_id && !include_sender {
713                continue;
714            }
715            self.send_frame(recipient_id, event_frame.clone());
716        }
717
718        self.send_frame(conn_id, ServerFrame::ack_ok(req_id));
719    }
720
721    fn handle_ping(&mut self, conn_id: ConnectionId, req_id: String) {
722        self.send_frame(conn_id, ServerFrame::pong(req_id));
723    }
724
725    fn handle_send_to_channel(&mut self, channel: ChannelName, event: Event, payload: Payload) {
726        let Some(conn_ids) = self.channels.get(&channel).cloned() else {
727            return;
728        };
729
730        let frame = ServerFrame::event(channel.to_string(), event, payload, None);
731        for conn_id in conn_ids {
732            self.send_frame(conn_id, frame.clone());
733        }
734    }
735
736    fn handle_send_to_user(&mut self, user_id: UserId, event: Event, payload: Payload) {
737        let Some(conn_ids) = self.users.get(&user_id).cloned() else {
738            return;
739        };
740
741        let channel = format!("user:{user_id}");
742        let frame = ServerFrame::event(channel, event, payload, None);
743        for conn_id in conn_ids {
744            self.send_frame(conn_id, frame.clone());
745        }
746    }
747
748    fn publish_inbound(&mut self, message: InboundMessage) {
749        let Some(tx) = &self.inbound_tx else {
750            return;
751        };
752
753        match tx.try_send(message) {
754            Ok(_) => {}
755            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
756                tracing::debug!("realtime inbound dispatch queue is full; dropping message");
757            }
758            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
759                self.inbound_tx = None;
760            }
761        }
762    }
763
764    fn check_join_rate(&mut self, conn_id: ConnectionId) -> bool {
765        let Some(state) = self.connections.get_mut(&conn_id) else {
766            return false;
767        };
768        allow_within_window(
769            &mut state.rate.join_window_started_at,
770            &mut state.rate.joins_in_window,
771            self.config.join_rate_per_sec,
772        )
773    }
774
775    fn check_emit_rate(&mut self, conn_id: ConnectionId) -> bool {
776        let Some(state) = self.connections.get_mut(&conn_id) else {
777            return false;
778        };
779        allow_within_window(
780            &mut state.rate.emit_window_started_at,
781            &mut state.rate.emits_in_window,
782            self.config.emit_rate_per_sec,
783        )
784    }
785
786    fn join_internal(
787        &mut self,
788        conn_id: ConnectionId,
789        channel: ChannelName,
790    ) -> Option<DisconnectReason> {
791        self.connection_channels
792            .entry(conn_id)
793            .or_default()
794            .insert(channel.clone());
795        self.channels.entry(channel).or_default().insert(conn_id);
796        None
797    }
798
799    fn leave_internal(&mut self, conn_id: ConnectionId, channel: &ChannelName) {
800        if let Some(set) = self.connection_channels.get_mut(&conn_id) {
801            set.remove(channel);
802            if set.is_empty() {
803                self.connection_channels.remove(&conn_id);
804            }
805        }
806
807        if let Some(set) = self.channels.get_mut(channel) {
808            set.remove(&conn_id);
809            if set.is_empty() {
810                self.channels.remove(channel);
811            }
812        }
813    }
814
815    fn send_frame(&mut self, conn_id: ConnectionId, frame: ServerFrame) {
816        let Some(outbound_tx) = self
817            .connections
818            .get(&conn_id)
819            .map(|connection| connection.outbound_tx.clone())
820        else {
821            return;
822        };
823
824        match outbound_tx.try_send(frame) {
825            Ok(_) => {}
826            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
827                self.unregister(conn_id, DisconnectReason::SlowConsumer);
828            }
829            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
830                self.unregister(conn_id, DisconnectReason::SocketError);
831            }
832        }
833    }
834}
835
836fn allow_within_window(start: &mut Instant, count: &mut u32, max_per_sec: u32) -> bool {
837    let now = Instant::now();
838    if now.duration_since(*start).as_secs() >= 1 {
839        *start = now;
840        *count = 0;
841    }
842    if *count >= max_per_sec {
843        return false;
844    }
845    *count += 1;
846    true
847}
848
849fn should_echo_to_sender(channel: &ChannelName) -> bool {
850    channel.as_str().starts_with("echo:")
851}
852
853fn spawn_inbound_dispatcher(
854    mut inbound_rx: mpsc::Receiver<InboundMessage>,
855    channel_handlers: ChannelHandlers,
856    global_handlers: GlobalHandlers,
857    channel_event_handlers: ChannelEventHandlers,
858    global_event_handlers: GlobalEventHandlers,
859) {
860    tokio::spawn(async move {
861        while let Some(message) = inbound_rx.recv().await {
862            dispatch_channel_handlers(&channel_handlers, &message.channel, &message.payload);
863            dispatch_global_handlers(&global_handlers, &message.channel, &message.payload);
864            dispatch_channel_event_handlers(
865                &channel_event_handlers,
866                &message.channel,
867                &message.event,
868                &message.payload,
869            );
870            dispatch_global_event_handlers(
871                &global_event_handlers,
872                &message.channel,
873                &message.event,
874                &message.payload,
875            );
876        }
877    });
878}
879
880fn dispatch_channel_handlers(handlers: &ChannelHandlers, channel: &str, message: &Payload) {
881    let callbacks: Vec<ChannelHandler> = {
882        let guard = handlers.lock().expect("channel handler mutex poisoned");
883        guard
884            .get(channel)
885            .map(|entries| entries.values().cloned().collect())
886            .unwrap_or_default()
887    };
888
889    for callback in callbacks {
890        callback(message.clone());
891    }
892}
893
894fn dispatch_global_handlers(handlers: &GlobalHandlers, channel: &str, message: &Payload) {
895    let callbacks: Vec<GlobalHandler> = {
896        let guard = handlers.lock().expect("global handler mutex poisoned");
897        guard.values().cloned().collect()
898    };
899
900    for callback in callbacks {
901        callback(channel.to_string(), message.clone());
902    }
903}
904
905fn dispatch_channel_event_handlers(
906    handlers: &ChannelEventHandlers,
907    channel: &str,
908    event: &str,
909    message: &Payload,
910) {
911    let callbacks: Vec<ChannelEventHandler> = {
912        let guard = handlers
913            .lock()
914            .expect("channel event handler mutex poisoned");
915        guard
916            .get(channel)
917            .map(|entries| entries.values().cloned().collect())
918            .unwrap_or_default()
919    };
920
921    for callback in callbacks {
922        callback(event.to_string(), message.clone());
923    }
924}
925
926fn dispatch_global_event_handlers(
927    handlers: &GlobalEventHandlers,
928    channel: &str,
929    event: &str,
930    message: &Payload,
931) {
932    let callbacks: Vec<GlobalEventHandler> = {
933        let guard = handlers
934            .lock()
935            .expect("global event handler mutex poisoned");
936        guard.values().cloned().collect()
937    };
938
939    for callback in callbacks {
940        callback(channel.to_string(), event.to_string(), message.clone());
941    }
942}
943
944#[cfg(test)]
945mod tests {
946    use std::{
947        collections::HashMap,
948        sync::{
949            Arc,
950            atomic::{AtomicUsize, Ordering},
951        },
952    };
953
954    use serde_json::json;
955
956    use super::{
957        ChannelEventHandlers, ChannelHandlers, GlobalEventHandlers, GlobalHandlers,
958        dispatch_channel_event_handlers, dispatch_channel_handlers, dispatch_global_event_handlers,
959        dispatch_global_handlers, should_echo_to_sender,
960    };
961    use crate::server::ChannelName;
962
963    #[test]
964    fn echo_channel_includes_sender() {
965        let channel = ChannelName::parse("echo:room").expect("channel should parse");
966        assert!(should_echo_to_sender(&channel));
967    }
968
969    #[test]
970    fn non_echo_channel_excludes_sender() {
971        let channel = ChannelName::parse("public:lobby").expect("channel should parse");
972        assert!(!should_echo_to_sender(&channel));
973    }
974
975    #[test]
976    fn dispatch_channel_handlers_only_targets_matching_channel() {
977        let handlers: ChannelHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
978        let count = Arc::new(AtomicUsize::new(0));
979        let count_for_handler = Arc::clone(&count);
980        handlers
981            .lock()
982            .expect("channel handlers lock")
983            .entry("chat:room:1".to_string())
984            .or_default()
985            .insert(
986                1,
987                Arc::new(move |_| {
988                    count_for_handler.fetch_add(1, Ordering::Relaxed);
989                }),
990            );
991
992        dispatch_channel_handlers(&handlers, "chat:room:1", &json!({"text":"hello"}));
993        dispatch_channel_handlers(&handlers, "chat:room:2", &json!({"text":"hello"}));
994
995        assert_eq!(count.load(Ordering::Relaxed), 1);
996    }
997
998    #[test]
999    fn dispatch_global_handlers_receives_channel_and_message() {
1000        let handlers: GlobalHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
1001        let count = Arc::new(AtomicUsize::new(0));
1002        let count_for_handler = Arc::clone(&count);
1003        handlers.lock().expect("global handlers lock").insert(
1004            1,
1005            Arc::new(move |channel, payload| {
1006                assert_eq!(channel, "chat:room:1");
1007                assert_eq!(payload["text"], "hello");
1008                count_for_handler.fetch_add(1, Ordering::Relaxed);
1009            }),
1010        );
1011
1012        dispatch_global_handlers(&handlers, "chat:room:1", &json!({"text":"hello"}));
1013
1014        assert_eq!(count.load(Ordering::Relaxed), 1);
1015    }
1016
1017    #[test]
1018    fn dispatch_channel_event_handlers_receives_event_name() {
1019        let handlers: ChannelEventHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
1020        let count = Arc::new(AtomicUsize::new(0));
1021        let count_for_handler = Arc::clone(&count);
1022        handlers
1023            .lock()
1024            .expect("channel event handlers lock")
1025            .entry("chat:room:1".to_string())
1026            .or_default()
1027            .insert(
1028                1,
1029                Arc::new(move |event, payload| {
1030                    assert_eq!(event, "chat.typing");
1031                    assert_eq!(payload["typing"], true);
1032                    count_for_handler.fetch_add(1, Ordering::Relaxed);
1033                }),
1034            );
1035
1036        dispatch_channel_event_handlers(
1037            &handlers,
1038            "chat:room:1",
1039            "chat.typing",
1040            &json!({"typing": true}),
1041        );
1042
1043        assert_eq!(count.load(Ordering::Relaxed), 1);
1044    }
1045
1046    #[test]
1047    fn dispatch_global_event_handlers_receives_channel_event_and_message() {
1048        let handlers: GlobalEventHandlers = Arc::new(std::sync::Mutex::new(HashMap::new()));
1049        let count = Arc::new(AtomicUsize::new(0));
1050        let count_for_handler = Arc::clone(&count);
1051        handlers.lock().expect("global event handlers lock").insert(
1052            1,
1053            Arc::new(move |channel, event, payload| {
1054                assert_eq!(channel, "chat:room:1");
1055                assert_eq!(event, "chat.message");
1056                assert_eq!(payload["text"], "hello");
1057                count_for_handler.fetch_add(1, Ordering::Relaxed);
1058            }),
1059        );
1060
1061        dispatch_global_event_handlers(
1062            &handlers,
1063            "chat:room:1",
1064            "chat.message",
1065            &json!({"text":"hello"}),
1066        );
1067
1068        assert_eq!(count.load(Ordering::Relaxed), 1);
1069    }
1070}