1#![allow(dead_code)]
2use futures_util::sink::SinkExt;
3use futures_util::StreamExt;
4use std::future::Future;
5use std::marker::PhantomData;
6use std::pin::Pin;
7use std::sync::{Arc, Weak};
8use std::task::{Context, Poll};
9use tokio::spawn;
10use tokio::sync::{Mutex, RwLock};
11use tokio::task::JoinHandle;
12use yrs::encoding::read::Cursor;
13use yrs::sync::Awareness;
14use yrs::sync::{DefaultProtocol, Error, Message, MessageReader, Protocol, SyncMessage};
15use yrs::updates::decoder::{Decode, DecoderV1};
16use yrs::updates::encoder::{Encode, Encoder, EncoderV1};
17use yrs::Update;
18
19#[derive(Debug)]
25pub struct Connection<Sink, Stream> {
26 processing_loop: JoinHandle<Result<(), Error>>,
27 awareness: Arc<RwLock<Awareness>>,
28 inbox: Arc<Mutex<Sink>>,
29 _stream: PhantomData<Stream>,
30}
31
32impl<Sink, Stream, E> Connection<Sink, Stream>
33where
34 Sink: SinkExt<Vec<u8>, Error = E> + Send + Sync + Unpin + 'static,
35 E: Into<Error> + Send + Sync,
36{
37 pub async fn send(&self, msg: Vec<u8>) -> Result<(), Error> {
38 let mut inbox = self.inbox.lock().await;
39 match inbox.send(msg).await {
40 Ok(_) => Ok(()),
41 Err(err) => Err(err.into()),
42 }
43 }
44
45 pub async fn close(self) -> Result<(), E> {
46 let mut inbox = self.inbox.lock().await;
47 inbox.close().await
48 }
49
50 pub fn sink(&self) -> Weak<Mutex<Sink>> {
51 Arc::downgrade(&self.inbox)
52 }
53}
54
55impl<Sink, Stream, E> Connection<Sink, Stream>
56where
57 Stream: StreamExt<Item = Result<Vec<u8>, E>> + Send + Sync + Unpin + 'static,
58 Sink: SinkExt<Vec<u8>, Error = E> + Send + Sync + Unpin + 'static,
59 E: Into<Error> + Send + Sync,
60{
61 pub fn new(awareness: Arc<RwLock<Awareness>>, sink: Sink, stream: Stream) -> Self {
68 Self::with_protocol(awareness, sink, stream, DefaultProtocol)
69 }
70
71 pub fn awareness(&self) -> &Arc<RwLock<Awareness>> {
73 &self.awareness
74 }
75
76 pub fn with_protocol<P>(
83 awareness: Arc<RwLock<Awareness>>,
84 sink: Sink,
85 mut stream: Stream,
86 protocol: P,
87 ) -> Self
88 where
89 P: Protocol + Send + Sync + 'static,
90 {
91 let sink = Arc::new(Mutex::new(sink));
92 let inbox = sink.clone();
93 let loop_sink = Arc::downgrade(&sink);
94 let loop_awareness = Arc::downgrade(&awareness);
95 let processing_loop: JoinHandle<Result<(), Error>> = spawn(async move {
96 let payload = {
98 let awareness = loop_awareness.upgrade().unwrap();
99 let mut encoder = EncoderV1::new();
100 let awareness = awareness.read().await;
101 protocol.start(&awareness, &mut encoder)?;
102 encoder.to_vec()
103 };
104 if !payload.is_empty() {
105 if let Some(sink) = loop_sink.upgrade() {
106 let mut s = sink.lock().await;
107 if let Err(e) = s.send(payload).await {
108 return Err(e.into());
109 }
110 } else {
111 return Ok(()); }
113 }
114
115 while let Some(input) = stream.next().await {
116 match input {
117 Ok(data) => {
118 if let Some(mut sink) = loop_sink.upgrade() {
119 if let Some(awareness) = loop_awareness.upgrade() {
120 match Self::process(&protocol, &awareness, &mut sink, data).await {
121 Ok(()) => { }
122 Err(e) => {
123 return Err(e);
124 }
125 }
126 } else {
127 return Ok(()); }
129 } else {
130 return Ok(()); }
132 }
133 Err(e) => return Err(e.into()),
134 }
135 }
136
137 Ok(())
138 });
139 Connection {
140 processing_loop,
141 awareness,
142 inbox,
143 _stream: PhantomData::default(),
144 }
145 }
146
147 async fn process<P: Protocol>(
148 protocol: &P,
149 awareness: &Arc<RwLock<Awareness>>,
150 sink: &mut Arc<Mutex<Sink>>,
151 input: Vec<u8>,
152 ) -> Result<(), Error> {
153 let mut decoder = DecoderV1::new(Cursor::new(&input));
154 let reader = MessageReader::new(&mut decoder);
155 for r in reader {
156 let msg = r?;
157 if let Some(reply) = handle_msg(protocol, &awareness, msg).await? {
158 let mut sender = sink.lock().await;
159 if let Err(e) = sender.send(reply.encode_v1()).await {
160 println!("connection failed to send back the reply");
161 return Err(e.into());
162 } else {
163 println!("connection send back the reply");
164 }
165 }
166 }
167 Ok(())
168 }
169}
170
171impl<Sink, Stream> Unpin for Connection<Sink, Stream> {}
172
173impl<Sink, Stream> Future for Connection<Sink, Stream> {
174 type Output = Result<(), Error>;
175
176 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
177 match Pin::new(&mut self.processing_loop).poll(cx) {
178 Poll::Pending => Poll::Pending,
179 Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
180 Poll::Ready(Ok(r)) => Poll::Ready(r),
181 }
182 }
183}
184
185pub async fn handle_msg<P: Protocol>(
186 protocol: &P,
187 a: &Arc<RwLock<Awareness>>,
188 msg: Message,
189) -> Result<Option<Message>, Error> {
190 match msg {
191 Message::Sync(msg) => match msg {
192 SyncMessage::SyncStep1(sv) => {
193 let awareness = a.read().await;
194 protocol.handle_sync_step1(&awareness, sv)
195 }
196 SyncMessage::SyncStep2(update) => {
197 let mut awareness = a.write().await;
198 protocol.handle_sync_step2(&mut awareness, Update::decode_v1(&update)?)
199 }
200 SyncMessage::Update(update) => {
201 let mut awareness = a.write().await;
202 protocol.handle_update(&mut awareness, Update::decode_v1(&update)?)
203 }
204 },
205 Message::Auth(reason) => {
206 let awareness = a.read().await;
207 protocol.handle_auth(&awareness, reason)
208 }
209 Message::AwarenessQuery => {
210 let awareness = a.read().await;
211 protocol.handle_awareness_query(&awareness)
212 }
213 Message::Awareness(update) => {
214 let mut awareness = a.write().await;
215 protocol.handle_awareness_update(&mut awareness, update)
216 }
217 Message::Custom(tag, data) => {
218 let mut awareness = a.write().await;
219 protocol.missing_handle(&mut awareness, tag, data)
220 }
221 }
222}
223
224#[cfg(test)]
225mod test {
226 use crate::broadcast::BroadcastGroup;
227 use crate::conn::Connection;
228 use bytes::{Bytes, BytesMut};
229 use futures_util::SinkExt;
230 use std::net::SocketAddr;
231 use std::str::FromStr;
232 use std::sync::Arc;
233 use std::time::Duration;
234 use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
235 use tokio::net::{TcpListener, TcpSocket};
236 use tokio::sync::{Mutex, Notify, RwLock};
237 use tokio::task;
238 use tokio::task::JoinHandle;
239 use tokio::time::{sleep, timeout};
240 use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec};
241 use yrs::sync::{Awareness, Error, Message, SyncMessage};
242 use yrs::updates::encoder::Encode;
243 use yrs::{Doc, GetString, Subscription, Text, Transact};
244
245 #[derive(Debug, Default)]
246 struct YrsCodec(LengthDelimitedCodec);
247
248 impl Encoder<Vec<u8>> for YrsCodec {
249 type Error = Error;
250
251 fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
252 self.0.encode(Bytes::from(item), dst)?;
253 Ok(())
254 }
255 }
256
257 impl Decoder for YrsCodec {
258 type Item = Vec<u8>;
259 type Error = Error;
260
261 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
262 if let Some(bytes) = self.0.decode(src)? {
263 Ok(Some(bytes.freeze().to_vec()))
264 } else {
265 Ok(None)
266 }
267 }
268 }
269
270 type WrappedStream = FramedRead<OwnedReadHalf, YrsCodec>;
271 type WrappedSink = FramedWrite<OwnedWriteHalf, YrsCodec>;
272
273 async fn start_server(
274 addr: SocketAddr,
275 bcast: BroadcastGroup,
276 ) -> Result<JoinHandle<()>, Box<dyn std::error::Error>> {
277 let server = TcpListener::bind(addr).await?;
278 Ok(tokio::spawn(async move {
279 let mut subscribers = Vec::new();
280 while let Ok((stream, _)) = server.accept().await {
281 let (reader, writer) = stream.into_split();
282 let stream = WrappedStream::new(reader, YrsCodec::default());
283 let sink = WrappedSink::new(writer, YrsCodec::default());
284 let sub = bcast.subscribe(Arc::new(Mutex::new(sink)), stream);
285 subscribers.push(sub);
286 }
287 }))
288 }
289
290 async fn client(
291 addr: SocketAddr,
292 doc: Doc,
293 ) -> Result<Connection<WrappedSink, WrappedStream>, Box<dyn std::error::Error>> {
294 let stream = TcpSocket::new_v4()?.connect(addr).await?;
295 let (reader, writer) = stream.into_split();
296 let stream: WrappedStream = WrappedStream::new(reader, YrsCodec::default());
297 let sink: WrappedSink = WrappedSink::new(writer, YrsCodec::default());
298 Ok(Connection::new(
299 Arc::new(RwLock::new(Awareness::new(doc))),
300 sink,
301 stream,
302 ))
303 }
304
305 fn create_notifier(doc: &Doc) -> (Arc<Notify>, Subscription) {
306 let n = Arc::new(Notify::new());
307 let sub = {
308 let n = n.clone();
309 doc.observe_update_v1(move |_, _| n.notify_waiters())
310 .unwrap()
311 };
312 (n, sub)
313 }
314
315 const TIMEOUT: Duration = Duration::from_secs(5);
316
317 #[tokio::test]
318 async fn change_introduced_by_server_reaches_subscribed_clients(
319 ) -> Result<(), Box<dyn std::error::Error>> {
320 let server_addr = SocketAddr::from_str("127.0.0.1:6600").unwrap();
321 let doc = Doc::with_client_id(1);
322 let text = doc.get_or_insert_text("test");
323 let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
324 let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
325 let _server = start_server(server_addr.clone(), bcast).await?;
326
327 let doc = Doc::new();
328 let (n, _sub) = create_notifier(&doc);
329 let c1 = client(server_addr.clone(), doc).await?;
330
331 {
332 let lock = awareness.write().await;
333 text.push(&mut lock.doc().transact_mut(), "abc");
334 }
335
336 timeout(TIMEOUT, n.notified()).await?;
337
338 {
339 let awareness = c1.awareness().read().await;
340 let doc = awareness.doc();
341 let text = doc.get_or_insert_text("test");
342 let str = text.get_string(&doc.transact());
343 assert_eq!(str, "abc".to_string());
344 }
345
346 Ok(())
347 }
348
349 #[tokio::test]
350 async fn subscribed_client_fetches_initial_state() -> Result<(), Box<dyn std::error::Error>> {
351 let server_addr = SocketAddr::from_str("127.0.0.1:6601").unwrap();
352 let doc = Doc::with_client_id(1);
353 let text = doc.get_or_insert_text("test");
354
355 text.push(&mut doc.transact_mut(), "abc");
356
357 let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
358 let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
359 let _server = start_server(server_addr.clone(), bcast).await?;
360
361 let doc = Doc::new();
362 let (n, _sub) = create_notifier(&doc);
363 let c1 = client(server_addr.clone(), doc).await?;
364
365 timeout(TIMEOUT, n.notified()).await?;
366
367 {
368 let awareness = c1.awareness().read().await;
369 let doc = awareness.doc();
370 let text = doc.get_or_insert_text("test");
371 let str = text.get_string(&doc.transact());
372 assert_eq!(str, "abc".to_string());
373 }
374
375 Ok(())
376 }
377
378 #[tokio::test]
379 async fn changes_from_one_client_reach_others() -> Result<(), Box<dyn std::error::Error>> {
380 let server_addr = SocketAddr::from_str("127.0.0.1:6602").unwrap();
381 let doc = Doc::with_client_id(1);
382 let _text = doc.get_or_insert_text("test");
383
384 let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
385 let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
386 let _server = start_server(server_addr.clone(), bcast).await?;
387
388 let d1 = Doc::with_client_id(2);
389 let c1 = client(server_addr.clone(), d1).await?;
390 let _sub11 = {
392 let sink = c1.sink();
393 let a = c1.awareness().write().await;
394 let doc = a.doc();
395 doc.observe_update_v1(move |_, e| {
396 let update = e.update.to_owned();
397 if let Some(sink) = sink.upgrade() {
398 task::spawn(async move {
399 let msg = Message::Sync(SyncMessage::Update(update)).encode_v1();
400 let mut sink = sink.lock().await;
401 sink.send(msg).await.unwrap();
402 });
403 }
404 })
405 .unwrap()
406 };
407
408 let d2 = Doc::with_client_id(3);
409 let (n2, _sub2) = create_notifier(&d2);
410 let c2 = client(server_addr.clone(), d2).await?;
411
412 {
413 let a = c1.awareness().write().await;
414 let doc = a.doc();
415 let text = doc.get_or_insert_text("test");
416 text.push(&mut doc.transact_mut(), "def");
417 }
418
419 timeout(TIMEOUT, n2.notified()).await?;
420
421 {
422 let awareness = c2.awareness.read().await;
423 let doc = awareness.doc();
424 let text = doc.get_or_insert_text("test");
425 let str = text.get_string(&doc.transact());
426 assert_eq!(str, "def".to_string());
427 }
428
429 Ok(())
430 }
431
432 #[tokio::test]
433 async fn client_failure_doesnt_affect_others() -> Result<(), Box<dyn std::error::Error>> {
434 let server_addr = SocketAddr::from_str("127.0.0.1:6604").unwrap();
435 let doc = Doc::with_client_id(1);
436 let _ = doc.get_or_insert_text("test");
437
438 let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
439 let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
440 let _server = start_server(server_addr.clone(), bcast).await?;
441
442 let d1 = Doc::with_client_id(2);
443 let c1 = client(server_addr.clone(), d1).await?;
444 let _sub11 = {
446 let sink = c1.sink();
447 let a = c1.awareness().write().await;
448 let doc = a.doc();
449 doc.observe_update_v1(move |_, e| {
450 let update = e.update.to_owned();
451 if let Some(sink) = sink.upgrade() {
452 task::spawn(async move {
453 let msg = Message::Sync(SyncMessage::Update(update)).encode_v1();
454 let mut sink = sink.lock().await;
455 sink.send(msg).await.unwrap();
456 });
457 }
458 })
459 .unwrap()
460 };
461
462 let d2 = Doc::with_client_id(3);
463 let (n2, sub2) = create_notifier(&d2);
464 let c2 = client(server_addr.clone(), d2).await?;
465
466 let d3 = Doc::with_client_id(4);
467 let (n3, sub3) = create_notifier(&d3);
468 let c3 = client(server_addr.clone(), d3).await?;
469
470 {
471 let a = c1.awareness().write().await;
472 let doc = a.doc();
473 let text = doc.get_or_insert_text("test");
474 text.push(&mut doc.transact_mut(), "abc");
475 }
476
477 sleep(TIMEOUT).await;
481
482 {
483 let awareness = c2.awareness.read().await;
484 let doc = awareness.doc();
485 let text = doc.get_or_insert_text("test");
486 let str = text.get_string(&doc.transact());
487 assert_eq!(str, "abc".to_string());
488 }
489 {
490 let awareness = c3.awareness.read().await;
491 let doc = awareness.doc();
492 let text = doc.get_or_insert_text("test");
493 let str = text.get_string(&doc.transact());
494 assert_eq!(str, "abc".to_string());
495 }
496
497 drop(c3);
499 drop(n3);
500 drop(sub3);
501 drop(n2);
503 drop(sub2);
504
505 let (n2, _sub2) = {
506 let a = c2.awareness().write().await;
507 let doc = a.doc();
508 create_notifier(doc)
509 };
510
511 {
512 let a = c1.awareness().write().await;
513 let doc = a.doc();
514 let text = doc.get_or_insert_text("test");
515 text.push(&mut doc.transact_mut(), "def");
516 }
517
518 timeout(TIMEOUT, n2.notified()).await.unwrap();
519
520 {
521 let awareness = c2.awareness.read().await;
522 let doc = awareness.doc();
523 let text = doc.get_or_insert_text("test");
524 let str = text.get_string(&doc.transact());
525 assert_eq!(str, "abcdef".to_string());
526 }
527
528 Ok(())
529 }
530}