1use crate::conn::Connection;
2use crate::AwarenessRef;
3use futures_util::stream::{SplitSink, SplitStream};
4use futures_util::{Stream, StreamExt};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use axum::extract::ws::{WebSocket, Message};
8use yrs::sync::Error;
9
10#[repr(transparent)]
16#[derive(Debug)]
17pub struct AxumConn(Connection<AxumSink, AxumStream>);
18
19impl AxumConn {
20 pub fn new(awareness: AwarenessRef, socket: WebSocket) -> Self {
21 let (sink, stream) = socket.split();
22 let conn = Connection::new(awareness, AxumSink(sink), AxumStream(stream));
23 AxumConn(conn)
24 }
25}
26
27impl core::future::Future for AxumConn {
28 type Output = Result<(), Error>;
29
30 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
31 match Pin::new(&mut self.0).poll(cx) {
32 Poll::Pending => Poll::Pending,
33 Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
34 Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
35 }
36 }
37}
38
39#[repr(transparent)]
101#[derive(Debug)]
102pub struct AxumSink(pub SplitSink<WebSocket, Message>);
103
104impl From<SplitSink<WebSocket, Message>> for AxumSink {
105 fn from(sink: SplitSink<WebSocket, Message>) -> Self {
106 AxumSink(sink)
107 }
108}
109
110impl Into<SplitSink<WebSocket, Message>> for AxumSink {
111 fn into(self) -> SplitSink<WebSocket, Message> {
112 self.0
113 }
114}
115
116impl futures_util::Sink<Vec<u8>> for AxumSink {
117 type Error = Error;
118
119 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 match Pin::new(&mut self.0).poll_ready(cx) {
121 Poll::Pending => Poll::Pending,
122 Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
123 Poll::Ready(_) => Poll::Ready(Ok(())),
124 }
125 }
126
127 fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
128 if let Err(e) = Pin::new(&mut self.0).start_send(Message::binary(item)) {
129 Err(Error::Other(e.into()))
130 } else {
131 Ok(())
132 }
133 }
134
135 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
136 match Pin::new(&mut self.0).poll_flush(cx) {
137 Poll::Pending => Poll::Pending,
138 Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
139 Poll::Ready(_) => Poll::Ready(Ok(())),
140 }
141 }
142
143 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
144 match Pin::new(&mut self.0).poll_close(cx) {
145 Poll::Pending => Poll::Pending,
146 Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Other(e.into()))),
147 Poll::Ready(_) => Poll::Ready(Ok(())),
148 }
149 }
150}
151
152#[derive(Debug)]
214pub struct AxumStream(pub SplitStream<WebSocket>);
215
216impl From<SplitStream<WebSocket>> for AxumStream {
217 fn from(stream: SplitStream<WebSocket>) -> Self {
218 AxumStream(stream)
219 }
220}
221
222impl Into<SplitStream<WebSocket>> for AxumStream {
223 fn into(self) -> SplitStream<WebSocket> {
224 self.0
225 }
226}
227
228impl Stream for AxumStream {
229 type Item = Result<Vec<u8>, Error>;
230
231 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
232 match Pin::new(&mut self.0).poll_next(cx) {
233 Poll::Pending => Poll::Pending,
234 Poll::Ready(None) => Poll::Ready(None),
235 Poll::Ready(Some(res)) => match res {
236 Ok(item) => Poll::Ready(Some(Ok(item.into_data().to_vec()))),
237 Err(e) => Poll::Ready(Some(Err(Error::Other(e.into())))),
238 },
239 }
240 }
241}
242
243#[cfg(test)]
244mod test {
245 use crate::broadcast::BroadcastGroup;
246 use crate::conn::Connection;
247 use crate::ws::{AxumSink, AxumStream};
248 use futures_util::stream::{SplitSink, SplitStream};
249 use futures_util::{ready, SinkExt, Stream, StreamExt};
250 use std::pin::Pin;
251 use std::sync::Arc;
252 use std::task::{Context, Poll};
253 use std::time::Duration;
254 use tokio::net::TcpStream;
255 use tokio::sync::{Mutex, Notify, RwLock};
256 use tokio::task;
257 use tokio::task::JoinHandle;
258 use tokio::time::{sleep, timeout};
259 use tokio_tungstenite::tungstenite::Message;
260 use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
261 use axum::{
262 Router,
263 routing::get,
264 extract::ws::{WebSocket, WebSocketUpgrade},
265 extract::State,
266 response::IntoResponse,
267 };
268 use yrs::sync::{Awareness, Error};
269 use yrs::updates::encoder::Encode;
270 use yrs::{Doc, GetString, Subscription, Text, Transact};
271
272 async fn start_server(
273 addr: &str,
274 bcast: Arc<BroadcastGroup>,
275 ) -> Result<JoinHandle<()>, Box<dyn std::error::Error>> {
276 let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
277
278 let app = Router::new()
279 .route("/my-room", get(ws_handler))
280 .with_state(bcast);
281
282 Ok(tokio::spawn(async move {
283 axum::serve(listener, app.into_make_service())
284 .await
285 .unwrap();
286 }))
287 }
288
289 async fn ws_handler(
290 ws: WebSocketUpgrade,
291 State(bcast): State<Arc<BroadcastGroup>>,
292 ) -> impl IntoResponse {
293 ws.on_upgrade(move |socket| peer(socket, bcast))
294 }
295
296 async fn peer(ws: WebSocket, bcast: Arc<BroadcastGroup>) {
297 let (sink, stream) = ws.split();
298 let sink = Arc::new(Mutex::new(AxumSink(sink)));
299 let stream = AxumStream(stream);
300 let sub = bcast.subscribe(sink, stream);
301 match sub.completed().await {
302 Ok(_) => println!("broadcasting for channel finished successfully"),
303 Err(e) => eprintln!("broadcasting for channel finished abruptly: {}", e),
304 }
305 }
306
307 struct TungsteniteSink(SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>);
308
309 impl futures_util::Sink<Vec<u8>> for TungsteniteSink {
310 type Error = Error;
311
312 fn poll_ready(
313 mut self: Pin<&mut Self>,
314 cx: &mut Context<'_>,
315 ) -> Poll<Result<(), Self::Error>> {
316 let sink = unsafe { Pin::new_unchecked(&mut self.0) };
317 let result = ready!(sink.poll_ready(cx));
318 match result {
319 Ok(_) => Poll::Ready(Ok(())),
320 Err(e) => Poll::Ready(Err(Error::Other(Box::new(e)))),
321 }
322 }
323
324 fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
325 let sink = unsafe { Pin::new_unchecked(&mut self.0) };
326 let result = sink.start_send(Message::binary(item));
327 match result {
328 Ok(_) => Ok(()),
329 Err(e) => Err(Error::Other(Box::new(e))),
330 }
331 }
332
333 fn poll_flush(
334 mut self: Pin<&mut Self>,
335 cx: &mut Context<'_>,
336 ) -> Poll<Result<(), Self::Error>> {
337 let sink = unsafe { Pin::new_unchecked(&mut self.0) };
338 let result = ready!(sink.poll_flush(cx));
339 match result {
340 Ok(_) => Poll::Ready(Ok(())),
341 Err(e) => Poll::Ready(Err(Error::Other(Box::new(e)))),
342 }
343 }
344
345 fn poll_close(
346 mut self: Pin<&mut Self>,
347 cx: &mut Context<'_>,
348 ) -> Poll<Result<(), Self::Error>> {
349 let sink = unsafe { Pin::new_unchecked(&mut self.0) };
350 let result = ready!(sink.poll_close(cx));
351 match result {
352 Ok(_) => Poll::Ready(Ok(())),
353 Err(e) => Poll::Ready(Err(Error::Other(Box::new(e)))),
354 }
355 }
356 }
357
358 struct TungsteniteStream(SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>);
359 impl Stream for TungsteniteStream {
360 type Item = Result<Vec<u8>, Error>;
361
362 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
363 let stream = unsafe { Pin::new_unchecked(&mut self.0) };
364 let result = ready!(stream.poll_next(cx));
365 match result {
366 None => Poll::Ready(None),
367 Some(Ok(msg)) => Poll::Ready(Some(Ok(msg.into_data()))),
368 Some(Err(e)) => Poll::Ready(Some(Err(Error::Other(Box::new(e))))),
369 }
370 }
371 }
372
373 async fn client(
374 addr: &str,
375 doc: Doc,
376 ) -> Result<Connection<TungsteniteSink, TungsteniteStream>, Box<dyn std::error::Error>> {
377 let (stream, _) = tokio_tungstenite::connect_async(addr).await?;
378 let (sink, stream) = stream.split();
379 let sink = TungsteniteSink(sink);
380 let stream = TungsteniteStream(stream);
381 Ok(Connection::new(
382 Arc::new(RwLock::new(Awareness::new(doc))),
383 sink,
384 stream,
385 ))
386 }
387
388 fn create_notifier(doc: &Doc) -> (Arc<Notify>, Subscription) {
389 let n = Arc::new(Notify::new());
390 let sub = {
391 let n = n.clone();
392 doc.observe_update_v1(move |_, _| n.notify_waiters())
393 .unwrap()
394 };
395 (n, sub)
396 }
397
398 const TIMEOUT: Duration = Duration::from_secs(5);
399
400 #[tokio::test]
401 async fn change_introduced_by_server_reaches_subscribed_clients() {
402 let doc = Doc::with_client_id(1);
403 let text = doc.get_or_insert_text("test");
404 let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
405 let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
406 let _server = start_server("0.0.0.0:16600", Arc::new(bcast)).await.unwrap();
407
408 let doc = Doc::new();
409 let (n, _sub) = create_notifier(&doc);
410 let c1 = client("ws://localhost:16600/my-room", doc).await.unwrap();
411
412 {
413 let lock = awareness.write().await;
414 text.push(&mut lock.doc().transact_mut(), "abc");
415 }
416
417 timeout(TIMEOUT, n.notified()).await.unwrap();
418
419 {
420 let awareness = c1.awareness().read().await;
421 let doc = awareness.doc();
422 let text = doc.get_or_insert_text("test");
423 let str = text.get_string(&doc.transact());
424 assert_eq!(str, "abc".to_string());
425 }
426 }
427
428 #[tokio::test]
429 async fn subscribed_client_fetches_initial_state() {
430 let doc = Doc::with_client_id(1);
431 let text = doc.get_or_insert_text("test");
432
433 text.push(&mut doc.transact_mut(), "abc");
434
435 let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
436 let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
437 let _server = start_server("0.0.0.0:16601", Arc::new(bcast)).await.unwrap();
438
439 let doc = Doc::new();
440 let (n, _sub) = create_notifier(&doc);
441 let c1 = client("ws://localhost:16601/my-room", doc).await.unwrap();
442
443 timeout(TIMEOUT, n.notified()).await.unwrap();
444
445 {
446 let awareness = c1.awareness().read().await;
447 let doc = awareness.doc();
448 let text = doc.get_or_insert_text("test");
449 let str = text.get_string(&doc.transact());
450 assert_eq!(str, "abc".to_string());
451 }
452 }
453
454 #[tokio::test]
455 async fn changes_from_one_client_reach_others() {
456 let doc = Doc::with_client_id(1);
457 let _ = doc.get_or_insert_text("test");
458
459 let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
460 let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
461 let _server = start_server("0.0.0.0:16602", Arc::new(bcast)).await.unwrap();
462
463 let d1 = Doc::with_client_id(2);
464 let c1 = client("ws://localhost:16602/my-room", d1).await.unwrap();
465 let _sub11 = {
467 let sink = c1.sink();
468 let a = c1.awareness().write().await;
469 let doc = a.doc();
470 doc.observe_update_v1(move |_, e| {
471 let update = e.update.to_owned();
472 if let Some(sink) = sink.upgrade() {
473 task::spawn(async move {
474 let msg = yrs::sync::Message::Sync(yrs::sync::SyncMessage::Update(update))
475 .encode_v1();
476 let mut sink = sink.lock().await;
477 sink.send(msg).await.unwrap();
478 });
479 }
480 })
481 .unwrap()
482 };
483
484 let d2 = Doc::with_client_id(3);
485 let (n2, _sub2) = create_notifier(&d2);
486 let c2 = client("ws://localhost:16602/my-room", d2).await.unwrap();
487
488 {
489 let a = c1.awareness().write().await;
490 let doc = a.doc();
491 let text = doc.get_or_insert_text("test");
492 text.push(&mut doc.transact_mut(), "def");
493 }
494
495 timeout(TIMEOUT, n2.notified()).await.unwrap();
496
497 {
498 let awareness = c2.awareness().read().await;
499 let doc = awareness.doc();
500 let text = doc.get_or_insert_text("test");
501 let str = text.get_string(&doc.transact());
502 assert_eq!(str, "def".to_string());
503 }
504 }
505
506 #[tokio::test]
507 async fn client_failure_doesnt_affect_others() {
508 let doc = Doc::with_client_id(1);
509 let _text = doc.get_or_insert_text("test");
510
511 let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
512 let bcast = BroadcastGroup::new(awareness.clone(), 10).await;
513 let _server = start_server("0.0.0.0:16603", Arc::new(bcast)).await.unwrap();
514
515 let d1 = Doc::with_client_id(2);
516 let c1 = client("ws://localhost:16603/my-room", d1).await.unwrap();
517 let _sub11 = {
519 let sink = c1.sink();
520 let a = c1.awareness().write().await;
521 let doc = a.doc();
522 doc.observe_update_v1(move |_, e| {
523 let update = e.update.to_owned();
524 if let Some(sink) = sink.upgrade() {
525 task::spawn(async move {
526 let msg = yrs::sync::Message::Sync(yrs::sync::SyncMessage::Update(update))
527 .encode_v1();
528 let mut sink = sink.lock().await;
529 sink.send(msg).await.unwrap();
530 });
531 }
532 })
533 .unwrap()
534 };
535
536 let d2 = Doc::with_client_id(3);
537 let (n2, sub2) = create_notifier(&d2);
538 let c2 = client("ws://localhost:16603/my-room", d2).await.unwrap();
539
540 let d3 = Doc::with_client_id(4);
541 let (n3, sub3) = create_notifier(&d3);
542 let c3 = client("ws://localhost:16603/my-room", d3).await.unwrap();
543
544 {
545 let a = c1.awareness().write().await;
546 let doc = a.doc();
547 let text = doc.get_or_insert_text("test");
548 text.push(&mut doc.transact_mut(), "abc");
549 }
550
551 sleep(TIMEOUT).await;
555
556 {
557 let awareness = c2.awareness().read().await;
558 let doc = awareness.doc();
559 let text = doc.get_or_insert_text("test");
560 let str = text.get_string(&doc.transact());
561 assert_eq!(str, "abc".to_string());
562 }
563 {
564 let awareness = c3.awareness().read().await;
565 let doc = awareness.doc();
566 let text = doc.get_or_insert_text("test");
567 let str = text.get_string(&doc.transact());
568 assert_eq!(str, "abc".to_string());
569 }
570
571 drop(c3);
573 drop(n3);
574 drop(sub3);
575 drop(n2);
577 drop(sub2);
578
579 let (n2, _sub2) = {
580 let a = c2.awareness().write().await;
581 let doc = a.doc();
582 create_notifier(doc)
583 };
584
585 {
586 let a = c1.awareness().write().await;
587 let doc = a.doc();
588 let text = doc.get_or_insert_text("test");
589 text.push(&mut doc.transact_mut(), "def");
590 }
591
592 timeout(TIMEOUT, n2.notified()).await.unwrap();
593
594 {
595 let awareness = c2.awareness().read().await;
596 let doc = awareness.doc();
597 let text = doc.get_or_insert_text("test");
598 let str = text.get_string(&doc.transact());
599 assert_eq!(str, "abcdef".to_string());
600 }
601 }
602}