1use crate::awareness;
2use crate::awareness::{Awareness, AwarenessUpdate};
3use thiserror::Error;
4use yrs::encoding::read;
5use yrs::updates::decoder::{Decode, Decoder};
6use yrs::updates::encoder::{Encode, Encoder};
7use yrs::{ReadTxn, StateVector, Transact, Update};
8
9pub struct DefaultProtocol;
36
37impl Protocol for DefaultProtocol {}
38
39pub trait Protocol {
43 fn start<E: Encoder>(&self, awareness: &Awareness, encoder: &mut E) -> Result<(), Error> {
47 let (sv, update) = {
48 let sv = awareness.doc().transact().state_vector();
49 let update = awareness.update()?;
50 (sv, update)
51 };
52 Message::Sync(SyncMessage::SyncStep1(sv)).encode(encoder);
53 Message::Awareness(update).encode(encoder);
54 Ok(())
55 }
56
57 fn handle_sync_step1(
60 &self,
61 awareness: &Awareness,
62 sv: StateVector,
63 ) -> Result<Option<Message>, Error> {
64 let update = awareness.doc().transact().encode_state_as_update_v1(&sv);
65 Ok(Some(Message::Sync(SyncMessage::SyncStep2(update))))
66 }
67
68 fn handle_sync_step2(
71 &self,
72 awareness: &mut Awareness,
73 update: Update,
74 ) -> Result<Option<Message>, Error> {
75 let mut txn = awareness.doc().transact_mut();
76 txn.apply_update(update);
77 Ok(None)
78 }
79
80 fn handle_update(
83 &self,
84 awareness: &mut Awareness,
85 update: Update,
86 ) -> Result<Option<Message>, Error> {
87 self.handle_sync_step2(awareness, update)
88 }
89
90 fn handle_auth(
93 &self,
94 _awareness: &Awareness,
95 deny_reason: Option<String>,
96 ) -> Result<Option<Message>, Error> {
97 if let Some(reason) = deny_reason {
98 Err(Error::PermissionDenied { reason })
99 } else {
100 Ok(None)
101 }
102 }
103
104 fn handle_awareness_query(&self, awareness: &Awareness) -> Result<Option<Message>, Error> {
107 let update = awareness.update()?;
108 Ok(Some(Message::Awareness(update)))
109 }
110
111 fn handle_awareness_update(
114 &self,
115 awareness: &mut Awareness,
116 update: AwarenessUpdate,
117 ) -> Result<Option<Message>, Error> {
118 awareness.apply_update(update)?;
119 Ok(None)
120 }
121
122 fn missing_handle(
125 &self,
126 _awareness: &mut Awareness,
127 tag: u8,
128 _data: Vec<u8>,
129 ) -> Result<Option<Message>, Error> {
130 Err(Error::Unsupported(tag))
131 }
132}
133
134pub const MSG_SYNC: u8 = 0;
136pub const MSG_AWARENESS: u8 = 1;
138pub const MSG_AUTH: u8 = 2;
140pub const MSG_QUERY_AWARENESS: u8 = 3;
142
143pub const PERMISSION_DENIED: u8 = 0;
144pub const PERMISSION_GRANTED: u8 = 1;
145
146#[derive(Debug, Eq, PartialEq)]
147pub enum Message {
148 Sync(SyncMessage),
149 Auth(Option<String>),
150 AwarenessQuery,
151 Awareness(AwarenessUpdate),
152 Custom(u8, Vec<u8>),
153}
154
155impl Encode for Message {
156 fn encode<E: Encoder>(&self, encoder: &mut E) {
157 match self {
158 Message::Sync(msg) => {
159 encoder.write_var(MSG_SYNC);
160 msg.encode(encoder);
161 }
162 Message::Auth(reason) => {
163 encoder.write_var(MSG_AUTH);
164 if let Some(reason) = reason {
165 encoder.write_var(PERMISSION_DENIED);
166 encoder.write_string(&reason);
167 } else {
168 encoder.write_var(PERMISSION_GRANTED);
169 }
170 }
171 Message::AwarenessQuery => {
172 encoder.write_var(MSG_QUERY_AWARENESS);
173 }
174 Message::Awareness(update) => {
175 encoder.write_var(MSG_AWARENESS);
176 encoder.write_buf(&update.encode_v1())
177 }
178 Message::Custom(tag, data) => {
179 encoder.write_u8(*tag);
180 encoder.write_buf(&data);
181 }
182 }
183 }
184}
185
186impl Decode for Message {
187 fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, read::Error> {
188 let tag: u8 = decoder.read_var()?;
189 match tag {
190 MSG_SYNC => {
191 let msg = SyncMessage::decode(decoder)?;
192 Ok(Message::Sync(msg))
193 }
194 MSG_AWARENESS => {
195 let data = decoder.read_buf()?;
196 let update = AwarenessUpdate::decode_v1(data)?;
197 Ok(Message::Awareness(update))
198 }
199 MSG_AUTH => {
200 let reason = if decoder.read_var::<u8>()? == PERMISSION_DENIED {
201 Some(decoder.read_string()?.to_string())
202 } else {
203 None
204 };
205 Ok(Message::Auth(reason))
206 }
207 MSG_QUERY_AWARENESS => Ok(Message::AwarenessQuery),
208 tag => {
209 let data = decoder.read_buf()?;
210 Ok(Message::Custom(tag, data.to_vec()))
211 }
212 }
213 }
214}
215
216pub const MSG_SYNC_STEP_1: u8 = 0;
218pub const MSG_SYNC_STEP_2: u8 = 1;
220pub const MSG_SYNC_UPDATE: u8 = 2;
222
223#[derive(Debug, PartialEq, Eq)]
224pub enum SyncMessage {
225 SyncStep1(StateVector),
226 SyncStep2(Vec<u8>),
227 Update(Vec<u8>),
228}
229
230impl Encode for SyncMessage {
231 fn encode<E: Encoder>(&self, encoder: &mut E) {
232 match self {
233 SyncMessage::SyncStep1(sv) => {
234 encoder.write_var(MSG_SYNC_STEP_1);
235 encoder.write_buf(sv.encode_v1());
236 }
237 SyncMessage::SyncStep2(u) => {
238 encoder.write_var(MSG_SYNC_STEP_2);
239 encoder.write_buf(u);
240 }
241 SyncMessage::Update(u) => {
242 encoder.write_var(MSG_SYNC_UPDATE);
243 encoder.write_buf(u);
244 }
245 }
246 }
247}
248
249impl Decode for SyncMessage {
250 fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, read::Error> {
251 let tag: u8 = decoder.read_var()?;
252 match tag {
253 MSG_SYNC_STEP_1 => {
254 let buf = decoder.read_buf()?;
255 let sv = StateVector::decode_v1(buf)?;
256 Ok(SyncMessage::SyncStep1(sv))
257 }
258 MSG_SYNC_STEP_2 => {
259 let buf = decoder.read_buf()?;
260 Ok(SyncMessage::SyncStep2(buf.into()))
261 }
262 MSG_SYNC_UPDATE => {
263 let buf = decoder.read_buf()?;
264 Ok(SyncMessage::Update(buf.into()))
265 }
266 _ => Err(read::Error::UnexpectedValue),
267 }
268 }
269}
270
271#[derive(Debug, Error)]
273pub enum Error {
274 #[error("failed to deserialize message: {0}")]
276 DecodingError(#[from] read::Error),
277
278 #[error("failed to process awareness update: {0}")]
280 AwarenessEncoding(#[from] awareness::Error),
281
282 #[error("permission denied to access: {reason}")]
284 PermissionDenied { reason: String },
285
286 #[error("unsupported message tag identifier: {0}")]
288 Unsupported(u8),
289
290 #[error("IO error: {0}")]
292 IO(#[from] std::io::Error),
293
294 #[error("internal failure: {0}")]
296 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
297}
298
299#[cfg(feature = "net")]
300impl From<tokio::task::JoinError> for Error {
301 fn from(value: tokio::task::JoinError) -> Self {
302 Error::Other(value.into())
303 }
304}
305
306pub struct MessageReader<'a, D: Decoder>(&'a mut D);
310
311impl<'a, D: Decoder> MessageReader<'a, D> {
312 pub fn new(decoder: &'a mut D) -> Self {
313 MessageReader(decoder)
314 }
315}
316
317impl<'a, D: Decoder> Iterator for MessageReader<'a, D> {
318 type Item = Result<Message, read::Error>;
319
320 fn next(&mut self) -> Option<Self::Item> {
321 match Message::decode(self.0) {
322 Ok(msg) => Some(Ok(msg)),
323 Err(read::Error::EndOfBuffer(_)) => None,
324 Err(error) => Some(Err(error)),
325 }
326 }
327}
328
329#[cfg(test)]
330mod test {
331 use crate::awareness::Awareness;
332 use crate::sync::*;
333 use std::collections::HashMap;
334 use yrs::encoding::read::Cursor;
335 use yrs::updates::decoder::{Decode, DecoderV1};
336 use yrs::updates::encoder::{Encode, EncoderV1};
337 use yrs::{Doc, GetString, ReadTxn, StateVector, Text, Transact};
338
339 #[test]
340 fn message_encoding() {
341 let doc = Doc::new();
342 let txt = doc.get_or_insert_text("text");
343 txt.push(&mut doc.transact_mut(), "hello world");
344 let mut awareness = Awareness::new(doc);
345 awareness.set_local_state("{\"user\":{\"name\":\"Anonymous 50\",\"color\":\"#30bced\",\"colorLight\":\"#30bced33\"}}");
346
347 let messages = [
348 Message::Sync(SyncMessage::SyncStep1(
349 awareness.doc().transact().state_vector(),
350 )),
351 Message::Sync(SyncMessage::SyncStep2(
352 awareness
353 .doc()
354 .transact()
355 .encode_state_as_update_v1(&StateVector::default()),
356 )),
357 Message::Awareness(awareness.update().unwrap()),
358 Message::Auth(Some("reason".to_string())),
359 Message::AwarenessQuery,
360 ];
361
362 for msg in messages {
363 let encoded = msg.encode_v1();
364 let decoded =
365 Message::decode_v1(&encoded).expect(&format!("failed to decode {:?}", msg));
366 assert_eq!(decoded, msg);
367 }
368 }
369
370 #[test]
371 fn protocol_init() {
372 let awareness = Awareness::default();
373 let protocol = DefaultProtocol;
374 let mut encoder = EncoderV1::new();
375 protocol.start(&awareness, &mut encoder).unwrap();
376 let data = encoder.to_vec();
377 let mut decoder = DecoderV1::new(Cursor::new(&data));
378 let mut reader = MessageReader::new(&mut decoder);
379
380 assert_eq!(
381 reader.next().unwrap().unwrap(),
382 Message::Sync(SyncMessage::SyncStep1(StateVector::default()))
383 );
384
385 assert_eq!(
386 reader.next().unwrap().unwrap(),
387 Message::Awareness(awareness.update().unwrap())
388 );
389
390 assert!(reader.next().is_none());
391 }
392
393 #[test]
394 fn protocol_sync_steps() {
395 let protocol = DefaultProtocol;
396
397 let mut a1 = Awareness::new(Doc::with_client_id(1));
398 let mut a2 = Awareness::new(Doc::with_client_id(2));
399
400 let expected = {
401 let txt = a1.doc_mut().get_or_insert_text("test");
402 let mut txn = a1.doc_mut().transact_mut();
403 txt.push(&mut txn, "hello");
404 txn.encode_state_as_update_v1(&StateVector::default())
405 };
406
407 let result = protocol
408 .handle_sync_step1(&a1, a2.doc().transact().state_vector())
409 .unwrap();
410
411 assert_eq!(
412 result,
413 Some(Message::Sync(SyncMessage::SyncStep2(expected)))
414 );
415
416 if let Some(Message::Sync(SyncMessage::SyncStep2(u))) = result {
417 let result2 = protocol
418 .handle_sync_step2(&mut a2, Update::decode_v1(&u).unwrap())
419 .unwrap();
420
421 assert!(result2.is_none());
422 }
423
424 let txt = a2.doc().transact().get_text("test").unwrap();
425 assert_eq!(txt.get_string(&a2.doc().transact()), "hello".to_owned());
426 }
427
428 #[test]
429 fn protocol_sync_step_update() {
430 let protocol = DefaultProtocol;
431
432 let mut a1 = Awareness::new(Doc::with_client_id(1));
433 let mut a2 = Awareness::new(Doc::with_client_id(2));
434
435 let data = {
436 let txt = a1.doc_mut().get_or_insert_text("test");
437 let mut txn = a1.doc_mut().transact_mut();
438 txt.push(&mut txn, "hello");
439 txn.encode_update_v1()
440 };
441
442 let result = protocol
443 .handle_update(&mut a2, Update::decode_v1(&data).unwrap())
444 .unwrap();
445
446 assert!(result.is_none());
447
448 let txt = a2.doc().transact().get_text("test").unwrap();
449 assert_eq!(txt.get_string(&a2.doc().transact()), "hello".to_owned());
450 }
451
452 #[test]
453 fn protocol_awareness_sync() {
454 let protocol = DefaultProtocol;
455
456 let mut a1 = Awareness::new(Doc::with_client_id(1));
457 let mut a2 = Awareness::new(Doc::with_client_id(2));
458
459 a1.set_local_state("{x:3}");
460 let result = protocol.handle_awareness_query(&a1).unwrap();
461
462 assert_eq!(result, Some(Message::Awareness(a1.update().unwrap())));
463
464 if let Some(Message::Awareness(u)) = result {
465 let result = protocol.handle_awareness_update(&mut a2, u).unwrap();
466 assert!(result.is_none());
467 }
468
469 assert_eq!(a2.clients(), &HashMap::from([(1, "{x:3}".to_owned())]));
470 }
471}