yrs_axum/
ws.rs

1use crate::conn::Connection;
2use crate::AwarenessRef;
3use futures_util::stream::{SplitSink, SplitStream};
4use futures_util::{Stream, StreamExt};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use axum::extract::ws::{WebSocket, Message};
8use yrs::sync::Error;
9
10/// Connection Wrapper over a [WebSocket], which implements a Yjs/Yrs awareness and update exchange
11/// protocol.
12///
13/// This connection implements Future pattern and can be awaited upon in order for a caller to
14/// recognize whether underlying websocket connection has been finished gracefully or abruptly.
15#[repr(transparent)]
16#[derive(Debug)]
17pub struct AxumConn(Connection<AxumSink, AxumStream>);
18
19impl AxumConn {
20    pub fn new(awareness: AwarenessRef, socket: WebSocket) -> Self {
21        let (sink, stream) = socket.split();
22        let conn = Connection::new(awareness, AxumSink(sink), AxumStream(stream));
23        AxumConn(conn)
24    }
25}
26
27impl core::future::Future for AxumConn {
28    type Output = Result<(), Error>;
29
30    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
31        match Pin::new(&mut self.0).poll(cx) {
32            Poll::Pending => Poll::Pending,
33            Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
34            Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
35        }
36    }
37}
38
39/// An axum websocket sink wrapper, that implements futures `Sink` in a way, that makes it compatible
40/// with y-sync protocol, so that it can be used by y-sync crate [BroadcastGroup].
41///
42/// # Examples
43///
44/// ```rust
45/// use std::net::SocketAddr;
46/// use std::str::FromStr;
47/// use std::sync::Arc;
48/// use tokio::sync::Mutex;
49/// use tokio::task::JoinHandle;
50/// use futures_util::stream::StreamExt;
51/// use axum::{
52///     Router,
53///     routing::get,
54///     extract::ws::{WebSocket, WebSocketUpgrade},
55///     extract::State,
56///     response::IntoResponse,
57/// };
58/// use yrs_axum::broadcast::BroadcastGroup;
59/// use yrs_axum::ws::{AxumSink, AxumStream};
60///
61/// async fn start_server(
62///     addr: &str,
63///     bcast: Arc<BroadcastGroup>,
64/// ) -> Result<JoinHandle<()>, Box<dyn std::error::Error>> {
65///     let addr = SocketAddr::from_str(addr)?;
66///     let listener = tokio::net::TcpListener::bind(addr).await?;
67///     
68///     let app = Router::new()
69///         .route("/my-room", get(ws_handler))
70///         .with_state(bcast);
71///
72///     Ok(tokio::spawn(async move {
73///         axum::serve(listener, app.into_make_service())
74///             .await
75///             .unwrap();
76///     }))
77/// }
78///
79/// async fn ws_handler(
80///     ws: WebSocketUpgrade,
81///     State(bcast): State<Arc<BroadcastGroup>>,
82/// ) -> impl IntoResponse {
83///     ws.on_upgrade(move |socket| peer(socket, bcast))
84/// }
85///
86/// async fn peer(ws: WebSocket, bcast: Arc<BroadcastGroup>) {
87///     let (sink, stream) = ws.split();
88///     // convert axum web socket into compatible sink/stream
89///     let sink = Arc::new(Mutex::new(AxumSink(sink)));
90///     let stream = AxumStream(stream);
91///     // subscribe to broadcast group
92///     let sub = bcast.subscribe(sink, stream);
93///     // wait for subscribed connection to close itself
94///     match sub.completed().await {
95///         Ok(_) => println!("broadcasting for channel finished successfully"),
96///         Err(e) => eprintln!("broadcasting for channel finished abruptly: {}", e),
97///     }
98/// }
99/// ```
100#[repr(transparent)]
101#[derive(Debug)]
102pub struct AxumSink(pub SplitSink<WebSocket, Message>);
103
104impl From<SplitSink<WebSocket, Message>> for AxumSink {
105    fn from(sink: SplitSink<WebSocket, Message>) -> Self {
106        AxumSink(sink)
107    }
108}
109
110impl Into<SplitSink<WebSocket, Message>> for AxumSink {
111    fn into(self) -> SplitSink<WebSocket, Message> {
112        self.0
113    }
114}
115
116impl futures_util::Sink<Vec<u8>> for AxumSink {
117    type Error = Error;
118
119    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120        match Pin::new(&mut self.0).poll_ready(cx) {
121            Poll::Pending => Poll::Pending,
122            Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
123            Poll::Ready(_) => Poll::Ready(Ok(())),
124        }
125    }
126
127    fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
128        if let Err(e) = Pin::new(&mut self.0).start_send(Message::binary(item)) {
129            Err(Error::Other(e.into()))
130        } else {
131            Ok(())
132        }
133    }
134
135    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
136        match Pin::new(&mut self.0).poll_flush(cx) {
137            Poll::Pending => Poll::Pending,
138            Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
139            Poll::Ready(_) => Poll::Ready(Ok(())),
140        }
141    }
142
143    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
144        match Pin::new(&mut self.0).poll_close(cx) {
145            Poll::Pending => Poll::Pending,
146            Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
147            Poll::Ready(_) => Poll::Ready(Ok(())),
148        }
149    }
150}
151
152/// An axum websocket stream wrapper, that implements futures `Stream` in a way, that makes it compatible
153/// with y-sync protocol, so that it can be used by y-sync crate [BroadcastGroup].
154///
155/// # Examples
156///
157/// ```rust
158/// use std::net::SocketAddr;
159/// use std::str::FromStr;
160/// use std::sync::Arc;
161/// use tokio::sync::Mutex;
162/// use tokio::task::JoinHandle;
163/// use futures_util::stream::StreamExt;
164/// use axum::{
165///     Router,
166///     routing::get,
167///     extract::ws::{WebSocket, WebSocketUpgrade},
168///     extract::State,
169///     response::IntoResponse,
170/// };
171/// use yrs_axum::broadcast::BroadcastGroup;
172/// use yrs_axum::ws::{AxumSink, AxumStream};
173///
174/// async fn start_server(
175///     addr: &str,
176///     bcast: Arc<BroadcastGroup>,
177/// ) -> Result<JoinHandle<()>, Box<dyn std::error::Error>> {
178///     let addr = SocketAddr::from_str(addr)?;
179///     let listener = tokio::net::TcpListener::bind(addr).await?;
180///     
181///     let app = Router::new()
182///         .route("/my-room", get(ws_handler))
183///         .with_state(bcast);
184///
185///     Ok(tokio::spawn(async move {
186///         axum::serve(listener, app.into_make_service())
187///             .await
188///             .unwrap();
189///     }))
190/// }
191///
192/// async fn ws_handler(
193///     ws: WebSocketUpgrade,
194///     State(bcast): State<Arc<BroadcastGroup>>,
195/// ) -> impl IntoResponse {
196///     ws.on_upgrade(move |socket| peer(socket, bcast))
197/// }
198///
199/// async fn peer(ws: WebSocket, bcast: Arc<BroadcastGroup>) {
200///     let (sink, stream) = ws.split();
201///     // convert axum web socket into compatible sink/stream
202///     let sink = Arc::new(Mutex::new(AxumSink(sink)));
203///     let stream = AxumStream(stream);
204///     // subscribe to broadcast group
205///     let sub = bcast.subscribe(sink, stream);
206///     // wait for subscribed connection to close itself
207///     match sub.completed().await {
208///         Ok(_) => println!("broadcasting for channel finished successfully"),
209///         Err(e) => eprintln!("broadcasting for channel finished abruptly: {}", e),
210///     }
211/// }
212/// ```
213#[derive(Debug)]
214pub struct AxumStream(pub SplitStream<WebSocket>);
215
216impl From<SplitStream<WebSocket>> for AxumStream {
217    fn from(stream: SplitStream<WebSocket>) -> Self {
218        AxumStream(stream)
219    }
220}
221
222impl Into<SplitStream<WebSocket>> for AxumStream {
223    fn into(self) -> SplitStream<WebSocket> {
224        self.0
225    }
226}
227
228impl Stream for AxumStream {
229    type Item = Result<Vec<u8>, Error>;
230
231    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
232        match Pin::new(&mut self.0).poll_next(cx) {
233            Poll::Pending => Poll::Pending,
234            Poll::Ready(None) => Poll::Ready(None),
235            Poll::Ready(Some(res)) => match res {
236                Ok(item) => Poll::Ready(Some(Ok(item.into_data().to_vec()))),
237                Err(e) => Poll::Ready(Some(Err(Error::Other(e.into())))),
238            },
239        }
240    }
241}
242
243#[cfg(test)]
244mod test {
245    use crate::broadcast::BroadcastGroup;
246    use crate::conn::Connection;
247    use crate::ws::{AxumSink, AxumStream};
248    use futures_util::stream::{SplitSink, SplitStream};
249    use futures_util::{ready, SinkExt, Stream, StreamExt};
250    use std::pin::Pin;
251    use std::sync::Arc;
252    use std::task::{Context, Poll};
253    use std::time::Duration;
254    use tokio::net::TcpStream;
255    use tokio::sync::{Mutex, Notify, RwLock};
256    use tokio::task;
257    use tokio::task::JoinHandle;
258    use tokio::time::{sleep, timeout};
259    use tokio_tungstenite::tungstenite::Message;
260    use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
261    use axum::{
262        Router,
263        routing::get,
264        extract::ws::{WebSocket, WebSocketUpgrade},
265        extract::State,
266        response::IntoResponse,
267    };
268    use yrs::sync::{Awareness, Error};
269    use yrs::updates::encoder::Encode;
270    use yrs::{Doc, GetString, Subscription, Text, Transact};
271
272    async fn start_server(
273        addr: &str,
274        bcast: Arc<BroadcastGroup>,
275    ) -> Result<JoinHandle<()>, Box<dyn std::error::Error>> {
276        let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
277        
278        let app = Router::new()
279            .route("/my-room", get(ws_handler))
280            .with_state(bcast);
281
282        Ok(tokio::spawn(async move {
283            axum::serve(listener, app.into_make_service())
284                .await
285                .unwrap();
286        }))
287    }
288
289    async fn ws_handler(
290        ws: WebSocketUpgrade,
291        State(bcast): State<Arc<BroadcastGroup>>,
292    ) -> impl IntoResponse {
293        ws.on_upgrade(move |socket| peer(socket, bcast))
294    }
295
296    async fn peer(ws: WebSocket, bcast: Arc<BroadcastGroup>) {
297        let (sink, stream) = ws.split();
298        let sink = Arc::new(Mutex::new(AxumSink(sink)));
299        let stream = AxumStream(stream);
300        let sub = bcast.subscribe(sink, stream);
301        match sub.completed().await {
302            Ok(_) => println!("broadcasting for channel finished successfully"),
303            Err(e) => eprintln!("broadcasting for channel finished abruptly: {}", e),
304        }
305    }
306
307    struct TungsteniteSink(SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>);
308
309    impl futures_util::Sink<Vec<u8>> for TungsteniteSink {
310        type Error = Error;
311
312        fn poll_ready(
313            mut self: Pin<&mut Self>,
314            cx: &mut Context<'_>,
315        ) -> Poll<Result<(), Self::Error>> {
316            let sink = unsafe { Pin::new_unchecked(&mut self.0) };
317            let result = ready!(sink.poll_ready(cx));
318            match result {
319                Ok(_) => Poll::Ready(Ok(())),
320                Err(e) => Poll::Ready(Err(Error::Other(Box::new(e)))),
321            }
322        }
323
324        fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
325            let sink = unsafe { Pin::new_unchecked(&mut self.0) };
326            let result = sink.start_send(Message::binary(item));
327            match result {
328                Ok(_) => Ok(()),
329                Err(e) => Err(Error::Other(Box::new(e))),
330            }
331        }
332
333        fn poll_flush(
334            mut self: Pin<&mut Self>,
335            cx: &mut Context<'_>,
336        ) -> Poll<Result<(), Self::Error>> {
337            let sink = unsafe { Pin::new_unchecked(&mut self.0) };
338            let result = ready!(sink.poll_flush(cx));
339            match result {
340                Ok(_) => Poll::Ready(Ok(())),
341                Err(e) => Poll::Ready(Err(Error::Other(Box::new(e)))),
342            }
343        }
344
345        fn poll_close(
346            mut self: Pin<&mut Self>,
347            cx: &mut Context<'_>,
348        ) -> Poll<Result<(), Self::Error>> {
349            let sink = unsafe { Pin::new_unchecked(&mut self.0) };
350            let result = ready!(sink.poll_close(cx));
351            match result {
352                Ok(_) => Poll::Ready(Ok(())),
353                Err(e) => Poll::Ready(Err(Error::Other(Box::new(e)))),
354            }
355        }
356    }
357
358    struct TungsteniteStream(SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>);
359    impl Stream for TungsteniteStream {
360        type Item = Result<Vec<u8>, Error>;
361
362        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
363            let stream = unsafe { Pin::new_unchecked(&mut self.0) };
364            let result = ready!(stream.poll_next(cx));
365            match result {
366                None => Poll::Ready(None),
367                Some(Ok(msg)) => Poll::Ready(Some(Ok(msg.into_data()))),
368                Some(Err(e)) => Poll::Ready(Some(Err(Error::Other(Box::new(e))))),
369            }
370        }
371    }
372
373    async fn client(
374        addr: &str,
375        doc: Doc,
376    ) -> Result<Connection<TungsteniteSink, TungsteniteStream>, Box<dyn std::error::Error>> {
377        let (stream, _) = tokio_tungstenite::connect_async(addr).await?;
378        let (sink, stream) = stream.split();
379        let sink = TungsteniteSink(sink);
380        let stream = TungsteniteStream(stream);
381        Ok(Connection::new(
382            Arc::new(RwLock::new(Awareness::new(doc))),
383            sink,
384            stream,
385        ))
386    }
387
388    fn create_notifier(doc: &Doc) -> (Arc<Notify>, Subscription) {
389        let n = Arc::new(Notify::new());
390        let sub = {
391            let n = n.clone();
392            doc.observe_update_v1(move |_, _| n.notify_waiters())
393                .unwrap()
394        };
395        (n, sub)
396    }
397
398    const TIMEOUT: Duration = Duration::from_secs(5);
399
400    #[tokio::test]
401    async fn change_introduced_by_server_reaches_subscribed_clients() {
402        let doc = Doc::with_client_id(1);
403        let text = doc.get_or_insert_text("test");
404        let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
405        let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
406        let _server = start_server("0.0.0.0:16600", Arc::new(bcast)).await.unwrap();
407
408        let doc = Doc::new();
409        let (n, _sub) = create_notifier(&doc);
410        let c1 = client("ws://localhost:16600/my-room", doc).await.unwrap();
411
412        {
413            let lock = awareness.write().await;
414            text.push(&mut lock.doc().transact_mut(), "abc");
415        }
416
417        timeout(TIMEOUT, n.notified()).await.unwrap();
418
419        {
420            let awareness = c1.awareness().read().await;
421            let doc = awareness.doc();
422            let text = doc.get_or_insert_text("test");
423            let str = text.get_string(&doc.transact());
424            assert_eq!(str, "abc".to_string());
425        }
426    }
427
428    #[tokio::test]
429    async fn subscribed_client_fetches_initial_state() {
430        let doc = Doc::with_client_id(1);
431        let text = doc.get_or_insert_text("test");
432
433        text.push(&mut doc.transact_mut(), "abc");
434
435        let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
436        let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
437        let _server = start_server("0.0.0.0:16601", Arc::new(bcast)).await.unwrap();
438
439        let doc = Doc::new();
440        let (n, _sub) = create_notifier(&doc);
441        let c1 = client("ws://localhost:16601/my-room", doc).await.unwrap();
442
443        timeout(TIMEOUT, n.notified()).await.unwrap();
444
445        {
446            let awareness = c1.awareness().read().await;
447            let doc = awareness.doc();
448            let text = doc.get_or_insert_text("test");
449            let str = text.get_string(&doc.transact());
450            assert_eq!(str, "abc".to_string());
451        }
452    }
453
454    #[tokio::test]
455    async fn changes_from_one_client_reach_others() {
456        let doc = Doc::with_client_id(1);
457        let _ = doc.get_or_insert_text("test");
458
459        let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
460        let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
461        let _server = start_server("0.0.0.0:16602", Arc::new(bcast)).await.unwrap();
462
463        let d1 = Doc::with_client_id(2);
464        let c1 = client("ws://localhost:16602/my-room", d1).await.unwrap();
465        // by default changes made by document on the client side are not propagated automatically
466        let _sub11 = {
467            let sink = c1.sink();
468            let a = c1.awareness().write().await;
469            let doc = a.doc();
470            doc.observe_update_v1(move |_, e| {
471                let update = e.update.to_owned();
472                if let Some(sink) = sink.upgrade() {
473                    task::spawn(async move {
474                        let msg = yrs::sync::Message::Sync(yrs::sync::SyncMessage::Update(update))
475                            .encode_v1();
476                        let mut sink = sink.lock().await;
477                        sink.send(msg).await.unwrap();
478                    });
479                }
480            })
481            .unwrap()
482        };
483
484        let d2 = Doc::with_client_id(3);
485        let (n2, _sub2) = create_notifier(&d2);
486        let c2 = client("ws://localhost:16602/my-room", d2).await.unwrap();
487
488        {
489            let a = c1.awareness().write().await;
490            let doc = a.doc();
491            let text = doc.get_or_insert_text("test");
492            text.push(&mut doc.transact_mut(), "def");
493        }
494
495        timeout(TIMEOUT, n2.notified()).await.unwrap();
496
497        {
498            let awareness = c2.awareness().read().await;
499            let doc = awareness.doc();
500            let text = doc.get_or_insert_text("test");
501            let str = text.get_string(&doc.transact());
502            assert_eq!(str, "def".to_string());
503        }
504    }
505
506    #[tokio::test]
507    async fn client_failure_doesnt_affect_others() {
508        let doc = Doc::with_client_id(1);
509        let _text = doc.get_or_insert_text("test");
510
511        let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
512        let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
513        let _server = start_server("0.0.0.0:16603", Arc::new(bcast)).await.unwrap();
514
515        let d1 = Doc::with_client_id(2);
516        let c1 = client("ws://localhost:16603/my-room", d1).await.unwrap();
517        // by default changes made by document on the client side are not propagated automatically
518        let _sub11 = {
519            let sink = c1.sink();
520            let a = c1.awareness().write().await;
521            let doc = a.doc();
522            doc.observe_update_v1(move |_, e| {
523                let update = e.update.to_owned();
524                if let Some(sink) = sink.upgrade() {
525                    task::spawn(async move {
526                        let msg = yrs::sync::Message::Sync(yrs::sync::SyncMessage::Update(update))
527                            .encode_v1();
528                        let mut sink = sink.lock().await;
529                        sink.send(msg).await.unwrap();
530                    });
531                }
532            })
533            .unwrap()
534        };
535
536        let d2 = Doc::with_client_id(3);
537        let (n2, sub2) = create_notifier(&d2);
538        let c2 = client("ws://localhost:16603/my-room", d2).await.unwrap();
539
540        let d3 = Doc::with_client_id(4);
541        let (n3, sub3) = create_notifier(&d3);
542        let c3 = client("ws://localhost:16603/my-room", d3).await.unwrap();
543
544        {
545            let a = c1.awareness().write().await;
546            let doc = a.doc();
547            let text = doc.get_or_insert_text("test");
548            text.push(&mut doc.transact_mut(), "abc");
549        }
550
551        // on the first try both C2 and C3 should receive the update
552        //timeout(TIMEOUT, n2.notified()).await.unwrap();
553        //timeout(TIMEOUT, n3.notified()).await.unwrap();
554        sleep(TIMEOUT).await;
555
556        {
557            let awareness = c2.awareness().read().await;
558            let doc = awareness.doc();
559            let text = doc.get_or_insert_text("test");
560            let str = text.get_string(&doc.transact());
561            assert_eq!(str, "abc".to_string());
562        }
563        {
564            let awareness = c3.awareness().read().await;
565            let doc = awareness.doc();
566            let text = doc.get_or_insert_text("test");
567            let str = text.get_string(&doc.transact());
568            assert_eq!(str, "abc".to_string());
569        }
570
571        // drop client, causing abrupt ending
572        drop(c3);
573        drop(n3);
574        drop(sub3);
575        // C2 notification subscription has been realized, we need to refresh it
576        drop(n2);
577        drop(sub2);
578
579        let (n2, _sub2) = {
580            let a = c2.awareness().write().await;
581            let doc = a.doc();
582            create_notifier(doc)
583        };
584
585        {
586            let a = c1.awareness().write().await;
587            let doc = a.doc();
588            let text = doc.get_or_insert_text("test");
589            text.push(&mut doc.transact_mut(), "def");
590        }
591
592        timeout(TIMEOUT, n2.notified()).await.unwrap();
593
594        {
595            let awareness = c2.awareness().read().await;
596            let doc = awareness.doc();
597            let text = doc.get_or_insert_text("test");
598            let str = text.get_string(&doc.transact());
599            assert_eq!(str, "abcdef".to_string());
600        }
601    }
602}