1use crate::{Event, NodeId, SequenceNumber};
2use rkyv::{Archive, Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default, PartialEq, Eq)]
6#[archive(check_bytes)]
7pub struct VersionVector {
8 pub entries: HashMap<u32, SequenceNumber>,
9}
10
11impl VersionVector {
12 pub fn new() -> Self {
13 Self::default()
14 }
15
16 pub fn update(&mut self, node_id: NodeId, seq: SequenceNumber) {
17 let entry = self
18 .entries
19 .entry(node_id.0)
20 .or_insert(SequenceNumber::ZERO);
21 if seq > *entry {
22 *entry = seq;
23 }
24 }
25
26 pub fn get(&self, node_id: NodeId) -> SequenceNumber {
27 self.entries
28 .get(&node_id.0)
29 .cloned()
30 .unwrap_or(SequenceNumber::ZERO)
31 }
32
33 pub fn find_gaps(&self, other: &VersionVector) -> Vec<(NodeId, SequenceNumber)> {
37 let mut gaps = Vec::new();
38 for (node_id_raw, other_seq) in &other.entries {
39 let node_id = NodeId(*node_id_raw);
40 match self.entries.get(node_id_raw) {
41 Some(&local_last) if *other_seq > local_last => {
42 gaps.push((node_id, local_last.next()));
43 }
44 None => {
45 gaps.push((node_id, SequenceNumber::ZERO));
46 }
47 _ => {}
48 }
49 }
50 gaps
51 }
52}
53
54#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default)]
55#[archive(check_bytes)]
56pub struct PeerSyncState {
57 pub known_vv: VersionVector,
58 pub last_acked: Option<SequenceNumber>,
59}
60
61#[derive(Archive, Deserialize, Serialize, Debug, Clone, Default)]
62#[archive(check_bytes)]
63pub struct ReplicationState {
64 pub self_id: NodeId,
65 pub local_vv: VersionVector,
66 pub peers: HashMap<u32, PeerSyncState>,
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use std::collections::HashMap;
73
74 #[test]
75 fn test_vv_update_only_advances() {
76 let mut vv = VersionVector::new();
77 vv.update(NodeId(1), SequenceNumber(5));
78 assert_eq!(vv.get(NodeId(1)), SequenceNumber(5));
79 vv.update(NodeId(1), SequenceNumber(3));
81 assert_eq!(vv.get(NodeId(1)), SequenceNumber(5), "VV must never decrease");
82 vv.update(NodeId(1), SequenceNumber(5));
84 assert_eq!(vv.get(NodeId(1)), SequenceNumber(5));
85 vv.update(NodeId(1), SequenceNumber(10));
87 assert_eq!(vv.get(NodeId(1)), SequenceNumber(10));
88 }
89
90 #[test]
91 fn test_vv_get_unknown_node_returns_zero() {
92 let vv = VersionVector::new();
93 assert_eq!(vv.get(NodeId(42)), SequenceNumber::ZERO);
94 }
95
96 #[test]
97 fn test_vv_find_gaps_empty_local_needs_everything_from_zero() {
98 let local = VersionVector::new();
99 let mut remote = VersionVector::new();
100 remote.update(NodeId(1), SequenceNumber(10));
101 remote.update(NodeId(2), SequenceNumber(5));
102
103 let gaps: HashMap<u32, SequenceNumber> =
104 local.find_gaps(&remote).into_iter().map(|(n, s)| (n.0, s)).collect();
105 assert_eq!(gaps[&1], SequenceNumber::ZERO, "unknown node: start from seq 0");
106 assert_eq!(gaps[&2], SequenceNumber::ZERO);
107 }
108
109 #[test]
110 fn test_vv_find_gaps_partial_overlap_returns_next_needed() {
111 let mut local = VersionVector::new();
112 local.update(NodeId(1), SequenceNumber(5));
113
114 let mut remote = VersionVector::new();
115 remote.update(NodeId(1), SequenceNumber(10));
116 remote.update(NodeId(2), SequenceNumber(3));
117
118 let gaps: HashMap<u32, SequenceNumber> =
119 local.find_gaps(&remote).into_iter().map(|(n, s)| (n.0, s)).collect();
120 assert_eq!(gaps[&1], SequenceNumber(6));
122 assert_eq!(gaps[&2], SequenceNumber::ZERO);
124 }
125
126 #[test]
127 fn test_vv_find_gaps_up_to_date_returns_empty() {
128 let mut local = VersionVector::new();
129 local.update(NodeId(1), SequenceNumber(10));
130
131 let mut remote = VersionVector::new();
132 remote.update(NodeId(1), SequenceNumber(10));
133
134 assert!(local.find_gaps(&remote).is_empty(), "no gap when equal");
135 }
136
137 #[test]
138 fn test_vv_find_gaps_local_ahead_returns_empty() {
139 let mut local = VersionVector::new();
140 local.update(NodeId(1), SequenceNumber(15));
141
142 let mut remote = VersionVector::new();
143 remote.update(NodeId(1), SequenceNumber(10));
144
145 assert!(
146 local.find_gaps(&remote).is_empty(),
147 "local is ahead, nothing to pull"
148 );
149 }
150
151 #[test]
152 fn test_vv_find_gaps_start_is_inclusive_next_seq() {
153 let mut local = VersionVector::new();
154 local.update(NodeId(1), SequenceNumber(3));
155
156 let mut remote = VersionVector::new();
157 remote.update(NodeId(1), SequenceNumber(7));
158
159 let gaps = local.find_gaps(&remote);
160 assert_eq!(gaps.len(), 1);
161 assert_eq!(gaps[0], (NodeId(1), SequenceNumber(4)));
163 }
164
165 #[test]
166 fn test_vv_find_gaps_200_peers_correct_at_scale() {
167 const PEER_COUNT: u32 = 200;
168
169 let mut local = VersionVector::new();
170 let mut remote = VersionVector::new();
171
172 for i in 0..PEER_COUNT {
175 remote.update(NodeId(i), SequenceNumber(10));
176 if i < 100 {
177 local.update(NodeId(i), SequenceNumber(5));
178 }
179 }
180
181 let gaps = local.find_gaps(&remote);
182 assert_eq!(gaps.len(), PEER_COUNT as usize, "must find 200 gaps");
183
184 let gap_map: HashMap<u32, SequenceNumber> =
185 gaps.into_iter().map(|(n, s)| (n.0, s)).collect();
186
187 for i in 0..PEER_COUNT {
188 if i < 100 {
189 assert_eq!(gap_map[&i], SequenceNumber(6), "peer {i}: expected next seq 6");
191 } else {
192 assert_eq!(gap_map[&i], SequenceNumber::ZERO, "peer {i}: expected seq 0");
194 }
195 }
196 }
197
198 #[test]
199 fn test_vv_find_gaps_ignores_nodes_not_in_remote() {
200 let mut local = VersionVector::new();
201 local.update(NodeId(1), SequenceNumber(5));
202
203 let remote = VersionVector::new();
206 assert!(local.find_gaps(&remote).is_empty());
207 }
208}
209
210#[derive(Archive, Deserialize, Serialize, Debug, Clone)]
211#[archive(check_bytes)]
212pub enum SyncMessage {
213 Handshake {
214 node_id: NodeId,
215 vv: VersionVector,
216 },
217 PullRequest {
218 origin_node: NodeId,
219 start_seq: SequenceNumber,
220 limit: u32,
221 },
222 EventBatch {
223 origin_node: NodeId,
224 events: Vec<Event>,
225 },
226 SyncComplete,
227}