1use bytes::Bytes;
2use futures_util::stream::SplitSink;
3use futures_util::{SinkExt, StreamExt};
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6use std::hash::{Hash, Hasher};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::select;
10use tokio::sync::{Mutex, RwLock};
11use tokio::time::interval;
12use axum::Error;
13use axum::extract::ws::{Message, WebSocket};
14
15const PING_TIMEOUT: Duration = Duration::from_secs(30);
16
17#[derive(Debug, Clone)]
73pub struct SignalingService(Topics);
74
75impl SignalingService {
76 pub fn new() -> Self {
77 SignalingService(Arc::new(RwLock::new(Default::default())))
78 }
79
80 pub async fn publish(&self, topic: &str, msg: Message) -> Result<(), Error> {
81 let mut failed = Vec::new();
82 {
83 let topics = self.0.read().await;
84 if let Some(subs) = topics.get(topic) {
85 let client_count = subs.len();
86 tracing::info!("publishing message to {client_count} clients: {msg:?}");
87 for sub in subs {
88 if let Err(e) = sub.try_send(msg.clone()).await {
89 tracing::info!("failed to send {msg:?}: {e}");
90 failed.push(sub.clone());
91 }
92 }
93 }
94 }
95 if !failed.is_empty() {
96 let mut topics = self.0.write().await;
97 if let Some(subs) = topics.get_mut(topic) {
98 for f in failed {
99 subs.remove(&f);
100 }
101 }
102 }
103 Ok(())
104 }
105
106 pub async fn close_topic(&self, topic: &str) -> Result<(), Error> {
107 let mut topics = self.0.write().await;
108 if let Some(subs) = topics.remove(topic) {
109 for sub in subs {
110 if let Err(e) = sub.close().await {
111 tracing::warn!("failed to close connection on topic '{topic}': {e}");
112 }
113 }
114 }
115 Ok(())
116 }
117
118 pub async fn close(self) -> Result<(), Error> {
119 let mut topics = self.0.write_owned().await;
120 let mut all_conns = HashSet::new();
121 for (_, subs) in topics.drain() {
122 for sub in subs {
123 all_conns.insert(sub);
124 }
125 }
126
127 for conn in all_conns {
128 if let Err(e) = conn.close().await {
129 tracing::warn!("failed to close connection: {e}");
130 }
131 }
132
133 Ok(())
134 }
135}
136
137impl Default for SignalingService {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143type Topics = Arc<RwLock<HashMap<Arc<str>, HashSet<WsSink>>>>;
144
145#[derive(Debug, Clone)]
146struct WsSink(Arc<Mutex<SplitSink<WebSocket, Message>>>);
147
148impl WsSink {
149 fn new(sink: SplitSink<WebSocket, Message>) -> Self {
150 WsSink(Arc::new(Mutex::new(sink)))
151 }
152
153 async fn try_send(&self, msg: Message) -> Result<(), Error> {
154 let mut sink = self.0.lock().await;
155 if let Err(e) = sink.send(msg).await {
156 sink.close().await?;
157 Err(e)
158 } else {
159 Ok(())
160 }
161 }
162
163 async fn close(&self) -> Result<(), Error> {
164 let mut sink = self.0.lock().await;
165 sink.close().await
166 }
167}
168
169impl Hash for WsSink {
170 fn hash<H: Hasher>(&self, state: &mut H) {
171 let ptr = Arc::as_ptr(&self.0) as usize;
172 ptr.hash(state);
173 }
174}
175
176impl PartialEq<Self> for WsSink {
177 fn eq(&self, other: &Self) -> bool {
178 Arc::ptr_eq(&self.0, &other.0)
179 }
180}
181
182impl Eq for WsSink {}
183
184pub async fn signaling_conn(ws: WebSocket, service: SignalingService) -> Result<(), Error> {
187 let mut topics: Topics = service.0;
188 let (sink, mut stream) = ws.split();
189 let ws = WsSink::new(sink);
190 let mut ping_interval = interval(PING_TIMEOUT);
191 let mut state = ConnState::default();
192 loop {
193 select! {
194 _ = ping_interval.tick() => {
195 if !state.pong_received {
196 ws.close().await?;
197 drop(ping_interval);
198 return Ok(());
199 } else {
200 state.pong_received = false;
201 if let Err(e) = ws.try_send(Message::Ping(Bytes::default())).await {
202 ws.close().await?;
203 return Err(e);
204 }
205 }
206 },
207 res = stream.next() => {
208 match res {
209 None => {
210 ws.close().await?;
211 return Ok(());
212 },
213 Some(Err(e)) => {
214 ws.close().await?;
215 return Err(e);
216 },
217 Some(Ok(msg)) => {
218 process_msg(msg, &ws, &mut state, &mut topics).await?;
219 }
220 }
221 }
222 }
223 }
224}
225
226const PING_MSG: &'static str = r#"{"type":"ping"}"#;
227const PONG_MSG: &'static str = r#"{"type":"pong"}"#;
228
229async fn process_msg(
230 msg: Message,
231 ws: &WsSink,
232 state: &mut ConnState,
233 topics: &mut Topics,
234) -> Result<(), Error> {
235 match msg {
236 Message::Text(txt) => {
237 let json = txt.as_str();
238 let msg = serde_json::from_str(json).unwrap();
239 match msg {
240 Signal::Subscribe {
241 topics: topic_names,
242 } => {
243 if !topic_names.is_empty() {
244 let mut topics = topics.write().await;
245 for topic in topic_names {
246 tracing::trace!("subscribing new client to '{topic}'");
247 if let Some((key, _)) = topics.get_key_value(topic) {
248 state.subscribed_topics.insert(key.clone());
249 let subs = topics.get_mut(topic).unwrap();
250 subs.insert(ws.clone());
251 } else {
252 let topic: Arc<str> = topic.into();
253 state.subscribed_topics.insert(topic.clone());
254 let mut subs = HashSet::new();
255 subs.insert(ws.clone());
256 topics.insert(topic, subs);
257 };
258 }
259 }
260 }
261 Signal::Unsubscribe {
262 topics: topic_names,
263 } => {
264 if !topic_names.is_empty() {
265 let mut topics = topics.write().await;
266 for topic in topic_names {
267 if let Some(subs) = topics.get_mut(topic) {
268 tracing::trace!("unsubscribing client from '{topic}'");
269 subs.remove(ws);
270 }
271 }
272 }
273 }
274 Signal::Publish { topic } => {
275 let mut failed = Vec::new();
276 {
277 let topics = topics.read().await;
278 if let Some(receivers) = topics.get(topic) {
279 let client_count = receivers.len();
280 tracing::trace!(
281 "publishing on {client_count} clients at '{topic}': {json}"
282 );
283 for receiver in receivers.iter() {
284 if let Err(e) = receiver.try_send(Message::text(json)).await {
285 tracing::info!(
286 "failed to publish message {json} on '{topic}': {e}"
287 );
288 failed.push(receiver.clone());
289 }
290 }
291 }
292 }
293 if !failed.is_empty() {
294 let mut topics = topics.write().await;
295 if let Some(receivers) = topics.get_mut(topic) {
296 for f in failed {
297 receivers.remove(&f);
298 }
299 }
300 }
301 }
302 Signal::Ping => {
303 ws.try_send(Message::text(PONG_MSG)).await?;
304 }
305 Signal::Pong => {
306 ws.try_send(Message::text(PING_MSG)).await?;
307 }
308 }
309 },
310 Message::Close(_close_frame) => {
311 let mut topics = topics.write().await;
312 for topic in state.subscribed_topics.drain() {
313 if let Some(subs) = topics.get_mut(&topic) {
314 subs.remove(ws);
315 if subs.is_empty() {
316 topics.remove(&topic);
317 }
318 }
319 }
320 state.closed = true;
321 },
322 Message::Ping(_bytes) => {
323 ws.try_send(Message::Ping(Bytes::default())).await?;
324 },
325 _ => {}
326
327 }
328 Ok(())
329}
330
331#[derive(Debug)]
332struct ConnState {
333 closed: bool,
334 pong_received: bool,
335 subscribed_topics: HashSet<Arc<str>>,
336}
337
338impl Default for ConnState {
339 fn default() -> Self {
340 ConnState {
341 closed: false,
342 pong_received: true,
343 subscribed_topics: HashSet::new(),
344 }
345 }
346}
347
348#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
349#[serde(tag = "type")]
350pub(crate) enum Signal<'a> {
351 #[serde(rename = "publish")]
352 Publish { topic: &'a str },
353 #[serde(rename = "subscribe")]
354 Subscribe { topics: Vec<&'a str> },
355 #[serde(rename = "unsubscribe")]
356 Unsubscribe { topics: Vec<&'a str> },
357 #[serde(rename = "ping")]
358 Ping,
359 #[serde(rename = "pong")]
360 Pong,
361}