Skip to main content

zamsync_core/
sync.rs

1use crate::{Event, NodeId, SequenceNumber};
2use rkyv::{Archive, Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default, PartialEq, Eq)]
6#[archive(check_bytes)]
7pub struct VersionVector {
8    pub entries: HashMap<u32, SequenceNumber>,
9}
10
11impl VersionVector {
12    pub fn new() -> Self {
13        Self::default()
14    }
15
16    pub fn update(&mut self, node_id: NodeId, seq: SequenceNumber) {
17        let entry = self
18            .entries
19            .entry(node_id.0)
20            .or_insert(SequenceNumber::ZERO);
21        if seq > *entry {
22            *entry = seq;
23        }
24    }
25
26    pub fn get(&self, node_id: NodeId) -> SequenceNumber {
27        self.entries
28            .get(&node_id.0)
29            .cloned()
30            .unwrap_or(SequenceNumber::ZERO)
31    }
32
33    /// Returns the first sequence number this VV needs from `other`, for each node where
34    /// `other` is ahead. The returned `SequenceNumber` is the inclusive start of the gap
35    /// (i.e. `events_since(node, start)` should return events with `seq >= start`).
36    pub fn find_gaps(&self, other: &VersionVector) -> Vec<(NodeId, SequenceNumber)> {
37        let mut gaps = Vec::new();
38        for (node_id_raw, other_seq) in &other.entries {
39            let node_id = NodeId(*node_id_raw);
40            match self.entries.get(node_id_raw) {
41                Some(&local_last) if *other_seq > local_last => {
42                    gaps.push((node_id, local_last.next()));
43                }
44                None => {
45                    gaps.push((node_id, SequenceNumber::ZERO));
46                }
47                _ => {}
48            }
49        }
50        gaps
51    }
52}
53
54#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default)]
55#[archive(check_bytes)]
56pub struct PeerSyncState {
57    pub known_vv: VersionVector,
58    pub last_acked: Option<SequenceNumber>,
59}
60
61#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default)]
62#[archive(check_bytes)]
63pub struct ReplicationState {
64    pub self_id: NodeId,
65    pub local_vv: VersionVector,
66    pub peers: HashMap<u32, PeerSyncState>,
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use std::collections::HashMap;
73
74    #[test]
75    fn test_vv_update_only_advances() {
76        let mut vv = VersionVector::new();
77        vv.update(NodeId(1), SequenceNumber(5));
78        assert_eq!(vv.get(NodeId(1)), SequenceNumber(5));
79        // Lower seq must not overwrite a higher one.
80        vv.update(NodeId(1), SequenceNumber(3));
81        assert_eq!(
82            vv.get(NodeId(1)),
83            SequenceNumber(5),
84            "VV must never decrease"
85        );
86        // Same seq is idempotent.
87        vv.update(NodeId(1), SequenceNumber(5));
88        assert_eq!(vv.get(NodeId(1)), SequenceNumber(5));
89        // Higher seq advances.
90        vv.update(NodeId(1), SequenceNumber(10));
91        assert_eq!(vv.get(NodeId(1)), SequenceNumber(10));
92    }
93
94    #[test]
95    fn test_vv_get_unknown_node_returns_zero() {
96        let vv = VersionVector::new();
97        assert_eq!(vv.get(NodeId(42)), SequenceNumber::ZERO);
98    }
99
100    #[test]
101    fn test_vv_find_gaps_empty_local_needs_everything_from_zero() {
102        let local = VersionVector::new();
103        let mut remote = VersionVector::new();
104        remote.update(NodeId(1), SequenceNumber(10));
105        remote.update(NodeId(2), SequenceNumber(5));
106
107        let gaps: HashMap<u32, SequenceNumber> = local
108            .find_gaps(&remote)
109            .into_iter()
110            .map(|(n, s)| (n.0, s))
111            .collect();
112        assert_eq!(
113            gaps[&1],
114            SequenceNumber::ZERO,
115            "unknown node: start from seq 0"
116        );
117        assert_eq!(gaps[&2], SequenceNumber::ZERO);
118    }
119
120    #[test]
121    fn test_vv_find_gaps_partial_overlap_returns_next_needed() {
122        let mut local = VersionVector::new();
123        local.update(NodeId(1), SequenceNumber(5));
124
125        let mut remote = VersionVector::new();
126        remote.update(NodeId(1), SequenceNumber(10));
127        remote.update(NodeId(2), SequenceNumber(3));
128
129        let gaps: HashMap<u32, SequenceNumber> = local
130            .find_gaps(&remote)
131            .into_iter()
132            .map(|(n, s)| (n.0, s))
133            .collect();
134        // local has node 1 up to seq 5, remote has 10 → need seq 6 (local.next())
135        assert_eq!(gaps[&1], SequenceNumber(6));
136        // node 2 is unknown to local → need from seq 0
137        assert_eq!(gaps[&2], SequenceNumber::ZERO);
138    }
139
140    #[test]
141    fn test_vv_find_gaps_up_to_date_returns_empty() {
142        let mut local = VersionVector::new();
143        local.update(NodeId(1), SequenceNumber(10));
144
145        let mut remote = VersionVector::new();
146        remote.update(NodeId(1), SequenceNumber(10));
147
148        assert!(local.find_gaps(&remote).is_empty(), "no gap when equal");
149    }
150
151    #[test]
152    fn test_vv_find_gaps_local_ahead_returns_empty() {
153        let mut local = VersionVector::new();
154        local.update(NodeId(1), SequenceNumber(15));
155
156        let mut remote = VersionVector::new();
157        remote.update(NodeId(1), SequenceNumber(10));
158
159        assert!(
160            local.find_gaps(&remote).is_empty(),
161            "local is ahead, nothing to pull"
162        );
163    }
164
165    #[test]
166    fn test_vv_find_gaps_start_is_inclusive_next_seq() {
167        let mut local = VersionVector::new();
168        local.update(NodeId(1), SequenceNumber(3));
169
170        let mut remote = VersionVector::new();
171        remote.update(NodeId(1), SequenceNumber(7));
172
173        let gaps = local.find_gaps(&remote);
174        assert_eq!(gaps.len(), 1);
175        // local has up to seq 3, remote has 7 → first missing is seq 4 (3.next())
176        assert_eq!(gaps[0], (NodeId(1), SequenceNumber(4)));
177    }
178
179    #[test]
180    fn test_vv_find_gaps_200_peers_correct_at_scale() {
181        const PEER_COUNT: u32 = 200;
182
183        let mut local = VersionVector::new();
184        let mut remote = VersionVector::new();
185
186        // local knows the first 100 peers (up to seq 5 each).
187        // remote knows all 200 peers (up to seq 10 each).
188        for i in 0..PEER_COUNT {
189            remote.update(NodeId(i), SequenceNumber(10));
190            if i < 100 {
191                local.update(NodeId(i), SequenceNumber(5));
192            }
193        }
194
195        let gaps = local.find_gaps(&remote);
196        assert_eq!(gaps.len(), PEER_COUNT as usize, "must find 200 gaps");
197
198        let gap_map: HashMap<u32, SequenceNumber> =
199            gaps.into_iter().map(|(n, s)| (n.0, s)).collect();
200
201        for i in 0..PEER_COUNT {
202            if i < 100 {
203                // local has seq 5, remote has 10 → need seq 6
204                assert_eq!(
205                    gap_map[&i],
206                    SequenceNumber(6),
207                    "peer {i}: expected next seq 6"
208                );
209            } else {
210                // unknown to local → need from seq 0
211                assert_eq!(
212                    gap_map[&i],
213                    SequenceNumber::ZERO,
214                    "peer {i}: expected seq 0"
215                );
216            }
217        }
218    }
219
220    #[test]
221    fn test_vv_find_gaps_ignores_nodes_not_in_remote() {
222        let mut local = VersionVector::new();
223        local.update(NodeId(1), SequenceNumber(5));
224
225        // remote is empty -- local has events remote doesn't, but that's not a "gap"
226        // (gaps are defined as what the remote has that we don't)
227        let remote = VersionVector::new();
228        assert!(local.find_gaps(&remote).is_empty());
229    }
230}
231
232#[derive(Archive, Deserialize, Serialize, Debug, Clone)]
233#[archive(check_bytes)]
234pub enum SyncMessage {
235    Handshake {
236        node_id: NodeId,
237        vv: VersionVector,
238    },
239    PullRequest {
240        origin_node: NodeId,
241        start_seq: SequenceNumber,
242        limit: u32,
243    },
244    EventBatch {
245        origin_node: NodeId,
246        events: Vec<Event>,
247    },
248    SyncComplete,
249}