yrs_axum/
signaling.rs

1use bytes::Bytes;
2use futures_util::stream::SplitSink;
3use futures_util::{SinkExt, StreamExt};
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6use std::hash::{Hash, Hasher};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::select;
10use tokio::sync::{Mutex, RwLock};
11use tokio::time::interval;
12use axum::Error;
13use axum::extract::ws::{Message, WebSocket};
14
15const PING_TIMEOUT: Duration = Duration::from_secs(30);
16
17/// Signaling service is used by y-webrtc protocol in order to exchange WebRTC offerings between
18/// clients subscribing to particular rooms.
19///
20/// # Example
21///
22/// ```rust
23/// use std::net::SocketAddr;
24/// use std::str::FromStr;
25/// use axum::{
26///     Router,
27///     routing::get,
28///     extract::ws::{WebSocket, WebSocketUpgrade},
29///     extract::State,
30///     response::IntoResponse,
31/// };
32/// use yrs_axum::signaling::{SignalingService, signaling_conn};
33///
34/// #[tokio::main]
35/// async fn main() {
36///     let addr = SocketAddr::from_str("0.0.0.0:8000").unwrap();
37///     let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
38///     
39///     let signaling = SignalingService::new();
40///     
41///     let app = Router::new()
42///         .route("/signaling", get(ws_handler))
43///         .with_state(signaling);
44///
45///     let (tx, rx) = tokio::sync::oneshot::channel::<bool>();
46///     tokio::spawn(async move {
47///     axum::serve(listener, app.into_make_service())
48///         .with_graceful_shutdown(async move {
49///           rx.await.unwrap();
50///         })
51///         .await
52///         .unwrap();
53///     });
54///     tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
55///     tx.send(true);
56/// }
57///
58/// async fn ws_handler(
59///     ws: WebSocketUpgrade,
60///     State(svc): State<SignalingService>,
61/// ) -> impl IntoResponse {
62///     ws.on_upgrade(move |socket| peer(socket, svc))
63/// }
64///
65/// async fn peer(ws: WebSocket, svc: SignalingService) {
66///     match signaling_conn(ws, svc).await {
67///         Ok(_) => println!("signaling connection stopped"),
68///         Err(e) => eprintln!("signaling connection failed: {}", e),
69///     }
70/// }
71/// ```
72#[derive(Debug, Clone)]
73pub struct SignalingService(Topics);
74
75impl SignalingService {
76    pub fn new() -> Self {
77        SignalingService(Arc::new(RwLock::new(Default::default())))
78    }
79
80    pub async fn publish(&self, topic: &str, msg: Message) -> Result<(), Error> {
81        let mut failed = Vec::new();
82        {
83            let topics = self.0.read().await;
84            if let Some(subs) = topics.get(topic) {
85                let client_count = subs.len();
86                tracing::info!("publishing message to {client_count} clients: {msg:?}");
87                for sub in subs {
88                    if let Err(e) = sub.try_send(msg.clone()).await {
89                        tracing::info!("failed to send {msg:?}: {e}");
90                        failed.push(sub.clone());
91                    }
92                }
93            }
94        }
95        if !failed.is_empty() {
96            let mut topics = self.0.write().await;
97            if let Some(subs) = topics.get_mut(topic) {
98                for f in failed {
99                    subs.remove(&f);
100                }
101            }
102        }
103        Ok(())
104    }
105
106    pub async fn close_topic(&self, topic: &str) -> Result<(), Error> {
107        let mut topics = self.0.write().await;
108        if let Some(subs) = topics.remove(topic) {
109            for sub in subs {
110                if let Err(e) = sub.close().await {
111                    tracing::warn!("failed to close connection on topic '{topic}': {e}");
112                }
113            }
114        }
115        Ok(())
116    }
117
118    pub async fn close(self) -> Result<(), Error> {
119        let mut topics = self.0.write_owned().await;
120        let mut all_conns = HashSet::new();
121        for (_, subs) in topics.drain() {
122            for sub in subs {
123                all_conns.insert(sub);
124            }
125        }
126
127        for conn in all_conns {
128            if let Err(e) = conn.close().await {
129                tracing::warn!("failed to close connection: {e}");
130            }
131        }
132
133        Ok(())
134    }
135}
136
137impl Default for SignalingService {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143type Topics = Arc<RwLock<HashMap<Arc<str>, HashSet<WsSink>>>>;
144
145#[derive(Debug, Clone)]
146struct WsSink(Arc<Mutex<SplitSink<WebSocket, Message>>>);
147
148impl WsSink {
149    fn new(sink: SplitSink<WebSocket, Message>) -> Self {
150        WsSink(Arc::new(Mutex::new(sink)))
151    }
152
153    async fn try_send(&self, msg: Message) -> Result<(), Error> {
154        let mut sink = self.0.lock().await;
155        if let Err(e) = sink.send(msg).await {
156            sink.close().await?;
157            Err(e)
158        } else {
159            Ok(())
160        }
161    }
162
163    async fn close(&self) -> Result<(), Error> {
164        let mut sink = self.0.lock().await;
165        sink.close().await
166    }
167}
168
169impl Hash for WsSink {
170    fn hash<H: Hasher>(&self, state: &mut H) {
171        let ptr = Arc::as_ptr(&self.0) as usize;
172        ptr.hash(state);
173    }
174}
175
176impl PartialEq<Self> for WsSink {
177    fn eq(&self, other: &Self) -> bool {
178        Arc::ptr_eq(&self.0, &other.0)
179    }
180}
181
182impl Eq for WsSink {}
183
184/// Handle incoming signaling connection - it's a websocket connection used by y-webrtc protocol
185/// to exchange offering metadata between y-webrtc peers. It also manages topic/room access.
186pub async fn signaling_conn(ws: WebSocket, service: SignalingService) -> Result<(), Error> {
187    let mut topics: Topics = service.0;
188    let (sink, mut stream) = ws.split();
189    let ws = WsSink::new(sink);
190    let mut ping_interval = interval(PING_TIMEOUT);
191    let mut state = ConnState::default();
192    loop {
193        select! {
194            _ = ping_interval.tick() => {
195                if !state.pong_received {
196                    ws.close().await?;
197                    drop(ping_interval);
198                    return Ok(());
199                } else {
200                    state.pong_received = false;
201                    if let Err(e) = ws.try_send(Message::Ping(Bytes::default())).await {
202                        ws.close().await?;
203                        return Err(e);
204                    }
205                }
206            },
207            res = stream.next() => {
208                match res {
209                    None => {
210                        ws.close().await?;
211                        return Ok(());
212                    },
213                    Some(Err(e)) => {
214                        ws.close().await?;
215                        return Err(e);
216                    },
217                    Some(Ok(msg)) => {
218                        process_msg(msg, &ws, &mut state, &mut topics).await?;
219                    }
220                }
221            }
222        }
223    }
224}
225
226const PING_MSG: &'static str = r#"{"type":"ping"}"#;
227const PONG_MSG: &'static str = r#"{"type":"pong"}"#;
228
229async fn process_msg(
230    msg: Message,
231    ws: &WsSink,
232    state: &mut ConnState,
233    topics: &mut Topics,
234) -> Result<(), Error> {
235    match msg {
236        Message::Text(txt) => {
237            let json = txt.as_str();
238            let msg = serde_json::from_str(json).unwrap();
239            match msg {
240                Signal::Subscribe {
241                    topics: topic_names,
242                } => {
243                    if !topic_names.is_empty() {
244                        let mut topics = topics.write().await;
245                        for topic in topic_names {
246                            tracing::trace!("subscribing new client to '{topic}'");
247                            if let Some((key, _)) = topics.get_key_value(topic) {
248                                state.subscribed_topics.insert(key.clone());
249                                let subs = topics.get_mut(topic).unwrap();
250                                subs.insert(ws.clone());
251                            } else {
252                                let topic: Arc<str> = topic.into();
253                                state.subscribed_topics.insert(topic.clone());
254                                let mut subs = HashSet::new();
255                                subs.insert(ws.clone());
256                                topics.insert(topic, subs);
257                            };
258                        }
259                    }
260                }
261                Signal::Unsubscribe {
262                    topics: topic_names,
263                } => {
264                    if !topic_names.is_empty() {
265                        let mut topics = topics.write().await;
266                        for topic in topic_names {
267                            if let Some(subs) = topics.get_mut(topic) {
268                                tracing::trace!("unsubscribing client from '{topic}'");
269                                subs.remove(ws);
270                            }
271                        }
272                    }
273                }
274                Signal::Publish { topic } => {
275                    let mut failed = Vec::new();
276                    {
277                        let topics = topics.read().await;
278                        if let Some(receivers) = topics.get(topic) {
279                            let client_count = receivers.len();
280                            tracing::trace!(
281                                "publishing on {client_count} clients at '{topic}': {json}"
282                            );
283                            for receiver in receivers.iter() {
284                                if let Err(e) = receiver.try_send(Message::text(json)).await {
285                                    tracing::info!(
286                                        "failed to publish message {json} on '{topic}': {e}"
287                                    );
288                                    failed.push(receiver.clone());
289                                }
290                            }
291                        }
292                    }
293                    if !failed.is_empty() {
294                        let mut topics = topics.write().await;
295                        if let Some(receivers) = topics.get_mut(topic) {
296                            for f in failed {
297                                receivers.remove(&f);
298                            }
299                        }
300                    }
301                }
302                Signal::Ping => {
303                    ws.try_send(Message::text(PONG_MSG)).await?;
304                }
305                Signal::Pong => {
306                    ws.try_send(Message::text(PING_MSG)).await?;
307                }
308            }
309        },
310        Message::Close(_close_frame) => {
311            let mut topics = topics.write().await;
312            for topic in state.subscribed_topics.drain() {
313                if let Some(subs) = topics.get_mut(&topic) {
314                    subs.remove(ws);
315                    if subs.is_empty() {
316                        topics.remove(&topic);
317                    }
318                }
319            }
320            state.closed = true;
321        },
322        Message::Ping(_bytes) => {
323            ws.try_send(Message::Ping(Bytes::default())).await?;
324        }, 
325        _ => {}
326
327    }
328    Ok(())
329}
330
331#[derive(Debug)]
332struct ConnState {
333    closed: bool,
334    pong_received: bool,
335    subscribed_topics: HashSet<Arc<str>>,
336}
337
338impl Default for ConnState {
339    fn default() -> Self {
340        ConnState {
341            closed: false,
342            pong_received: true,
343            subscribed_topics: HashSet::new(),
344        }
345    }
346}
347
348#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
349#[serde(tag = "type")]
350pub(crate) enum Signal<'a> {
351    #[serde(rename = "publish")]
352    Publish { topic: &'a str },
353    #[serde(rename = "subscribe")]
354    Subscribe { topics: Vec<&'a str> },
355    #[serde(rename = "unsubscribe")]
356    Unsubscribe { topics: Vec<&'a str> },
357    #[serde(rename = "ping")]
358    Ping,
359    #[serde(rename = "pong")]
360    Pong,
361}