yrs_rocket_ws/
lib.rs

1use futures_util::Sink;
2use futures_util::stream::{SplitSink, SplitStream};
3use rocket_ws::Message;
4use rocket_ws::stream::DuplexStream;
5use yrs_tokio::signaling::Message as SignalingMessage;
6use yrs_tokio::{
7    YrsExchange, YrsSink, YrsStream, impl_yrs_signal_stream, to_signaling_message, yrs_common_sink,
8};
9
10#[derive(YrsStream)]
11pub struct YrsStream(SplitStream<DuplexStream>);
12#[derive(YrsExchange)]
13pub struct YrsSignalStream(SplitStream<DuplexStream>);
14
15impl_yrs_signal_stream!(YrsSignalStream, item => to_signaling_message!(item, custom Message::Frame(frame) => SignalingMessage::Binary(frame.into_data())));
16#[derive(YrsSink)]
17pub struct YrsSink(SplitSink<DuplexStream, Message>);
18#[yrs_common_sink]
19impl Sink<SignalingMessage> for YrsSink {}
20
21#[cfg(test)]
22mod test {
23    use crate::{YrsSink, YrsStream};
24    use futures_util::{SinkExt, ready};
25    use rocket::{State, get, routes};
26    use rocket_ws::stream::DuplexStream;
27    use rocket_ws::{Channel, WebSocket};
28    use std::net::SocketAddr;
29    use std::str::FromStr;
30    use std::sync::Arc;
31    use tokio::sync::Mutex;
32    use tokio::task::JoinHandle;
33    use yrs::updates::encoder::Encode;
34    use yrs::{GetString, Text, Transact};
35    use yrs_tokio::broadcast::BroadcastGroup;
36    use yrs_tokio::yrs_common_test;
37
38    #[get("/my-room")]
39    fn ws_handler(ws: WebSocket, bcast: &State<Arc<BroadcastGroup>>) -> Channel<'_> {
40        let bcast = bcast.inner();
41
42        ws.channel(move |stream| {
43            Box::pin(async move {
44                peer(stream, bcast).await;
45
46                Ok(())
47            })
48        })
49    }
50
51    async fn peer(stream: DuplexStream, bcast: &Arc<BroadcastGroup>) {
52        use rocket::futures::StreamExt;
53        let (sink, stream) = stream.split();
54        let sink = Arc::new(Mutex::new(YrsSink::from(sink)));
55        let stream = YrsStream::from(stream);
56
57        let sub = bcast.subscribe(sink, stream);
58        match sub.completed().await {
59            Ok(_) => println!("broadcasting for channel finished successfully"),
60            Err(e) => eprintln!("broadcasting for channel finished abruptly: {}", e),
61        }
62    }
63
64    #[yrs_common_test]
65    async fn start_server(
66        addr: &str,
67        bcast: Arc<BroadcastGroup>,
68    ) -> Result<JoinHandle<()>, Box<dyn std::error::Error>> {
69        let addr = SocketAddr::from_str(addr)?;
70
71        let rocket_handle = tokio::spawn(async move {
72            let _rocket = rocket::build()
73                .configure(
74                    rocket::config::Config::figment()
75                        .merge(("address", addr.ip().to_string()))
76                        .merge(("port", addr.port())),
77                )
78                .manage(bcast.clone()) // 将 BroadcastGroup 放入 Rocket 的状态管理
79                .mount("/", routes![ws_handler])
80                .launch()
81                .await;
82        });
83
84        Ok(rocket_handle)
85    }
86}