y_octo/doc/common/
state.rs

1use std::{
2    collections::HashMap,
3    ops::{Deref, DerefMut},
4};
5
6use crate::{Client, Clock, CrdtRead, CrdtReader, CrdtWrite, CrdtWriter, Id, JwstCodecResult};
7
8#[derive(Default, Debug, PartialEq, Clone)]
9pub struct StateVector(HashMap<Client, Clock>);
10
11impl StateVector {
12    pub fn set_max(&mut self, client: Client, clock: Clock) {
13        self.entry(client)
14            .and_modify(|m_clock| {
15                if *m_clock < clock {
16                    *m_clock = clock;
17                }
18            })
19            .or_insert(clock);
20    }
21
22    pub fn get(&self, client: &Client) -> Clock {
23        *self.0.get(client).unwrap_or(&0)
24    }
25
26    pub fn contains(&self, id: &Id) -> bool {
27        id.clock <= self.get(&id.client)
28    }
29
30    pub fn set_min(&mut self, client: Client, clock: Clock) {
31        self.entry(client)
32            .and_modify(|m_clock| {
33                if *m_clock > clock {
34                    *m_clock = clock;
35                }
36            })
37            .or_insert(clock);
38    }
39
40    pub fn iter(&self) -> impl Iterator<Item = (&Client, &Clock)> {
41        self.0.iter()
42    }
43
44    pub fn merge_with(&mut self, other: &Self) {
45        for (client, clock) in other.iter() {
46            self.set_min(*client, *clock);
47        }
48    }
49}
50
51impl Deref for StateVector {
52    type Target = HashMap<Client, Clock>;
53
54    fn deref(&self) -> &Self::Target {
55        &self.0
56    }
57}
58
59impl DerefMut for StateVector {
60    fn deref_mut(&mut self) -> &mut Self::Target {
61        &mut self.0
62    }
63}
64
65impl<const N: usize> From<[(Client, Clock); N]> for StateVector {
66    fn from(value: [(Client, Clock); N]) -> Self {
67        let mut map = HashMap::with_capacity(N);
68
69        for (client, clock) in value {
70            map.insert(client, clock);
71        }
72
73        Self(map)
74    }
75}
76
77impl<R: CrdtReader> CrdtRead<R> for StateVector {
78    fn read(decoder: &mut R) -> JwstCodecResult<Self> {
79        let len = decoder.read_var_u64()? as usize;
80
81        let mut map = HashMap::with_capacity(len);
82        for _ in 0..len {
83            let client = decoder.read_var_u64()?;
84            let clock = decoder.read_var_u64()?;
85            map.insert(client, clock);
86        }
87
88        Ok(Self(map))
89    }
90}
91
92impl<W: CrdtWriter> CrdtWrite<W> for StateVector {
93    fn write(&self, encoder: &mut W) -> JwstCodecResult {
94        encoder.write_var_u64(self.len() as u64)?;
95
96        for (client, clock) in self.iter() {
97            encoder.write_var_u64(*client)?;
98            encoder.write_var_u64(*clock)?;
99        }
100
101        Ok(())
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn test_state_vector_basic() {
111        let mut state_vector = StateVector::from([(1, 1), (2, 2), (3, 3)]);
112        assert_eq!(state_vector.len(), 3);
113        assert_eq!(state_vector.get(&1), 1);
114
115        state_vector.set_min(1, 0);
116        assert_eq!(state_vector.get(&1), 0);
117
118        state_vector.set_max(1, 4);
119        assert_eq!(state_vector.get(&1), 4);
120
121        // set inexistent client
122        state_vector.set_max(4, 1);
123        assert_eq!(state_vector.get(&4), 1);
124
125        // same client with larger clock
126        assert!(!state_vector.contains(&(1, 5).into()));
127    }
128
129    #[test]
130    fn test_state_vector_merge() {
131        let mut state_vector = StateVector::from([(1, 1), (2, 2), (3, 3)]);
132        let other_state_vector = StateVector::from([(1, 5), (2, 6), (3, 7)]);
133        state_vector.merge_with(&other_state_vector);
134        assert_eq!(state_vector, StateVector::from([(3, 3), (1, 1), (2, 2)]));
135    }
136}