1use crate::{Event, NodeId, SequenceNumber};
2use rkyv::{Archive, Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default, PartialEq, Eq)]
6#[archive(check_bytes)]
7pub struct VersionVector {
8 pub entries: HashMap<u32, SequenceNumber>,
9}
10
11impl VersionVector {
12 pub fn new() -> Self {
13 Self::default()
14 }
15
16 pub fn update(&mut self, node_id: NodeId, seq: SequenceNumber) {
17 let entry = self
18 .entries
19 .entry(node_id.0)
20 .or_insert(SequenceNumber::ZERO);
21 if seq > *entry {
22 *entry = seq;
23 }
24 }
25
26 pub fn get(&self, node_id: NodeId) -> SequenceNumber {
27 self.entries
28 .get(&node_id.0)
29 .cloned()
30 .unwrap_or(SequenceNumber::ZERO)
31 }
32
33 pub fn find_gaps(&self, other: &VersionVector) -> Vec<(NodeId, SequenceNumber)> {
37 let mut gaps = Vec::new();
38 for (node_id_raw, other_seq) in &other.entries {
39 let node_id = NodeId(*node_id_raw);
40 match self.entries.get(node_id_raw) {
41 Some(&local_last) if *other_seq > local_last => {
42 gaps.push((node_id, local_last.next()));
43 }
44 None => {
45 gaps.push((node_id, SequenceNumber::ZERO));
46 }
47 _ => {}
48 }
49 }
50 gaps
51 }
52}
53
54#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default)]
55#[archive(check_bytes)]
56pub struct PeerSyncState {
57 pub known_vv: VersionVector,
58 pub last_acked: Option<SequenceNumber>,
59}
60
61#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default)]
62#[archive(check_bytes)]
63pub struct ReplicationState {
64 pub self_id: NodeId,
65 pub local_vv: VersionVector,
66 pub peers: HashMap<u32, PeerSyncState>,
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use std::collections::HashMap;
73
74 #[test]
75 fn test_vv_update_only_advances() {
76 let mut vv = VersionVector::new();
77 vv.update(NodeId(1), SequenceNumber(5));
78 assert_eq!(vv.get(NodeId(1)), SequenceNumber(5));
79 vv.update(NodeId(1), SequenceNumber(3));
81 assert_eq!(
82 vv.get(NodeId(1)),
83 SequenceNumber(5),
84 "VV must never decrease"
85 );
86 vv.update(NodeId(1), SequenceNumber(5));
88 assert_eq!(vv.get(NodeId(1)), SequenceNumber(5));
89 vv.update(NodeId(1), SequenceNumber(10));
91 assert_eq!(vv.get(NodeId(1)), SequenceNumber(10));
92 }
93
94 #[test]
95 fn test_vv_get_unknown_node_returns_zero() {
96 let vv = VersionVector::new();
97 assert_eq!(vv.get(NodeId(42)), SequenceNumber::ZERO);
98 }
99
100 #[test]
101 fn test_vv_find_gaps_empty_local_needs_everything_from_zero() {
102 let local = VersionVector::new();
103 let mut remote = VersionVector::new();
104 remote.update(NodeId(1), SequenceNumber(10));
105 remote.update(NodeId(2), SequenceNumber(5));
106
107 let gaps: HashMap<u32, SequenceNumber> = local
108 .find_gaps(&remote)
109 .into_iter()
110 .map(|(n, s)| (n.0, s))
111 .collect();
112 assert_eq!(
113 gaps[&1],
114 SequenceNumber::ZERO,
115 "unknown node: start from seq 0"
116 );
117 assert_eq!(gaps[&2], SequenceNumber::ZERO);
118 }
119
120 #[test]
121 fn test_vv_find_gaps_partial_overlap_returns_next_needed() {
122 let mut local = VersionVector::new();
123 local.update(NodeId(1), SequenceNumber(5));
124
125 let mut remote = VersionVector::new();
126 remote.update(NodeId(1), SequenceNumber(10));
127 remote.update(NodeId(2), SequenceNumber(3));
128
129 let gaps: HashMap<u32, SequenceNumber> = local
130 .find_gaps(&remote)
131 .into_iter()
132 .map(|(n, s)| (n.0, s))
133 .collect();
134 assert_eq!(gaps[&1], SequenceNumber(6));
136 assert_eq!(gaps[&2], SequenceNumber::ZERO);
138 }
139
140 #[test]
141 fn test_vv_find_gaps_up_to_date_returns_empty() {
142 let mut local = VersionVector::new();
143 local.update(NodeId(1), SequenceNumber(10));
144
145 let mut remote = VersionVector::new();
146 remote.update(NodeId(1), SequenceNumber(10));
147
148 assert!(local.find_gaps(&remote).is_empty(), "no gap when equal");
149 }
150
151 #[test]
152 fn test_vv_find_gaps_local_ahead_returns_empty() {
153 let mut local = VersionVector::new();
154 local.update(NodeId(1), SequenceNumber(15));
155
156 let mut remote = VersionVector::new();
157 remote.update(NodeId(1), SequenceNumber(10));
158
159 assert!(
160 local.find_gaps(&remote).is_empty(),
161 "local is ahead, nothing to pull"
162 );
163 }
164
165 #[test]
166 fn test_vv_find_gaps_start_is_inclusive_next_seq() {
167 let mut local = VersionVector::new();
168 local.update(NodeId(1), SequenceNumber(3));
169
170 let mut remote = VersionVector::new();
171 remote.update(NodeId(1), SequenceNumber(7));
172
173 let gaps = local.find_gaps(&remote);
174 assert_eq!(gaps.len(), 1);
175 assert_eq!(gaps[0], (NodeId(1), SequenceNumber(4)));
177 }
178
179 #[test]
180 fn test_vv_find_gaps_200_peers_correct_at_scale() {
181 const PEER_COUNT: u32 = 200;
182
183 let mut local = VersionVector::new();
184 let mut remote = VersionVector::new();
185
186 for i in 0..PEER_COUNT {
189 remote.update(NodeId(i), SequenceNumber(10));
190 if i < 100 {
191 local.update(NodeId(i), SequenceNumber(5));
192 }
193 }
194
195 let gaps = local.find_gaps(&remote);
196 assert_eq!(gaps.len(), PEER_COUNT as usize, "must find 200 gaps");
197
198 let gap_map: HashMap<u32, SequenceNumber> =
199 gaps.into_iter().map(|(n, s)| (n.0, s)).collect();
200
201 for i in 0..PEER_COUNT {
202 if i < 100 {
203 assert_eq!(
205 gap_map[&i],
206 SequenceNumber(6),
207 "peer {i}: expected next seq 6"
208 );
209 } else {
210 assert_eq!(
212 gap_map[&i],
213 SequenceNumber::ZERO,
214 "peer {i}: expected seq 0"
215 );
216 }
217 }
218 }
219
220 #[test]
221 fn test_vv_find_gaps_ignores_nodes_not_in_remote() {
222 let mut local = VersionVector::new();
223 local.update(NodeId(1), SequenceNumber(5));
224
225 let remote = VersionVector::new();
228 assert!(local.find_gaps(&remote).is_empty());
229 }
230}
231
232#[derive(Archive, Deserialize, Serialize, Debug, Clone)]
233#[archive(check_bytes)]
234pub enum SyncMessage {
235 Handshake {
236 node_id: NodeId,
237 vv: VersionVector,
238 },
239 PullRequest {
240 origin_node: NodeId,
241 start_seq: SequenceNumber,
242 limit: u32,
243 },
244 EventBatch {
245 origin_node: NodeId,
246 events: Vec<Event>,
247 },
248 SyncComplete,
249}