1use crate::{ReplicationError, Result};
8use serde::{Deserialize, Serialize};
9use std::cmp::Ordering;
10use std::collections::HashMap;
11use std::fmt;
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
15pub struct VectorClock {
16 clock: HashMap<String, u64>,
18}
19
20impl VectorClock {
21 pub fn new() -> Self {
23 Self {
24 clock: HashMap::new(),
25 }
26 }
27
28 pub fn increment(&mut self, replica_id: &str) {
30 let counter = self.clock.entry(replica_id.to_string()).or_insert(0);
31 *counter += 1;
32 }
33
34 pub fn get(&self, replica_id: &str) -> u64 {
36 self.clock.get(replica_id).copied().unwrap_or(0)
37 }
38
39 pub fn merge(&mut self, other: &VectorClock) {
41 for (replica_id, ×tamp) in &other.clock {
42 let current = self.clock.entry(replica_id.clone()).or_insert(0);
43 *current = (*current).max(timestamp);
44 }
45 }
46
47 pub fn happens_before(&self, other: &VectorClock) -> bool {
49 let mut less = false;
50 let mut equal = true;
51
52 for (replica_id, &self_ts) in &self.clock {
54 let other_ts = other.get(replica_id);
55 if self_ts > other_ts {
56 return false;
57 }
58 if self_ts < other_ts {
59 less = true;
60 equal = false;
61 }
62 }
63
64 for (replica_id, &other_ts) in &other.clock {
66 if !self.clock.contains_key(replica_id) && other_ts > 0 {
67 less = true;
68 equal = false;
69 }
70 }
71
72 less || equal
73 }
74
75 pub fn compare(&self, other: &VectorClock) -> ClockOrdering {
77 if self == other {
78 return ClockOrdering::Equal;
79 }
80
81 if self.happens_before(other) {
82 return ClockOrdering::Before;
83 }
84
85 if other.happens_before(self) {
86 return ClockOrdering::After;
87 }
88
89 ClockOrdering::Concurrent
90 }
91
92 pub fn is_concurrent(&self, other: &VectorClock) -> bool {
94 matches!(self.compare(other), ClockOrdering::Concurrent)
95 }
96}
97
98impl Default for VectorClock {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104impl fmt::Display for VectorClock {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 write!(f, "{{")?;
107 for (i, (replica, ts)) in self.clock.iter().enumerate() {
108 if i > 0 {
109 write!(f, ", ")?;
110 }
111 write!(f, "{}: {}", replica, ts)?;
112 }
113 write!(f, "}}")
114 }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119pub enum ClockOrdering {
120 Equal,
122 Before,
124 After,
126 Concurrent,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct Versioned<T> {
133 pub value: T,
135 pub clock: VectorClock,
137 pub replica_id: String,
139}
140
141impl<T> Versioned<T> {
142 pub fn new(value: T, replica_id: String) -> Self {
144 let mut clock = VectorClock::new();
145 clock.increment(&replica_id);
146 Self {
147 value,
148 clock,
149 replica_id,
150 }
151 }
152
153 pub fn update(&mut self, value: T) {
155 self.value = value;
156 self.clock.increment(&self.replica_id);
157 }
158
159 pub fn compare(&self, other: &Versioned<T>) -> ClockOrdering {
161 self.clock.compare(&other.clock)
162 }
163}
164
165pub trait ConflictResolver<T: Clone>: Send + Sync {
167 fn resolve(&self, v1: &Versioned<T>, v2: &Versioned<T>) -> Result<Versioned<T>>;
169
170 fn resolve_many(&self, versions: Vec<Versioned<T>>) -> Result<Versioned<T>> {
172 if versions.is_empty() {
173 return Err(ReplicationError::ConflictResolution(
174 "No versions to resolve".to_string(),
175 ));
176 }
177
178 if versions.len() == 1 {
179 return Ok(versions
181 .into_iter()
182 .next()
183 .expect("versions verified non-empty"));
184 }
185
186 let mut result = versions[0].clone();
187 for version in versions.iter().skip(1) {
188 result = self.resolve(&result, version)?;
189 }
190 Ok(result)
191 }
192}
193
194pub struct LastWriteWins;
196
197impl<T: Clone> ConflictResolver<T> for LastWriteWins {
198 fn resolve(&self, v1: &Versioned<T>, v2: &Versioned<T>) -> Result<Versioned<T>> {
199 match v1.compare(v2) {
200 ClockOrdering::Before | ClockOrdering::Concurrent => Ok(v2.clone()),
201 ClockOrdering::After | ClockOrdering::Equal => Ok(v1.clone()),
202 }
203 }
204}
205
206pub struct MergeFunction<T, F>
208where
209 F: Fn(&T, &T) -> T + Send + Sync,
210{
211 merge_fn: F,
212 _phantom: std::marker::PhantomData<T>,
213}
214
215impl<T, F> MergeFunction<T, F>
216where
217 F: Fn(&T, &T) -> T + Send + Sync,
218{
219 pub fn new(merge_fn: F) -> Self {
221 Self {
222 merge_fn,
223 _phantom: std::marker::PhantomData,
224 }
225 }
226}
227
228impl<T: Clone + Send + Sync, F> ConflictResolver<T> for MergeFunction<T, F>
229where
230 F: Fn(&T, &T) -> T + Send + Sync,
231{
232 fn resolve(&self, v1: &Versioned<T>, v2: &Versioned<T>) -> Result<Versioned<T>> {
233 match v1.compare(v2) {
234 ClockOrdering::Equal | ClockOrdering::Before => Ok(v2.clone()),
235 ClockOrdering::After => Ok(v1.clone()),
236 ClockOrdering::Concurrent => {
237 let merged_value = (self.merge_fn)(&v1.value, &v2.value);
238 let mut merged_clock = v1.clock.clone();
239 merged_clock.merge(&v2.clock);
240
241 Ok(Versioned {
242 value: merged_value,
243 clock: merged_clock,
244 replica_id: v1.replica_id.clone(),
245 })
246 }
247 }
248 }
249}
250
251pub struct MaxMerge;
253
254impl ConflictResolver<i64> for MaxMerge {
255 fn resolve(&self, v1: &Versioned<i64>, v2: &Versioned<i64>) -> Result<Versioned<i64>> {
256 match v1.compare(v2) {
257 ClockOrdering::Equal | ClockOrdering::Before => Ok(v2.clone()),
258 ClockOrdering::After => Ok(v1.clone()),
259 ClockOrdering::Concurrent => {
260 let merged_value = v1.value.max(v2.value);
261 let mut merged_clock = v1.clock.clone();
262 merged_clock.merge(&v2.clock);
263
264 Ok(Versioned {
265 value: merged_value,
266 clock: merged_clock,
267 replica_id: v1.replica_id.clone(),
268 })
269 }
270 }
271 }
272}
273
274pub struct SetUnion;
276
277impl<T: Clone + Eq + std::hash::Hash> ConflictResolver<Vec<T>> for SetUnion {
278 fn resolve(&self, v1: &Versioned<Vec<T>>, v2: &Versioned<Vec<T>>) -> Result<Versioned<Vec<T>>> {
279 match v1.compare(v2) {
280 ClockOrdering::Equal | ClockOrdering::Before => Ok(v2.clone()),
281 ClockOrdering::After => Ok(v1.clone()),
282 ClockOrdering::Concurrent => {
283 let mut merged_value = v1.value.clone();
284 for item in &v2.value {
285 if !merged_value.contains(item) {
286 merged_value.push(item.clone());
287 }
288 }
289
290 let mut merged_clock = v1.clock.clone();
291 merged_clock.merge(&v2.clock);
292
293 Ok(Versioned {
294 value: merged_value,
295 clock: merged_clock,
296 replica_id: v1.replica_id.clone(),
297 })
298 }
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_vector_clock() {
309 let mut clock1 = VectorClock::new();
310 clock1.increment("r1");
311 clock1.increment("r1");
312
313 let mut clock2 = VectorClock::new();
314 clock2.increment("r1");
315
316 assert_eq!(clock1.compare(&clock2), ClockOrdering::After);
317 assert_eq!(clock2.compare(&clock1), ClockOrdering::Before);
318 }
319
320 #[test]
321 fn test_concurrent_clocks() {
322 let mut clock1 = VectorClock::new();
323 clock1.increment("r1");
324
325 let mut clock2 = VectorClock::new();
326 clock2.increment("r2");
327
328 assert_eq!(clock1.compare(&clock2), ClockOrdering::Concurrent);
329 assert!(clock1.is_concurrent(&clock2));
330 }
331
332 #[test]
333 fn test_clock_merge() {
334 let mut clock1 = VectorClock::new();
335 clock1.increment("r1");
336 clock1.increment("r1");
337
338 let mut clock2 = VectorClock::new();
339 clock2.increment("r2");
340 clock2.increment("r2");
341 clock2.increment("r2");
342
343 clock1.merge(&clock2);
344 assert_eq!(clock1.get("r1"), 2);
345 assert_eq!(clock1.get("r2"), 3);
346 }
347
348 #[test]
349 fn test_versioned() {
350 let mut v1 = Versioned::new(100, "r1".to_string());
351 v1.update(200);
352
353 assert_eq!(v1.value, 200);
354 assert_eq!(v1.clock.get("r1"), 2);
355 }
356
357 #[test]
358 fn test_last_write_wins() {
359 let v1 = Versioned::new(100, "r1".to_string());
360 let mut v2 = Versioned::new(200, "r1".to_string());
361 v2.clock.increment("r1");
362
363 let resolver = LastWriteWins;
364 let result = resolver.resolve(&v1, &v2).unwrap();
365 assert_eq!(result.value, 200);
366 }
367
368 #[test]
369 fn test_merge_function() {
370 let v1 = Versioned::new(100, "r1".to_string());
371 let v2 = Versioned::new(200, "r2".to_string());
372
373 let resolver = MergeFunction::new(|a, b| a + b);
374 let result = resolver.resolve(&v1, &v2).unwrap();
375 assert_eq!(result.value, 300);
376 }
377
378 #[test]
379 fn test_max_merge() {
380 let v1 = Versioned::new(100, "r1".to_string());
381 let v2 = Versioned::new(200, "r2".to_string());
382
383 let resolver = MaxMerge;
384 let result = resolver.resolve(&v1, &v2).unwrap();
385 assert_eq!(result.value, 200);
386 }
387
388 #[test]
389 fn test_set_union() {
390 let v1 = Versioned::new(vec![1, 2, 3], "r1".to_string());
391 let v2 = Versioned::new(vec![3, 4, 5], "r2".to_string());
392
393 let resolver = SetUnion;
394 let result = resolver.resolve(&v1, &v2).unwrap();
395 assert_eq!(result.value.len(), 5);
396 assert!(result.value.contains(&1));
397 assert!(result.value.contains(&4));
398 }
399}