Skip to main content

ruvector_replication/
conflict.rs

1//! Conflict resolution strategies for distributed replication
2//!
3//! Provides vector clocks for causality tracking and various
4//! conflict resolution strategies including Last-Write-Wins
5//! and custom merge functions.
6
7use crate::{ReplicationError, Result};
8use serde::{Deserialize, Serialize};
9use std::cmp::Ordering;
10use std::collections::HashMap;
11use std::fmt;
12
13/// Vector clock for tracking causality in distributed systems
14#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
15pub struct VectorClock {
16    /// Map of replica ID to logical timestamp
17    clock: HashMap<String, u64>,
18}
19
20impl VectorClock {
21    /// Create a new vector clock
22    pub fn new() -> Self {
23        Self {
24            clock: HashMap::new(),
25        }
26    }
27
28    /// Increment the clock for a replica
29    pub fn increment(&mut self, replica_id: &str) {
30        let counter = self.clock.entry(replica_id.to_string()).or_insert(0);
31        *counter += 1;
32    }
33
34    /// Get the timestamp for a replica
35    pub fn get(&self, replica_id: &str) -> u64 {
36        self.clock.get(replica_id).copied().unwrap_or(0)
37    }
38
39    /// Update with another vector clock (taking max of each component)
40    pub fn merge(&mut self, other: &VectorClock) {
41        for (replica_id, &timestamp) in &other.clock {
42            let current = self.clock.entry(replica_id.clone()).or_insert(0);
43            *current = (*current).max(timestamp);
44        }
45    }
46
47    /// Check if this clock happens-before another clock
48    pub fn happens_before(&self, other: &VectorClock) -> bool {
49        let mut less = false;
50        let mut equal = true;
51
52        // Check all replicas in self
53        for (replica_id, &self_ts) in &self.clock {
54            let other_ts = other.get(replica_id);
55            if self_ts > other_ts {
56                return false;
57            }
58            if self_ts < other_ts {
59                less = true;
60                equal = false;
61            }
62        }
63
64        // Check replicas only in other
65        for (replica_id, &other_ts) in &other.clock {
66            if !self.clock.contains_key(replica_id) && other_ts > 0 {
67                less = true;
68                equal = false;
69            }
70        }
71
72        less || equal
73    }
74
75    /// Compare vector clocks for causality
76    pub fn compare(&self, other: &VectorClock) -> ClockOrdering {
77        if self == other {
78            return ClockOrdering::Equal;
79        }
80
81        if self.happens_before(other) {
82            return ClockOrdering::Before;
83        }
84
85        if other.happens_before(self) {
86            return ClockOrdering::After;
87        }
88
89        ClockOrdering::Concurrent
90    }
91
92    /// Check if two clocks are concurrent (conflicting)
93    pub fn is_concurrent(&self, other: &VectorClock) -> bool {
94        matches!(self.compare(other), ClockOrdering::Concurrent)
95    }
96}
97
98impl Default for VectorClock {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104impl fmt::Display for VectorClock {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        write!(f, "{{")?;
107        for (i, (replica, ts)) in self.clock.iter().enumerate() {
108            if i > 0 {
109                write!(f, ", ")?;
110            }
111            write!(f, "{}: {}", replica, ts)?;
112        }
113        write!(f, "}}")
114    }
115}
116
117/// Ordering relationship between vector clocks
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119pub enum ClockOrdering {
120    /// Clocks are equal
121    Equal,
122    /// First clock happens before second
123    Before,
124    /// First clock happens after second
125    After,
126    /// Clocks are concurrent (conflicting)
127    Concurrent,
128}
129
130/// A versioned value with vector clock
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct Versioned<T> {
133    /// The value
134    pub value: T,
135    /// Vector clock for this version
136    pub clock: VectorClock,
137    /// Replica that created this version
138    pub replica_id: String,
139}
140
141impl<T> Versioned<T> {
142    /// Create a new versioned value
143    pub fn new(value: T, replica_id: String) -> Self {
144        let mut clock = VectorClock::new();
145        clock.increment(&replica_id);
146        Self {
147            value,
148            clock,
149            replica_id,
150        }
151    }
152
153    /// Update the version with a new value
154    pub fn update(&mut self, value: T) {
155        self.value = value;
156        self.clock.increment(&self.replica_id);
157    }
158
159    /// Compare versions for causality
160    pub fn compare(&self, other: &Versioned<T>) -> ClockOrdering {
161        self.clock.compare(&other.clock)
162    }
163}
164
165/// Trait for conflict resolution strategies
166pub trait ConflictResolver<T: Clone>: Send + Sync {
167    /// Resolve a conflict between two versions
168    fn resolve(&self, v1: &Versioned<T>, v2: &Versioned<T>) -> Result<Versioned<T>>;
169
170    /// Resolve multiple conflicting versions
171    fn resolve_many(&self, versions: Vec<Versioned<T>>) -> Result<Versioned<T>> {
172        if versions.is_empty() {
173            return Err(ReplicationError::ConflictResolution(
174                "No versions to resolve".to_string(),
175            ));
176        }
177
178        if versions.len() == 1 {
179            // SAFETY: We just checked versions.len() == 1
180            return Ok(versions
181                .into_iter()
182                .next()
183                .expect("versions verified non-empty"));
184        }
185
186        let mut result = versions[0].clone();
187        for version in versions.iter().skip(1) {
188            result = self.resolve(&result, version)?;
189        }
190        Ok(result)
191    }
192}
193
194/// Last-Write-Wins conflict resolution strategy
195pub struct LastWriteWins;
196
197impl<T: Clone> ConflictResolver<T> for LastWriteWins {
198    fn resolve(&self, v1: &Versioned<T>, v2: &Versioned<T>) -> Result<Versioned<T>> {
199        match v1.compare(v2) {
200            ClockOrdering::Before | ClockOrdering::Concurrent => Ok(v2.clone()),
201            ClockOrdering::After | ClockOrdering::Equal => Ok(v1.clone()),
202        }
203    }
204}
205
206/// Custom merge function for conflict resolution
207pub struct MergeFunction<T, F>
208where
209    F: Fn(&T, &T) -> T + Send + Sync,
210{
211    merge_fn: F,
212    _phantom: std::marker::PhantomData<T>,
213}
214
215impl<T, F> MergeFunction<T, F>
216where
217    F: Fn(&T, &T) -> T + Send + Sync,
218{
219    /// Create a new merge function resolver
220    pub fn new(merge_fn: F) -> Self {
221        Self {
222            merge_fn,
223            _phantom: std::marker::PhantomData,
224        }
225    }
226}
227
228impl<T: Clone + Send + Sync, F> ConflictResolver<T> for MergeFunction<T, F>
229where
230    F: Fn(&T, &T) -> T + Send + Sync,
231{
232    fn resolve(&self, v1: &Versioned<T>, v2: &Versioned<T>) -> Result<Versioned<T>> {
233        match v1.compare(v2) {
234            ClockOrdering::Equal | ClockOrdering::Before => Ok(v2.clone()),
235            ClockOrdering::After => Ok(v1.clone()),
236            ClockOrdering::Concurrent => {
237                let merged_value = (self.merge_fn)(&v1.value, &v2.value);
238                let mut merged_clock = v1.clock.clone();
239                merged_clock.merge(&v2.clock);
240
241                Ok(Versioned {
242                    value: merged_value,
243                    clock: merged_clock,
244                    replica_id: v1.replica_id.clone(),
245                })
246            }
247        }
248    }
249}
250
251/// CRDT-inspired merge for numeric values (takes max)
252pub struct MaxMerge;
253
254impl ConflictResolver<i64> for MaxMerge {
255    fn resolve(&self, v1: &Versioned<i64>, v2: &Versioned<i64>) -> Result<Versioned<i64>> {
256        match v1.compare(v2) {
257            ClockOrdering::Equal | ClockOrdering::Before => Ok(v2.clone()),
258            ClockOrdering::After => Ok(v1.clone()),
259            ClockOrdering::Concurrent => {
260                let merged_value = v1.value.max(v2.value);
261                let mut merged_clock = v1.clock.clone();
262                merged_clock.merge(&v2.clock);
263
264                Ok(Versioned {
265                    value: merged_value,
266                    clock: merged_clock,
267                    replica_id: v1.replica_id.clone(),
268                })
269            }
270        }
271    }
272}
273
274/// CRDT-inspired merge for sets (takes union)
275pub struct SetUnion;
276
277impl<T: Clone + Eq + std::hash::Hash> ConflictResolver<Vec<T>> for SetUnion {
278    fn resolve(&self, v1: &Versioned<Vec<T>>, v2: &Versioned<Vec<T>>) -> Result<Versioned<Vec<T>>> {
279        match v1.compare(v2) {
280            ClockOrdering::Equal | ClockOrdering::Before => Ok(v2.clone()),
281            ClockOrdering::After => Ok(v1.clone()),
282            ClockOrdering::Concurrent => {
283                let mut merged_value = v1.value.clone();
284                for item in &v2.value {
285                    if !merged_value.contains(item) {
286                        merged_value.push(item.clone());
287                    }
288                }
289
290                let mut merged_clock = v1.clock.clone();
291                merged_clock.merge(&v2.clock);
292
293                Ok(Versioned {
294                    value: merged_value,
295                    clock: merged_clock,
296                    replica_id: v1.replica_id.clone(),
297                })
298            }
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_vector_clock() {
309        let mut clock1 = VectorClock::new();
310        clock1.increment("r1");
311        clock1.increment("r1");
312
313        let mut clock2 = VectorClock::new();
314        clock2.increment("r1");
315
316        assert_eq!(clock1.compare(&clock2), ClockOrdering::After);
317        assert_eq!(clock2.compare(&clock1), ClockOrdering::Before);
318    }
319
320    #[test]
321    fn test_concurrent_clocks() {
322        let mut clock1 = VectorClock::new();
323        clock1.increment("r1");
324
325        let mut clock2 = VectorClock::new();
326        clock2.increment("r2");
327
328        assert_eq!(clock1.compare(&clock2), ClockOrdering::Concurrent);
329        assert!(clock1.is_concurrent(&clock2));
330    }
331
332    #[test]
333    fn test_clock_merge() {
334        let mut clock1 = VectorClock::new();
335        clock1.increment("r1");
336        clock1.increment("r1");
337
338        let mut clock2 = VectorClock::new();
339        clock2.increment("r2");
340        clock2.increment("r2");
341        clock2.increment("r2");
342
343        clock1.merge(&clock2);
344        assert_eq!(clock1.get("r1"), 2);
345        assert_eq!(clock1.get("r2"), 3);
346    }
347
348    #[test]
349    fn test_versioned() {
350        let mut v1 = Versioned::new(100, "r1".to_string());
351        v1.update(200);
352
353        assert_eq!(v1.value, 200);
354        assert_eq!(v1.clock.get("r1"), 2);
355    }
356
357    #[test]
358    fn test_last_write_wins() {
359        let v1 = Versioned::new(100, "r1".to_string());
360        let mut v2 = Versioned::new(200, "r1".to_string());
361        v2.clock.increment("r1");
362
363        let resolver = LastWriteWins;
364        let result = resolver.resolve(&v1, &v2).unwrap();
365        assert_eq!(result.value, 200);
366    }
367
368    #[test]
369    fn test_merge_function() {
370        let v1 = Versioned::new(100, "r1".to_string());
371        let v2 = Versioned::new(200, "r2".to_string());
372
373        let resolver = MergeFunction::new(|a, b| a + b);
374        let result = resolver.resolve(&v1, &v2).unwrap();
375        assert_eq!(result.value, 300);
376    }
377
378    #[test]
379    fn test_max_merge() {
380        let v1 = Versioned::new(100, "r1".to_string());
381        let v2 = Versioned::new(200, "r2".to_string());
382
383        let resolver = MaxMerge;
384        let result = resolver.resolve(&v1, &v2).unwrap();
385        assert_eq!(result.value, 200);
386    }
387
388    #[test]
389    fn test_set_union() {
390        let v1 = Versioned::new(vec![1, 2, 3], "r1".to_string());
391        let v2 = Versioned::new(vec![3, 4, 5], "r2".to_string());
392
393        let resolver = SetUnion;
394        let result = resolver.resolve(&v1, &v2).unwrap();
395        assert_eq!(result.value.len(), 5);
396        assert!(result.value.contains(&1));
397        assert!(result.value.contains(&4));
398    }
399}