1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3use std::fmt::Formatter;
4use std::sync::Arc;
5use std::time::Instant;
6use thiserror::Error;
7use yrs::block::ClientID;
8use yrs::encoding::read;
9use yrs::updates::decoder::{Decode, Decoder};
10use yrs::updates::encoder::{Encode, Encoder};
11use yrs::{Doc, Observer, Subscription};
12
13const NULL_STR: &str = "null";
14
15pub struct Awareness {
28 doc: Doc,
29 states: HashMap<ClientID, String>,
30 meta: HashMap<ClientID, MetaClientState>,
31 on_update: Option<Observer<Arc<dyn Fn(&Awareness, &Event) -> () + 'static>>>,
32}
33
34unsafe impl Send for Awareness {}
35unsafe impl Sync for Awareness {}
36
37impl Awareness {
38 pub fn new(doc: Doc) -> Self {
42 Awareness {
43 doc,
44 on_update: None,
45 states: HashMap::new(),
46 meta: HashMap::new(),
47 }
48 }
49
50 pub fn on_update<F>(&mut self, f: F) -> UpdateSubscription
52 where
53 F: Fn(&Awareness, &Event) -> () + 'static,
54 {
55 let eh = self.on_update.get_or_insert_with(Observer::default);
56 eh.subscribe(Arc::new(f))
57 }
58
59 pub fn doc(&self) -> &Doc {
61 &self.doc
62 }
63
64 pub fn doc_mut(&mut self) -> &mut Doc {
66 &mut self.doc
67 }
68
69 pub fn client_id(&self) -> ClientID {
71 self.doc.client_id()
72 }
73
74 pub fn clients(&self) -> &HashMap<ClientID, String> {
78 &self.states
79 }
80
81 pub fn local_state(&self) -> Option<&str> {
83 Some(self.states.get(&self.doc.client_id())?.as_str())
84 }
85
86 pub fn set_local_state<S: Into<String>>(&mut self, json: S) {
91 let client_id = self.doc.client_id();
92 self.update_meta(client_id);
93 let new: String = json.into();
94 match self.states.entry(client_id) {
95 Entry::Occupied(mut e) => {
96 e.insert(new);
97 if let Some(eh) = self.on_update.as_ref() {
98 let e = Event::new(vec![], vec![client_id], vec![]);
99 for cb in eh.callbacks() {
100 cb(self, &e);
101 }
102 }
103 }
104 Entry::Vacant(e) => {
105 e.insert(new);
106 if let Some(eh) = self.on_update.as_ref() {
107 let e = Event::new(vec![client_id], vec![], vec![]);
108 for cb in eh.callbacks() {
109 cb(self, &e);
110 }
111 }
112 }
113 }
114 }
115
116 pub fn remove_state(&mut self, client_id: ClientID) {
118 let prev_state = self.states.remove(&client_id);
119 self.update_meta(client_id);
120 if let Some(eh) = self.on_update.as_ref() {
121 if prev_state.is_some() {
122 let e = Event::new(Vec::default(), Vec::default(), vec![client_id]);
123 for cb in eh.callbacks() {
124 cb(self, &e);
125 }
126 }
127 }
128 }
129
130 pub fn clean_local_state(&mut self) {
133 let client_id = self.doc.client_id();
134 self.remove_state(client_id);
135 }
136
137 fn update_meta(&mut self, client_id: ClientID) {
138 match self.meta.entry(client_id) {
139 Entry::Occupied(mut e) => {
140 let clock = e.get().clock + 1;
141 let meta = MetaClientState::new(clock, Instant::now());
142 e.insert(meta);
143 }
144 Entry::Vacant(e) => {
145 e.insert(MetaClientState::new(1, Instant::now()));
146 }
147 }
148 }
149
150 pub fn update(&self) -> Result<AwarenessUpdate, Error> {
152 let clients = self.states.keys().cloned();
153 self.update_with_clients(clients)
154 }
155
156 pub fn update_with_clients<I: IntoIterator<Item = ClientID>>(
161 &self,
162 clients: I,
163 ) -> Result<AwarenessUpdate, Error> {
164 let mut res = HashMap::new();
165 for client_id in clients {
166 let clock = if let Some(meta) = self.meta.get(&client_id) {
167 meta.clock
168 } else {
169 return Err(Error::ClientNotFound(client_id));
170 };
171 let json = if let Some(json) = self.states.get(&client_id) {
172 json.clone()
173 } else {
174 String::from(NULL_STR)
175 };
176 res.insert(client_id, AwarenessUpdateEntry { clock, json });
177 }
178 Ok(AwarenessUpdate { clients: res })
179 }
180
181 pub fn apply_update(&mut self, update: AwarenessUpdate) -> Result<(), Error> {
187 let now = Instant::now();
188
189 let mut added = Vec::new();
190 let mut updated = Vec::new();
191 let mut removed = Vec::new();
192
193 for (client_id, entry) in update.clients {
194 let mut clock = entry.clock;
195 let is_null = entry.json.as_str() == NULL_STR;
196 match self.meta.entry(client_id) {
197 Entry::Occupied(mut e) => {
198 let prev = e.get();
199 let is_removed =
200 prev.clock == clock && is_null && self.states.contains_key(&client_id);
201 let is_new = prev.clock < clock;
202 if is_new || is_removed {
203 if is_null {
204 if client_id == self.doc.client_id()
206 && self.states.get(&client_id).is_some()
207 {
208 clock += 1;
211 } else {
212 self.states.remove(&client_id);
213 if self.on_update.is_some() {
214 removed.push(client_id);
215 }
216 }
217 } else {
218 match self.states.entry(client_id) {
219 Entry::Occupied(mut e) => {
220 if self.on_update.is_some() {
221 updated.push(client_id);
222 }
223 e.insert(entry.json);
224 }
225 Entry::Vacant(e) => {
226 e.insert(entry.json);
227 if self.on_update.is_some() {
228 updated.push(client_id);
229 }
230 }
231 }
232 }
233 e.insert(MetaClientState::new(clock, now));
234 true
235 } else {
236 false
237 }
238 }
239 Entry::Vacant(e) => {
240 e.insert(MetaClientState::new(clock, now));
241 self.states.insert(client_id, entry.json);
242 if self.on_update.is_some() {
243 added.push(client_id);
244 }
245 true
246 }
247 };
248 }
249
250 if let Some(eh) = self.on_update.as_ref() {
251 if !added.is_empty() || !updated.is_empty() || !removed.is_empty() {
252 let e = Event::new(added, updated, removed);
253 for cb in eh.callbacks() {
254 cb(self, &e);
255 }
256 }
257 }
258
259 Ok(())
260 }
261}
262
263impl Default for Awareness {
264 fn default() -> Self {
265 Awareness::new(Doc::new())
266 }
267}
268
269impl std::fmt::Debug for Awareness {
270 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
271 f.debug_struct("Awareness")
272 .field("state", &self.states)
273 .field("meta", &self.meta)
274 .field("doc", &self.doc)
275 .finish()
276 }
277}
278
279pub type UpdateSubscription = Subscription<Arc<dyn Fn(&Awareness, &Event) -> () + 'static>>;
282
283#[derive(Debug, Eq, PartialEq)]
285pub struct AwarenessUpdate {
286 pub(crate) clients: HashMap<ClientID, AwarenessUpdateEntry>,
287}
288
289impl Encode for AwarenessUpdate {
290 fn encode<E: Encoder>(&self, encoder: &mut E) {
291 encoder.write_var(self.clients.len());
292 for (&client_id, e) in self.clients.iter() {
293 encoder.write_var(client_id);
294 encoder.write_var(e.clock);
295 encoder.write_string(&e.json);
296 }
297 }
298}
299
300impl Decode for AwarenessUpdate {
301 fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, read::Error> {
302 let len: usize = decoder.read_var()?;
303 let mut clients = HashMap::with_capacity(len);
304 for _ in 0..len {
305 let client_id: ClientID = decoder.read_var()?;
306 let clock: u32 = decoder.read_var()?;
307 let json = decoder.read_string()?.to_string();
308 clients.insert(client_id, AwarenessUpdateEntry { clock, json });
309 }
310
311 Ok(AwarenessUpdate { clients })
312 }
313}
314
315#[derive(Debug, Eq, PartialEq)]
318pub struct AwarenessUpdateEntry {
319 pub(crate) clock: u32,
320 pub(crate) json: String,
321}
322
323#[derive(Error, Debug)]
325pub enum Error {
326 #[error("client ID `{0}` not found")]
328 ClientNotFound(ClientID),
329}
330
331#[derive(Debug, Clone, PartialEq, Eq)]
332struct MetaClientState {
333 clock: u32,
334 last_updated: Instant,
335}
336
337impl MetaClientState {
338 fn new(clock: u32, last_updated: Instant) -> Self {
339 MetaClientState {
340 clock,
341 last_updated,
342 }
343 }
344}
345
346#[derive(Debug, Default, Clone, Eq, PartialEq)]
348pub struct Event {
349 added: Vec<ClientID>,
350 updated: Vec<ClientID>,
351 removed: Vec<ClientID>,
352}
353
354impl Event {
355 pub fn new(added: Vec<ClientID>, updated: Vec<ClientID>, removed: Vec<ClientID>) -> Self {
356 Event {
357 added,
358 updated,
359 removed,
360 }
361 }
362
363 pub fn added(&self) -> &[ClientID] {
366 &self.added
367 }
368
369 pub fn updated(&self) -> &[ClientID] {
372 &self.updated
373 }
374
375 pub fn removed(&self) -> &[ClientID] {
378 &self.removed
379 }
380}
381
382#[cfg(test)]
383mod test {
384 use crate::awareness::{Awareness, Event};
385 use std::sync::mpsc::{channel, Receiver};
386 use yrs::Doc;
387
388 fn update(
389 recv: &mut Receiver<Event>,
390 from: &Awareness,
391 to: &mut Awareness,
392 ) -> Result<Event, Box<dyn std::error::Error>> {
393 let e = recv.try_recv()?;
394 let u = from.update_with_clients([e.added(), e.updated(), e.removed()].concat())?;
395 to.apply_update(u)?;
396 Ok(e)
397 }
398
399 #[test]
400 fn awareness() -> Result<(), Box<dyn std::error::Error>> {
401 let (s1, mut o_local) = channel();
402 let mut local = Awareness::new(Doc::with_client_id(1));
403 let _sub_local = local.on_update(move |_, e| {
404 s1.send(e.clone()).unwrap();
405 });
406
407 let (s2, o_remote) = channel();
408 let mut remote = Awareness::new(Doc::with_client_id(2));
409 let _sub_remote = local.on_update(move |_, e| {
410 s2.send(e.clone()).unwrap();
411 });
412
413 local.set_local_state("{x:3}");
414 let _e_local = update(&mut o_local, &local, &mut remote)?;
415 assert_eq!(remote.clients()[&1], "{x:3}");
416 assert_eq!(remote.meta[&1].clock, 1);
417 assert_eq!(o_remote.try_recv()?.added, &[1]);
418
419 local.set_local_state("{x:4}");
420 let e_local = update(&mut o_local, &local, &mut remote)?;
421 let e_remote = o_remote.try_recv()?;
422 assert_eq!(remote.clients()[&1], "{x:4}");
423 assert_eq!(e_remote, Event::new(vec![], vec![1], vec![]));
424 assert_eq!(e_remote, e_local);
425
426 local.clean_local_state();
427 let e_local = update(&mut o_local, &local, &mut remote)?;
428 let e_remote = o_remote.try_recv()?;
429 assert_eq!(e_remote.removed.len(), 1);
430 assert_eq!(local.clients().get(&1), None);
431 assert_eq!(e_remote, e_local);
432 Ok(())
433 }
434}