Skip to main content

zamsync_core/
sync.rs

1use crate::{Event, NodeId, SequenceNumber};
2use rkyv::{Archive, Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default, PartialEq, Eq)]
6#[archive(check_bytes)]
7pub struct VersionVector {
8    pub entries: HashMap<u32, SequenceNumber>,
9}
10
11impl VersionVector {
12    pub fn new() -> Self {
13        Self::default()
14    }
15
16    pub fn update(&mut self, node_id: NodeId, seq: SequenceNumber) {
17        let entry = self
18            .entries
19            .entry(node_id.0)
20            .or_insert(SequenceNumber::ZERO);
21        if seq > *entry {
22            *entry = seq;
23        }
24    }
25
26    pub fn get(&self, node_id: NodeId) -> SequenceNumber {
27        self.entries
28            .get(&node_id.0)
29            .cloned()
30            .unwrap_or(SequenceNumber::ZERO)
31    }
32
33    /// Returns the first sequence number this VV needs from `other`, for each node where
34    /// `other` is ahead. The returned `SequenceNumber` is the inclusive start of the gap
35    /// (i.e. `events_since(node, start)` should return events with `seq >= start`).
36    pub fn find_gaps(&self, other: &VersionVector) -> Vec<(NodeId, SequenceNumber)> {
37        let mut gaps = Vec::new();
38        for (node_id_raw, other_seq) in &other.entries {
39            let node_id = NodeId(*node_id_raw);
40            match self.entries.get(node_id_raw) {
41                Some(&local_last) if *other_seq > local_last => {
42                    gaps.push((node_id, local_last.next()));
43                }
44                None => {
45                    gaps.push((node_id, SequenceNumber::ZERO));
46                }
47                _ => {}
48            }
49        }
50        gaps
51    }
52}
53
54#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default)]
55#[archive(check_bytes)]
56pub struct PeerSyncState {
57    pub known_vv: VersionVector,
58    pub last_acked: Option<SequenceNumber>,
59}
60
61#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default)]
62#[archive(check_bytes)]
63pub struct ReplicationState {
64    pub self_id: NodeId,
65    pub local_vv: VersionVector,
66    pub peers: HashMap<u32, PeerSyncState>,
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use std::collections::HashMap;
73
74    #[test]
75    fn test_vv_update_only_advances() {
76        let mut vv = VersionVector::new();
77        vv.update(NodeId(1), SequenceNumber(5));
78        assert_eq!(vv.get(NodeId(1)), SequenceNumber(5));
79        // Lower seq must not overwrite a higher one.
80        vv.update(NodeId(1), SequenceNumber(3));
81        assert_eq!(vv.get(NodeId(1)), SequenceNumber(5), "VV must never decrease");
82        // Same seq is idempotent.
83        vv.update(NodeId(1), SequenceNumber(5));
84        assert_eq!(vv.get(NodeId(1)), SequenceNumber(5));
85        // Higher seq advances.
86        vv.update(NodeId(1), SequenceNumber(10));
87        assert_eq!(vv.get(NodeId(1)), SequenceNumber(10));
88    }
89
90    #[test]
91    fn test_vv_get_unknown_node_returns_zero() {
92        let vv = VersionVector::new();
93        assert_eq!(vv.get(NodeId(42)), SequenceNumber::ZERO);
94    }
95
96    #[test]
97    fn test_vv_find_gaps_empty_local_needs_everything_from_zero() {
98        let local = VersionVector::new();
99        let mut remote = VersionVector::new();
100        remote.update(NodeId(1), SequenceNumber(10));
101        remote.update(NodeId(2), SequenceNumber(5));
102
103        let gaps: HashMap<u32, SequenceNumber> =
104            local.find_gaps(&remote).into_iter().map(|(n, s)| (n.0, s)).collect();
105        assert_eq!(gaps[&1], SequenceNumber::ZERO, "unknown node: start from seq 0");
106        assert_eq!(gaps[&2], SequenceNumber::ZERO);
107    }
108
109    #[test]
110    fn test_vv_find_gaps_partial_overlap_returns_next_needed() {
111        let mut local = VersionVector::new();
112        local.update(NodeId(1), SequenceNumber(5));
113
114        let mut remote = VersionVector::new();
115        remote.update(NodeId(1), SequenceNumber(10));
116        remote.update(NodeId(2), SequenceNumber(3));
117
118        let gaps: HashMap<u32, SequenceNumber> =
119            local.find_gaps(&remote).into_iter().map(|(n, s)| (n.0, s)).collect();
120        // local has node 1 up to seq 5, remote has 10 → need seq 6 (local.next())
121        assert_eq!(gaps[&1], SequenceNumber(6));
122        // node 2 is unknown to local → need from seq 0
123        assert_eq!(gaps[&2], SequenceNumber::ZERO);
124    }
125
126    #[test]
127    fn test_vv_find_gaps_up_to_date_returns_empty() {
128        let mut local = VersionVector::new();
129        local.update(NodeId(1), SequenceNumber(10));
130
131        let mut remote = VersionVector::new();
132        remote.update(NodeId(1), SequenceNumber(10));
133
134        assert!(local.find_gaps(&remote).is_empty(), "no gap when equal");
135    }
136
137    #[test]
138    fn test_vv_find_gaps_local_ahead_returns_empty() {
139        let mut local = VersionVector::new();
140        local.update(NodeId(1), SequenceNumber(15));
141
142        let mut remote = VersionVector::new();
143        remote.update(NodeId(1), SequenceNumber(10));
144
145        assert!(
146            local.find_gaps(&remote).is_empty(),
147            "local is ahead, nothing to pull"
148        );
149    }
150
151    #[test]
152    fn test_vv_find_gaps_start_is_inclusive_next_seq() {
153        let mut local = VersionVector::new();
154        local.update(NodeId(1), SequenceNumber(3));
155
156        let mut remote = VersionVector::new();
157        remote.update(NodeId(1), SequenceNumber(7));
158
159        let gaps = local.find_gaps(&remote);
160        assert_eq!(gaps.len(), 1);
161        // local has up to seq 3, remote has 7 → first missing is seq 4 (3.next())
162        assert_eq!(gaps[0], (NodeId(1), SequenceNumber(4)));
163    }
164
165    #[test]
166    fn test_vv_find_gaps_200_peers_correct_at_scale() {
167        const PEER_COUNT: u32 = 200;
168
169        let mut local = VersionVector::new();
170        let mut remote = VersionVector::new();
171
172        // local knows the first 100 peers (up to seq 5 each).
173        // remote knows all 200 peers (up to seq 10 each).
174        for i in 0..PEER_COUNT {
175            remote.update(NodeId(i), SequenceNumber(10));
176            if i < 100 {
177                local.update(NodeId(i), SequenceNumber(5));
178            }
179        }
180
181        let gaps = local.find_gaps(&remote);
182        assert_eq!(gaps.len(), PEER_COUNT as usize, "must find 200 gaps");
183
184        let gap_map: HashMap<u32, SequenceNumber> =
185            gaps.into_iter().map(|(n, s)| (n.0, s)).collect();
186
187        for i in 0..PEER_COUNT {
188            if i < 100 {
189                // local has seq 5, remote has 10 → need seq 6
190                assert_eq!(gap_map[&i], SequenceNumber(6), "peer {i}: expected next seq 6");
191            } else {
192                // unknown to local → need from seq 0
193                assert_eq!(gap_map[&i], SequenceNumber::ZERO, "peer {i}: expected seq 0");
194            }
195        }
196    }
197
198    #[test]
199    fn test_vv_find_gaps_ignores_nodes_not_in_remote() {
200        let mut local = VersionVector::new();
201        local.update(NodeId(1), SequenceNumber(5));
202
203        // remote is empty -- local has events remote doesn't, but that's not a "gap"
204        // (gaps are defined as what the remote has that we don't)
205        let remote = VersionVector::new();
206        assert!(local.find_gaps(&remote).is_empty());
207    }
208}
209
210#[derive(Archive, Deserialize, Serialize, Debug, Clone)]
211#[archive(check_bytes)]
212pub enum SyncMessage {
213    Handshake {
214        node_id: NodeId,
215        vv: VersionVector,
216    },
217    PullRequest {
218        origin_node: NodeId,
219        start_seq: SequenceNumber,
220        limit: u32,
221    },
222    EventBatch {
223        origin_node: NodeId,
224        events: Vec<Event>,
225    },
226    SyncComplete,
227}