yrs_axum/
conn.rs

1#![allow(dead_code)]
2use futures_util::sink::SinkExt;
3use futures_util::StreamExt;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::sync::{Arc, Weak};
8use std::task::{Context, Poll};
9use tokio::spawn;
10use tokio::sync::{Mutex, RwLock};
11use tokio::task::JoinHandle;
12use yrs::encoding::read::Cursor;
13use yrs::sync::Awareness;
14use yrs::sync::{DefaultProtocol, Error, Message, MessageReader, Protocol, SyncMessage};
15use yrs::updates::decoder::{Decode, DecoderV1};
16use yrs::updates::encoder::{Encode, Encoder, EncoderV1};
17use yrs::Update;
18
19/// Connection handler over a pair of message streams, which implements a Yjs/Yrs awareness and
20/// update exchange protocol.
21///
22/// This connection implements Future pattern and can be awaited upon in order for a caller to
23/// recognize whether underlying websocket connection has been finished gracefully or abruptly.
24#[derive(Debug)]
25pub struct Connection<Sink, Stream> {
26    processing_loop: JoinHandle<Result<(), Error>>,
27    awareness: Arc<RwLock<Awareness>>,
28    inbox: Arc<Mutex<Sink>>,
29    _stream: PhantomData<Stream>,
30}
31
32impl<Sink, Stream, E> Connection<Sink, Stream>
33where
34    Sink: SinkExt<Vec<u8>, Error = E> + Send + Sync + Unpin + 'static,
35    E: Into<Error> + Send + Sync,
36{
37    pub async fn send(&self, msg: Vec<u8>) -> Result<(), Error> {
38        let mut inbox = self.inbox.lock().await;
39        match inbox.send(msg).await {
40            Ok(_) => Ok(()),
41            Err(err) => Err(err.into()),
42        }
43    }
44
45    pub async fn close(self) -> Result<(), E> {
46        let mut inbox = self.inbox.lock().await;
47        inbox.close().await
48    }
49
50    pub fn sink(&self) -> Weak<Mutex<Sink>> {
51        Arc::downgrade(&self.inbox)
52    }
53}
54
55impl<Sink, Stream, E> Connection<Sink, Stream>
56where
57    Stream: StreamExt<Item = Result<Vec<u8>, E>> + Send + Sync + Unpin + 'static,
58    Sink: SinkExt<Vec<u8>, Error = E> + Send + Sync + Unpin + 'static,
59    E: Into<Error> + Send + Sync,
60{
61    /// Wraps incoming [WebSocket] connection and supplied [Awareness] accessor into a new
62    /// connection handler capable of exchanging Yrs/Yjs messages.
63    ///
64    /// While creation of new [AxumConn] always succeeds, a connection itself can possibly fail
65    /// while processing incoming input/output. This can be detected by awaiting for returned
66    /// [AxumConn] and handling the awaited result.
67    pub fn new(awareness: Arc<RwLock<Awareness>>, sink: Sink, stream: Stream) -> Self {
68        Self::with_protocol(awareness, sink, stream, DefaultProtocol)
69    }
70
71    /// Returns an underlying [Awareness] structure, that contains client state of that connection.
72    pub fn awareness(&self) -> &Arc<RwLock<Awareness>> {
73        &self.awareness
74    }
75
76    /// Wraps incoming [WebSocket] connection and supplied [Awareness] accessor into a new
77    /// connection handler capable of exchanging Yrs/Yjs messages.
78    ///
79    /// While creation of new [AxumConn] always succeeds, a connection itself can possibly fail
80    /// while processing incoming input/output. This can be detected by awaiting for returned
81    /// [AxumConn] and handling the awaited result.
82    pub fn with_protocol<P>(
83        awareness: Arc<RwLock<Awareness>>,
84        sink: Sink,
85        mut stream: Stream,
86        protocol: P,
87    ) -> Self
88    where
89        P: Protocol + Send + Sync + 'static,
90    {
91        let sink = Arc::new(Mutex::new(sink));
92        let inbox = sink.clone();
93        let loop_sink = Arc::downgrade(&sink);
94        let loop_awareness = Arc::downgrade(&awareness);
95        let processing_loop: JoinHandle<Result<(), Error>> = spawn(async move {
96            // at the beginning send SyncStep1 and AwarenessUpdate
97            let payload = {
98                let awareness = loop_awareness.upgrade().unwrap();
99                let mut encoder = EncoderV1::new();
100                let awareness = awareness.read().await;
101                protocol.start(&awareness, &mut encoder)?;
102                encoder.to_vec()
103            };
104            if !payload.is_empty() {
105                if let Some(sink) = loop_sink.upgrade() {
106                    let mut s = sink.lock().await;
107                    if let Err(e) = s.send(payload).await {
108                        return Err(e.into());
109                    }
110                } else {
111                    return Ok(()); // parent ConnHandler has been dropped
112                }
113            }
114
115            while let Some(input) = stream.next().await {
116                match input {
117                    Ok(data) => {
118                        if let Some(mut sink) = loop_sink.upgrade() {
119                            if let Some(awareness) = loop_awareness.upgrade() {
120                                match Self::process(&protocol, &awareness, &mut sink, data).await {
121                                    Ok(()) => { /* continue */ }
122                                    Err(e) => {
123                                        return Err(e);
124                                    }
125                                }
126                            } else {
127                                return Ok(()); // parent ConnHandler has been dropped
128                            }
129                        } else {
130                            return Ok(()); // parent ConnHandler has been dropped
131                        }
132                    }
133                    Err(e) => return Err(e.into()),
134                }
135            }
136
137            Ok(())
138        });
139        Connection {
140            processing_loop,
141            awareness,
142            inbox,
143            _stream: PhantomData::default(),
144        }
145    }
146
147    async fn process<P: Protocol>(
148        protocol: &P,
149        awareness: &Arc<RwLock<Awareness>>,
150        sink: &mut Arc<Mutex<Sink>>,
151        input: Vec<u8>,
152    ) -> Result<(), Error> {
153        let mut decoder = DecoderV1::new(Cursor::new(&input));
154        let reader = MessageReader::new(&mut decoder);
155        for r in reader {
156            let msg = r?;
157            if let Some(reply) = handle_msg(protocol, &awareness, msg).await? {
158                let mut sender = sink.lock().await;
159                if let Err(e) = sender.send(reply.encode_v1()).await {
160                    println!("connection failed to send back the reply");
161                    return Err(e.into());
162                } else {
163                    println!("connection send back the reply");
164                }
165            }
166        }
167        Ok(())
168    }
169}
170
171impl<Sink, Stream> Unpin for Connection<Sink, Stream> {}
172
173impl<Sink, Stream> Future for Connection<Sink, Stream> {
174    type Output = Result<(), Error>;
175
176    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
177        match Pin::new(&mut self.processing_loop).poll(cx) {
178            Poll::Pending => Poll::Pending,
179            Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
180            Poll::Ready(Ok(r)) => Poll::Ready(r),
181        }
182    }
183}
184
185pub async fn handle_msg<P: Protocol>(
186    protocol: &P,
187    a: &Arc<RwLock<Awareness>>,
188    msg: Message,
189) -> Result<Option<Message>, Error> {
190    match msg {
191        Message::Sync(msg) => match msg {
192            SyncMessage::SyncStep1(sv) => {
193                let awareness = a.read().await;
194                protocol.handle_sync_step1(&awareness, sv)
195            }
196            SyncMessage::SyncStep2(update) => {
197                let mut awareness = a.write().await;
198                protocol.handle_sync_step2(&mut awareness, Update::decode_v1(&update)?)
199            }
200            SyncMessage::Update(update) => {
201                let mut awareness = a.write().await;
202                protocol.handle_update(&mut awareness, Update::decode_v1(&update)?)
203            }
204        },
205        Message::Auth(reason) => {
206            let awareness = a.read().await;
207            protocol.handle_auth(&awareness, reason)
208        }
209        Message::AwarenessQuery => {
210            let awareness = a.read().await;
211            protocol.handle_awareness_query(&awareness)
212        }
213        Message::Awareness(update) => {
214            let mut awareness = a.write().await;
215            protocol.handle_awareness_update(&mut awareness, update)
216        }
217        Message::Custom(tag, data) => {
218            let mut awareness = a.write().await;
219            protocol.missing_handle(&mut awareness, tag, data)
220        }
221    }
222}
223
224#[cfg(test)]
225mod test {
226    use crate::broadcast::BroadcastGroup;
227    use crate::conn::Connection;
228    use bytes::{Bytes, BytesMut};
229    use futures_util::SinkExt;
230    use std::net::SocketAddr;
231    use std::str::FromStr;
232    use std::sync::Arc;
233    use std::time::Duration;
234    use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
235    use tokio::net::{TcpListener, TcpSocket};
236    use tokio::sync::{Mutex, Notify, RwLock};
237    use tokio::task;
238    use tokio::task::JoinHandle;
239    use tokio::time::{sleep, timeout};
240    use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec};
241    use yrs::sync::{Awareness, Error, Message, SyncMessage};
242    use yrs::updates::encoder::Encode;
243    use yrs::{Doc, GetString, Subscription, Text, Transact};
244
245    #[derive(Debug, Default)]
246    struct YrsCodec(LengthDelimitedCodec);
247
248    impl Encoder<Vec<u8>> for YrsCodec {
249        type Error = Error;
250
251        fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
252            self.0.encode(Bytes::from(item), dst)?;
253            Ok(())
254        }
255    }
256
257    impl Decoder for YrsCodec {
258        type Item = Vec<u8>;
259        type Error = Error;
260
261        fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
262            if let Some(bytes) = self.0.decode(src)? {
263                Ok(Some(bytes.freeze().to_vec()))
264            } else {
265                Ok(None)
266            }
267        }
268    }
269
270    type WrappedStream = FramedRead<OwnedReadHalf, YrsCodec>;
271    type WrappedSink = FramedWrite<OwnedWriteHalf, YrsCodec>;
272
273    async fn start_server(
274        addr: SocketAddr,
275        bcast: BroadcastGroup,
276    ) -> Result<JoinHandle<()>, Box<dyn std::error::Error>> {
277        let server = TcpListener::bind(addr).await?;
278        Ok(tokio::spawn(async move {
279            let mut subscribers = Vec::new();
280            while let Ok((stream, _)) = server.accept().await {
281                let (reader, writer) = stream.into_split();
282                let stream = WrappedStream::new(reader, YrsCodec::default());
283                let sink = WrappedSink::new(writer, YrsCodec::default());
284                let sub = bcast.subscribe(Arc::new(Mutex::new(sink)), stream);
285                subscribers.push(sub);
286            }
287        }))
288    }
289
290    async fn client(
291        addr: SocketAddr,
292        doc: Doc,
293    ) -> Result<Connection<WrappedSink, WrappedStream>, Box<dyn std::error::Error>> {
294        let stream = TcpSocket::new_v4()?.connect(addr).await?;
295        let (reader, writer) = stream.into_split();
296        let stream: WrappedStream = WrappedStream::new(reader, YrsCodec::default());
297        let sink: WrappedSink = WrappedSink::new(writer, YrsCodec::default());
298        Ok(Connection::new(
299            Arc::new(RwLock::new(Awareness::new(doc))),
300            sink,
301            stream,
302        ))
303    }
304
305    fn create_notifier(doc: &Doc) -> (Arc<Notify>, Subscription) {
306        let n = Arc::new(Notify::new());
307        let sub = {
308            let n = n.clone();
309            doc.observe_update_v1(move |_, _| n.notify_waiters())
310                .unwrap()
311        };
312        (n, sub)
313    }
314
315    const TIMEOUT: Duration = Duration::from_secs(5);
316
317    #[tokio::test]
318    async fn change_introduced_by_server_reaches_subscribed_clients(
319    ) -> Result<(), Box<dyn std::error::Error>> {
320        let server_addr = SocketAddr::from_str("127.0.0.1:6600").unwrap();
321        let doc = Doc::with_client_id(1);
322        let text = doc.get_or_insert_text("test");
323        let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
324        let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
325        let _server = start_server(server_addr.clone(), bcast).await?;
326
327        let doc = Doc::new();
328        let (n, _sub) = create_notifier(&doc);
329        let c1 = client(server_addr.clone(), doc).await?;
330
331        {
332            let lock = awareness.write().await;
333            text.push(&mut lock.doc().transact_mut(), "abc");
334        }
335
336        timeout(TIMEOUT, n.notified()).await?;
337
338        {
339            let awareness = c1.awareness().read().await;
340            let doc = awareness.doc();
341            let text = doc.get_or_insert_text("test");
342            let str = text.get_string(&doc.transact());
343            assert_eq!(str, "abc".to_string());
344        }
345
346        Ok(())
347    }
348
349    #[tokio::test]
350    async fn subscribed_client_fetches_initial_state() -> Result<(), Box<dyn std::error::Error>> {
351        let server_addr = SocketAddr::from_str("127.0.0.1:6601").unwrap();
352        let doc = Doc::with_client_id(1);
353        let text = doc.get_or_insert_text("test");
354
355        text.push(&mut doc.transact_mut(), "abc");
356
357        let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
358        let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
359        let _server = start_server(server_addr.clone(), bcast).await?;
360
361        let doc = Doc::new();
362        let (n, _sub) = create_notifier(&doc);
363        let c1 = client(server_addr.clone(), doc).await?;
364
365        timeout(TIMEOUT, n.notified()).await?;
366
367        {
368            let awareness = c1.awareness().read().await;
369            let doc = awareness.doc();
370            let text = doc.get_or_insert_text("test");
371            let str = text.get_string(&doc.transact());
372            assert_eq!(str, "abc".to_string());
373        }
374
375        Ok(())
376    }
377
378    #[tokio::test]
379    async fn changes_from_one_client_reach_others() -> Result<(), Box<dyn std::error::Error>> {
380        let server_addr = SocketAddr::from_str("127.0.0.1:6602").unwrap();
381        let doc = Doc::with_client_id(1);
382        let _text = doc.get_or_insert_text("test");
383
384        let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
385        let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
386        let _server = start_server(server_addr.clone(), bcast).await?;
387
388        let d1 = Doc::with_client_id(2);
389        let c1 = client(server_addr.clone(), d1).await?;
390        // by default changes made by document on the client side are not propagated automatically
391        let _sub11 = {
392            let sink = c1.sink();
393            let a = c1.awareness().write().await;
394            let doc = a.doc();
395            doc.observe_update_v1(move |_, e| {
396                let update = e.update.to_owned();
397                if let Some(sink) = sink.upgrade() {
398                    task::spawn(async move {
399                        let msg = Message::Sync(SyncMessage::Update(update)).encode_v1();
400                        let mut sink = sink.lock().await;
401                        sink.send(msg).await.unwrap();
402                    });
403                }
404            })
405            .unwrap()
406        };
407
408        let d2 = Doc::with_client_id(3);
409        let (n2, _sub2) = create_notifier(&d2);
410        let c2 = client(server_addr.clone(), d2).await?;
411
412        {
413            let a = c1.awareness().write().await;
414            let doc = a.doc();
415            let text = doc.get_or_insert_text("test");
416            text.push(&mut doc.transact_mut(), "def");
417        }
418
419        timeout(TIMEOUT, n2.notified()).await?;
420
421        {
422            let awareness = c2.awareness.read().await;
423            let doc = awareness.doc();
424            let text = doc.get_or_insert_text("test");
425            let str = text.get_string(&doc.transact());
426            assert_eq!(str, "def".to_string());
427        }
428
429        Ok(())
430    }
431
432    #[tokio::test]
433    async fn client_failure_doesnt_affect_others() -> Result<(), Box<dyn std::error::Error>> {
434        let server_addr = SocketAddr::from_str("127.0.0.1:6604").unwrap();
435        let doc = Doc::with_client_id(1);
436        let _ = doc.get_or_insert_text("test");
437
438        let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
439        let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
440        let _server = start_server(server_addr.clone(), bcast).await?;
441
442        let d1 = Doc::with_client_id(2);
443        let c1 = client(server_addr.clone(), d1).await?;
444        // by default changes made by document on the client side are not propagated automatically
445        let _sub11 = {
446            let sink = c1.sink();
447            let a = c1.awareness().write().await;
448            let doc = a.doc();
449            doc.observe_update_v1(move |_, e| {
450                let update = e.update.to_owned();
451                if let Some(sink) = sink.upgrade() {
452                    task::spawn(async move {
453                        let msg = Message::Sync(SyncMessage::Update(update)).encode_v1();
454                        let mut sink = sink.lock().await;
455                        sink.send(msg).await.unwrap();
456                    });
457                }
458            })
459            .unwrap()
460        };
461
462        let d2 = Doc::with_client_id(3);
463        let (n2, sub2) = create_notifier(&d2);
464        let c2 = client(server_addr.clone(), d2).await?;
465
466        let d3 = Doc::with_client_id(4);
467        let (n3, sub3) = create_notifier(&d3);
468        let c3 = client(server_addr.clone(), d3).await?;
469
470        {
471            let a = c1.awareness().write().await;
472            let doc = a.doc();
473            let text = doc.get_or_insert_text("test");
474            text.push(&mut doc.transact_mut(), "abc");
475        }
476
477        // on the first try both C2 and C3 should receive the update
478        //timeout(TIMEOUT, n2.notified()).await.unwrap();
479        //timeout(TIMEOUT, n3.notified()).await.unwrap();
480        sleep(TIMEOUT).await;
481
482        {
483            let awareness = c2.awareness.read().await;
484            let doc = awareness.doc();
485            let text = doc.get_or_insert_text("test");
486            let str = text.get_string(&doc.transact());
487            assert_eq!(str, "abc".to_string());
488        }
489        {
490            let awareness = c3.awareness.read().await;
491            let doc = awareness.doc();
492            let text = doc.get_or_insert_text("test");
493            let str = text.get_string(&doc.transact());
494            assert_eq!(str, "abc".to_string());
495        }
496
497        // drop client, causing abrupt ending
498        drop(c3);
499        drop(n3);
500        drop(sub3);
501        // C2 notification subscription has been realized, we need to refresh it
502        drop(n2);
503        drop(sub2);
504
505        let (n2, _sub2) = {
506            let a = c2.awareness().write().await;
507            let doc = a.doc();
508            create_notifier(doc)
509        };
510
511        {
512            let a = c1.awareness().write().await;
513            let doc = a.doc();
514            let text = doc.get_or_insert_text("test");
515            text.push(&mut doc.transact_mut(), "def");
516        }
517
518        timeout(TIMEOUT, n2.notified()).await.unwrap();
519
520        {
521            let awareness = c2.awareness.read().await;
522            let doc = awareness.doc();
523            let text = doc.get_or_insert_text("test");
524            let str = text.get_string(&doc.transact());
525            assert_eq!(str, "abcdef".to_string());
526        }
527
528        Ok(())
529    }
530}