1use crate::ServerMessage;
2use std::{
3 collections::{BTreeMap, HashMap, HashSet},
4 future::Future,
5 pin::Pin,
6 sync::{Arc, Mutex},
7};
8use tokio::sync::broadcast;
9
10const DEFAULT_TOPIC_CAPACITY: usize = 1024;
11type TopicSenders = HashMap<String, broadcast::Sender<PubSubMessage>>;
12type NodePresenceMap = HashMap<String, HashSet<String>>;
13type TopicPresenceMap = HashMap<String, NodePresenceMap>;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum PubSubDeliveryScope {
18 LocalProcess,
20 Cluster,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum PubSubOrdering {
27 PerTopicOrdered,
29 BestEffort,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum SessionAffinityRequirement {
36 None,
38 StatefulSessionRequired,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct PubSubCapabilities {
45 pub backend: String,
46 pub delivery_scope: PubSubDeliveryScope,
47 pub ordering: PubSubOrdering,
48 pub session_affinity: SessionAffinityRequirement,
49 pub presence_tracking: bool,
50}
51
52impl PubSubCapabilities {
53 fn in_process() -> Self {
54 Self {
55 backend: "in_process".to_string(),
56 delivery_scope: PubSubDeliveryScope::LocalProcess,
57 ordering: PubSubOrdering::PerTopicOrdered,
58 session_affinity: SessionAffinityRequirement::StatefulSessionRequired,
59 presence_tracking: true,
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct PubSubPresenceSnapshot {
67 pub topic: String,
68 pub total_sessions: usize,
69 pub by_node: BTreeMap<String, usize>,
70}
71
72impl PubSubPresenceSnapshot {
73 fn empty(topic: &str) -> Self {
74 Self {
75 topic: topic.to_string(),
76 total_sessions: 0,
77 by_node: BTreeMap::new(),
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
84pub enum PubSubReceiveError {
85 Closed,
86 Lagged(u64),
87}
88
89type PubSubRecvFuture<'a> =
90 Pin<Box<dyn Future<Output = Result<PubSubMessage, PubSubReceiveError>> + Send + 'a>>;
91
92pub trait PubSubSubscriptionHandle: Send {
94 fn recv(&mut self) -> PubSubRecvFuture<'_>;
95}
96
97struct BroadcastSubscriptionHandle {
98 receiver: broadcast::Receiver<PubSubMessage>,
99}
100
101impl PubSubSubscriptionHandle for BroadcastSubscriptionHandle {
102 fn recv(&mut self) -> PubSubRecvFuture<'_> {
103 Box::pin(async move {
104 self.receiver.recv().await.map_err(|err| match err {
105 broadcast::error::RecvError::Closed => PubSubReceiveError::Closed,
106 broadcast::error::RecvError::Lagged(skipped) => PubSubReceiveError::Lagged(skipped),
107 })
108 })
109 }
110}
111
112pub trait PubSubBackend: Send + Sync {
114 fn subscribe(&self, topic: &str) -> PubSubSubscription;
115 fn broadcast(&self, topic: &str, messages: Vec<ServerMessage>) -> usize;
116 fn capabilities(&self) -> PubSubCapabilities;
117
118 fn register_presence(&self, _topic: &str, _session_id: &str, _node_id: &str) {}
119
120 fn unregister_presence(&self, _topic: &str, _session_id: &str, _node_id: &str) {}
121
122 fn presence_snapshot(&self, topic: &str) -> PubSubPresenceSnapshot {
123 PubSubPresenceSnapshot::empty(topic)
124 }
125}
126
127#[derive(Debug)]
128struct InProcessPubSubBackend {
129 topics: Arc<Mutex<TopicSenders>>,
130 presence: Arc<Mutex<TopicPresenceMap>>,
131 topic_capacity: usize,
132}
133
134impl InProcessPubSubBackend {
135 fn new(topic_capacity: usize) -> Self {
136 Self {
137 topics: Arc::new(Mutex::new(HashMap::new())),
138 presence: Arc::new(Mutex::new(HashMap::new())),
139 topic_capacity,
140 }
141 }
142
143 fn sender_for(&self, topic: &str) -> broadcast::Sender<PubSubMessage> {
144 let mut topics = self.topics.lock().expect("pubsub topic mutex poisoned");
145 topics
146 .entry(topic.to_string())
147 .or_insert_with(|| {
148 let (sender, _) = broadcast::channel(self.topic_capacity);
149 sender
150 })
151 .clone()
152 }
153}
154
155impl PubSubBackend for InProcessPubSubBackend {
156 fn subscribe(&self, topic: &str) -> PubSubSubscription {
157 let sender = self.sender_for(topic);
158 PubSubSubscription::new(BroadcastSubscriptionHandle {
159 receiver: sender.subscribe(),
160 })
161 }
162
163 fn broadcast(&self, topic: &str, messages: Vec<ServerMessage>) -> usize {
164 let sender = self.sender_for(topic);
165 sender
166 .send(PubSubMessage {
167 topic: topic.to_string(),
168 messages,
169 })
170 .unwrap_or_default()
171 }
172
173 fn capabilities(&self) -> PubSubCapabilities {
174 PubSubCapabilities::in_process()
175 }
176
177 fn register_presence(&self, topic: &str, session_id: &str, node_id: &str) {
178 let mut presence = self
179 .presence
180 .lock()
181 .expect("pubsub presence mutex poisoned");
182 presence
183 .entry(topic.to_string())
184 .or_default()
185 .entry(node_id.to_string())
186 .or_default()
187 .insert(session_id.to_string());
188 }
189
190 fn unregister_presence(&self, topic: &str, session_id: &str, node_id: &str) {
191 let mut presence = self
192 .presence
193 .lock()
194 .expect("pubsub presence mutex poisoned");
195 let mut remove_topic = false;
196 if let Some(by_node) = presence.get_mut(topic) {
197 if let Some(sessions) = by_node.get_mut(node_id) {
198 sessions.remove(session_id);
199 if sessions.is_empty() {
200 by_node.remove(node_id);
201 }
202 }
203 remove_topic = by_node.is_empty();
204 }
205 if remove_topic {
206 presence.remove(topic);
207 }
208 }
209
210 fn presence_snapshot(&self, topic: &str) -> PubSubPresenceSnapshot {
211 let presence = self
212 .presence
213 .lock()
214 .expect("pubsub presence mutex poisoned");
215 let Some(by_node) = presence.get(topic) else {
216 return PubSubPresenceSnapshot::empty(topic);
217 };
218 let mut snapshot = PubSubPresenceSnapshot {
219 topic: topic.to_string(),
220 total_sessions: 0,
221 by_node: BTreeMap::new(),
222 };
223 for (node_id, sessions) in by_node {
224 snapshot.total_sessions += sessions.len();
225 snapshot.by_node.insert(node_id.clone(), sessions.len());
226 }
227 snapshot
228 }
229}
230
231#[derive(Clone)]
233pub struct PubSub {
234 backend: Arc<dyn PubSubBackend>,
235}
236
237impl std::fmt::Debug for PubSub {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 f.debug_struct("PubSub")
240 .field("capabilities", &self.capabilities())
241 .finish()
242 }
243}
244
245impl Default for PubSub {
246 fn default() -> Self {
247 Self::new(DEFAULT_TOPIC_CAPACITY)
248 }
249}
250
251impl PubSub {
252 pub fn new(topic_capacity: usize) -> Self {
254 Self::with_backend(InProcessPubSubBackend::new(topic_capacity))
255 }
256
257 pub fn with_backend<B>(backend: B) -> Self
259 where
260 B: PubSubBackend + 'static,
261 {
262 Self {
263 backend: Arc::new(backend),
264 }
265 }
266
267 pub fn subscribe(&self, topic: impl Into<String>) -> PubSubSubscription {
269 let topic = topic.into();
270 self.backend.subscribe(&topic)
271 }
272
273 pub fn broadcast(&self, topic: impl Into<String>, messages: Vec<ServerMessage>) -> usize {
275 let topic = topic.into();
276 self.backend.broadcast(&topic, messages)
277 }
278
279 pub fn capabilities(&self) -> PubSubCapabilities {
281 self.backend.capabilities()
282 }
283
284 pub fn register_presence(
286 &self,
287 topic: impl Into<String>,
288 session_id: impl Into<String>,
289 node_id: impl Into<String>,
290 ) {
291 let topic = topic.into();
292 let session_id = session_id.into();
293 let node_id = node_id.into();
294 self.backend
295 .register_presence(&topic, &session_id, &node_id);
296 }
297
298 pub fn unregister_presence(
300 &self,
301 topic: impl Into<String>,
302 session_id: impl Into<String>,
303 node_id: impl Into<String>,
304 ) {
305 let topic = topic.into();
306 let session_id = session_id.into();
307 let node_id = node_id.into();
308 self.backend
309 .unregister_presence(&topic, &session_id, &node_id);
310 }
311
312 pub fn presence_snapshot(&self, topic: impl Into<String>) -> PubSubPresenceSnapshot {
314 let topic = topic.into();
315 self.backend.presence_snapshot(&topic)
316 }
317}
318
319#[derive(Debug, Clone, PartialEq)]
321pub struct PubSubMessage {
322 pub topic: String,
323 pub messages: Vec<ServerMessage>,
324}
325
326pub struct PubSubSubscription {
328 inner: Box<dyn PubSubSubscriptionHandle>,
329}
330
331impl PubSubSubscription {
332 pub fn new<H>(handle: H) -> Self
334 where
335 H: PubSubSubscriptionHandle + 'static,
336 {
337 Self {
338 inner: Box::new(handle),
339 }
340 }
341
342 pub async fn recv(&mut self) -> Result<PubSubMessage, PubSubReceiveError> {
343 self.inner.recv().await
344 }
345}
346
347#[derive(Debug, Clone, PartialEq)]
349pub enum PubSubCommand {
350 Subscribe {
351 topic: String,
352 },
353 Broadcast {
354 topic: String,
355 messages: Vec<ServerMessage>,
356 },
357}
358
359#[cfg(test)]
360mod tests {
361 use super::{
362 BroadcastSubscriptionHandle, PubSub, PubSubBackend, PubSubCapabilities,
363 PubSubDeliveryScope, PubSubMessage, PubSubOrdering, PubSubReceiveError, PubSubSubscription,
364 SessionAffinityRequirement,
365 };
366 use crate::ServerMessage;
367 use std::{
368 collections::HashMap,
369 sync::{Arc, Mutex},
370 };
371 use tokio::sync::broadcast;
372
373 #[tokio::test]
374 async fn in_process_pubsub_broadcasts_to_subscribers() {
375 let pubsub = PubSub::default();
376 let mut first = pubsub.subscribe("chat:lobby");
377 let mut second = pubsub.subscribe("chat:lobby");
378
379 assert_eq!(
380 pubsub.broadcast(
381 "chat:lobby",
382 vec![ServerMessage::Redirect {
383 to: "/ok".to_string()
384 }]
385 ),
386 2
387 );
388
389 assert_eq!(first.recv().await.unwrap().topic, "chat:lobby");
390 assert_eq!(
391 second.recv().await.unwrap().messages,
392 vec![ServerMessage::Redirect {
393 to: "/ok".to_string()
394 }]
395 );
396 }
397
398 #[test]
399 fn in_process_pubsub_reports_cluster_capabilities_and_presence() {
400 let pubsub = PubSub::default();
401 let capabilities = pubsub.capabilities();
402 assert_eq!(capabilities.backend, "in_process");
403 assert_eq!(
404 capabilities.delivery_scope,
405 PubSubDeliveryScope::LocalProcess
406 );
407 assert_eq!(capabilities.ordering, PubSubOrdering::PerTopicOrdered);
408 assert_eq!(
409 capabilities.session_affinity,
410 SessionAffinityRequirement::StatefulSessionRequired
411 );
412 assert!(capabilities.presence_tracking);
413
414 pubsub.register_presence("chat:lobby", "s1", "node-a");
415 pubsub.register_presence("chat:lobby", "s2", "node-a");
416 pubsub.register_presence("chat:lobby", "s3", "node-b");
417 let snapshot = pubsub.presence_snapshot("chat:lobby");
418 assert_eq!(snapshot.topic, "chat:lobby");
419 assert_eq!(snapshot.total_sessions, 3);
420 assert_eq!(snapshot.by_node.get("node-a"), Some(&2));
421 assert_eq!(snapshot.by_node.get("node-b"), Some(&1));
422
423 pubsub.unregister_presence("chat:lobby", "s2", "node-a");
424 let after = pubsub.presence_snapshot("chat:lobby");
425 assert_eq!(after.total_sessions, 2);
426 assert_eq!(after.by_node.get("node-a"), Some(&1));
427 }
428
429 #[derive(Debug, Clone)]
430 struct SharedHub {
431 topics: Arc<Mutex<HashMap<String, broadcast::Sender<PubSubMessage>>>>,
432 }
433
434 impl SharedHub {
435 fn new() -> Self {
436 Self {
437 topics: Arc::new(Mutex::new(HashMap::new())),
438 }
439 }
440
441 fn sender_for(&self, topic: &str) -> broadcast::Sender<PubSubMessage> {
442 let mut topics = self.topics.lock().expect("hub mutex poisoned");
443 topics
444 .entry(topic.to_string())
445 .or_insert_with(|| {
446 let (tx, _) = broadcast::channel(256);
447 tx
448 })
449 .clone()
450 }
451 }
452
453 #[derive(Debug, Clone)]
454 struct MockClusterBackend {
455 hub: SharedHub,
456 }
457
458 impl PubSubBackend for MockClusterBackend {
459 fn subscribe(&self, topic: &str) -> PubSubSubscription {
460 let receiver = self.hub.sender_for(topic).subscribe();
461 PubSubSubscription::new(BroadcastSubscriptionHandle { receiver })
462 }
463
464 fn broadcast(&self, topic: &str, messages: Vec<ServerMessage>) -> usize {
465 self.hub
466 .sender_for(topic)
467 .send(PubSubMessage {
468 topic: topic.to_string(),
469 messages,
470 })
471 .unwrap_or_default()
472 }
473
474 fn capabilities(&self) -> PubSubCapabilities {
475 PubSubCapabilities {
476 backend: "mock_cluster".to_string(),
477 delivery_scope: PubSubDeliveryScope::Cluster,
478 ordering: PubSubOrdering::BestEffort,
479 session_affinity: SessionAffinityRequirement::StatefulSessionRequired,
480 presence_tracking: false,
481 }
482 }
483 }
484
485 #[tokio::test]
486 async fn custom_backend_can_fanout_across_multiple_pubsub_instances() {
487 let hub = SharedHub::new();
488 let node_a = PubSub::with_backend(MockClusterBackend { hub: hub.clone() });
489 let node_b = PubSub::with_backend(MockClusterBackend { hub });
490
491 let mut subscription = node_a.subscribe("cluster:lobby");
492 assert_eq!(
493 node_b.broadcast(
494 "cluster:lobby",
495 vec![ServerMessage::Error {
496 message: "hello".to_string(),
497 code: Some("cluster".to_string()),
498 }]
499 ),
500 1
501 );
502
503 let delivered = subscription.recv().await.unwrap();
504 assert_eq!(delivered.topic, "cluster:lobby");
505 assert_eq!(delivered.messages.len(), 1);
506 match &delivered.messages[0] {
507 ServerMessage::Error { message, code } => {
508 assert_eq!(message, "hello");
509 assert_eq!(code.as_deref(), Some("cluster"));
510 }
511 other => panic!("unexpected payload: {other:?}"),
512 }
513 }
514
515 #[test]
516 fn pubsub_debug_and_presence_cleanup_cover_additional_branches() {
517 let pubsub = PubSub::default();
518 assert!(format!("{pubsub:?}").contains("capabilities"));
519
520 pubsub.register_presence("chat:lobby", "session-1", "node-a");
521 pubsub.unregister_presence("chat:lobby", "session-1", "node-a");
522 let snapshot = pubsub.presence_snapshot("chat:lobby");
523 assert_eq!(snapshot.total_sessions, 0);
524 assert!(snapshot.by_node.is_empty());
525 }
526
527 #[derive(Debug, Clone, Default)]
528 struct NoPresenceBackend;
529
530 impl PubSubBackend for NoPresenceBackend {
531 fn subscribe(&self, _topic: &str) -> PubSubSubscription {
532 let (_sender, receiver) = broadcast::channel(8);
533 PubSubSubscription::new(BroadcastSubscriptionHandle { receiver })
534 }
535
536 fn broadcast(&self, _topic: &str, _messages: Vec<ServerMessage>) -> usize {
537 0
538 }
539
540 fn capabilities(&self) -> PubSubCapabilities {
541 PubSubCapabilities {
542 backend: "no_presence".to_string(),
543 delivery_scope: PubSubDeliveryScope::LocalProcess,
544 ordering: PubSubOrdering::BestEffort,
545 session_affinity: SessionAffinityRequirement::None,
546 presence_tracking: false,
547 }
548 }
549 }
550
551 #[test]
552 fn default_presence_methods_return_empty_snapshot() {
553 let pubsub = PubSub::with_backend(NoPresenceBackend);
554 pubsub.register_presence("topic", "session-1", "node-a");
555 pubsub.unregister_presence("topic", "session-1", "node-a");
556 let snapshot = pubsub.presence_snapshot("topic");
557 assert_eq!(snapshot.topic, "topic");
558 assert_eq!(snapshot.total_sessions, 0);
559 assert!(snapshot.by_node.is_empty());
560 }
561
562 #[tokio::test]
563 async fn broadcast_subscription_maps_lagged_and_closed_recv_errors() {
564 let (sender, receiver) = broadcast::channel(1);
565 let mut lagged = PubSubSubscription::new(BroadcastSubscriptionHandle { receiver });
566 sender
567 .send(PubSubMessage {
568 topic: "topic".to_string(),
569 messages: vec![ServerMessage::Redirect {
570 to: "/a".to_string(),
571 }],
572 })
573 .unwrap();
574 sender
575 .send(PubSubMessage {
576 topic: "topic".to_string(),
577 messages: vec![ServerMessage::Redirect {
578 to: "/b".to_string(),
579 }],
580 })
581 .unwrap();
582
583 assert!(matches!(
584 lagged.recv().await,
585 Err(PubSubReceiveError::Lagged(_))
586 ));
587
588 let (sender2, receiver2) = broadcast::channel(1);
589 drop(sender2);
590 let mut closed = PubSubSubscription::new(BroadcastSubscriptionHandle {
591 receiver: receiver2,
592 });
593 assert!(matches!(
594 closed.recv().await,
595 Err(PubSubReceiveError::Closed)
596 ));
597 }
598}