zerodds_websocket_bridge/daemon/
router.rs1use std::collections::BTreeMap;
16use std::string::String;
17use std::sync::mpsc;
18use std::vec::Vec;
19
20#[derive(Debug, Clone)]
22pub enum RouterMsg {
23 Sample {
25 topic: String,
27 payload: Vec<u8>,
29 },
30 Shutdown,
32}
33
34#[derive(Debug, Default)]
36pub struct Router {
37 subs: BTreeMap<String, Vec<u64>>,
39 conns: BTreeMap<u64, mpsc::Sender<RouterMsg>>,
41}
42
43impl Router {
44 #[must_use]
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn register_connection(&mut self, id: u64, sender: mpsc::Sender<RouterMsg>) {
52 self.conns.insert(id, sender);
53 }
54
55 pub fn deregister_connection(&mut self, id: u64) {
57 self.conns.remove(&id);
58 for subs in self.subs.values_mut() {
59 subs.retain(|c| *c != id);
60 }
61 }
62
63 pub fn subscribe(&mut self, conn_id: u64, topic: String) {
65 let entry = self.subs.entry(topic).or_default();
66 if !entry.contains(&conn_id) {
67 entry.push(conn_id);
68 }
69 }
70
71 pub fn unsubscribe(&mut self, conn_id: u64, topic: &str) {
73 if let Some(list) = self.subs.get_mut(topic) {
74 list.retain(|c| *c != conn_id);
75 }
76 }
77
78 pub fn dispatch(&mut self, topic: &str, payload: Vec<u8>) -> usize {
82 let Some(subs) = self.subs.get(topic).cloned() else {
83 return 0;
84 };
85 let mut delivered = 0usize;
86 for conn_id in subs {
87 if let Some(sender) = self.conns.get(&conn_id) {
88 let msg = RouterMsg::Sample {
89 topic: topic.to_string(),
90 payload: payload.clone(),
91 };
92 if sender.send(msg).is_ok() {
93 delivered += 1;
94 } else {
95 self.conns.remove(&conn_id);
97 }
98 }
99 }
100 delivered
101 }
102
103 pub fn broadcast_shutdown(&self) {
105 for sender in self.conns.values() {
106 let _ = sender.send(RouterMsg::Shutdown);
107 }
108 }
109
110 #[must_use]
112 pub fn connection_count(&self) -> usize {
113 self.conns.len()
114 }
115}
116
117#[cfg(test)]
118#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
119mod tests {
120 use super::*;
121 use std::sync::mpsc::channel;
122
123 #[test]
124 fn dispatch_to_subscribed_connection() {
125 let mut router = Router::new();
126 let (tx, rx) = channel();
127 router.register_connection(1, tx);
128 router.subscribe(1, "Trade".to_string());
129 let n = router.dispatch("Trade", b"PAYLOAD".to_vec());
130 assert_eq!(n, 1);
131 match rx.recv().unwrap() {
132 RouterMsg::Sample { topic, payload } => {
133 assert_eq!(topic, "Trade");
134 assert_eq!(payload, b"PAYLOAD");
135 }
136 other => panic!("unexpected msg {other:?}"),
137 }
138 }
139
140 #[test]
141 fn dispatch_to_no_subscribers_is_zero() {
142 let mut router = Router::new();
143 let n = router.dispatch("Empty", b"x".to_vec());
144 assert_eq!(n, 0);
145 }
146
147 #[test]
148 fn unsubscribe_stops_delivery() {
149 let mut router = Router::new();
150 let (tx, rx) = channel();
151 router.register_connection(2, tx);
152 router.subscribe(2, "T".to_string());
153 router.unsubscribe(2, "T");
154 let n = router.dispatch("T", b"x".to_vec());
155 assert_eq!(n, 0);
156 assert!(rx.try_recv().is_err());
157 }
158
159 #[test]
160 fn deregister_removes_subscription() {
161 let mut router = Router::new();
162 let (tx, _rx) = channel();
163 router.register_connection(3, tx);
164 router.subscribe(3, "T".to_string());
165 router.deregister_connection(3);
166 assert_eq!(router.connection_count(), 0);
167 }
168
169 #[test]
170 fn shutdown_broadcasts_to_all() {
171 let mut router = Router::new();
172 let (tx, rx) = channel();
173 router.register_connection(7, tx);
174 router.broadcast_shutdown();
175 assert!(matches!(rx.recv().unwrap(), RouterMsg::Shutdown));
176 }
177}