1use crate::{ReplicaSet, ReplicationError, Result};
7use chrono::{DateTime, Utc};
8use dashmap::DashMap;
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::time::timeout;
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum SyncMode {
19 Sync,
21 Async,
23 SemiSync { min_replicas: usize },
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct LogEntry {
30 pub id: Uuid,
32 pub sequence: u64,
34 pub timestamp: DateTime<Utc>,
36 pub data: Vec<u8>,
38 pub checksum: u64,
40 pub source_replica: String,
42}
43
44impl LogEntry {
45 pub fn new(sequence: u64, data: Vec<u8>, source_replica: String) -> Self {
47 let checksum = Self::calculate_checksum(&data);
48 Self {
49 id: Uuid::new_v4(),
50 sequence,
51 timestamp: Utc::now(),
52 data,
53 checksum,
54 source_replica,
55 }
56 }
57
58 fn calculate_checksum(data: &[u8]) -> u64 {
60 use std::collections::hash_map::DefaultHasher;
61 use std::hash::{Hash, Hasher};
62 let mut hasher = DefaultHasher::new();
63 data.hash(&mut hasher);
64 hasher.finish()
65 }
66
67 pub fn verify(&self) -> bool {
69 Self::calculate_checksum(&self.data) == self.checksum
70 }
71}
72
73pub struct ReplicationLog {
75 entries: Arc<DashMap<u64, LogEntry>>,
77 sequence: Arc<RwLock<u64>>,
79 replica_id: String,
81}
82
83impl ReplicationLog {
84 pub fn new(replica_id: impl Into<String>) -> Self {
86 Self {
87 entries: Arc::new(DashMap::new()),
88 sequence: Arc::new(RwLock::new(0)),
89 replica_id: replica_id.into(),
90 }
91 }
92
93 pub fn append(&self, data: Vec<u8>) -> LogEntry {
95 let mut seq = self.sequence.write();
96 *seq += 1;
97 let entry = LogEntry::new(*seq, data, self.replica_id.clone());
98 self.entries.insert(*seq, entry.clone());
99 entry
100 }
101
102 pub fn get(&self, sequence: u64) -> Option<LogEntry> {
104 self.entries.get(&sequence).map(|e| e.clone())
105 }
106
107 pub fn get_range(&self, start: u64, end: u64) -> Vec<LogEntry> {
109 let mut entries = Vec::new();
110 for seq in start..=end {
111 if let Some(entry) = self.entries.get(&seq) {
112 entries.push(entry.clone());
113 }
114 }
115 entries
116 }
117
118 pub fn current_sequence(&self) -> u64 {
120 *self.sequence.read()
121 }
122
123 pub fn get_since(&self, since: u64) -> Vec<LogEntry> {
125 let current = self.current_sequence();
126 self.get_range(since + 1, current)
127 }
128
129 pub fn truncate_before(&self, before: u64) {
131 self.entries.retain(|seq, _| *seq >= before);
132 }
133
134 pub fn size(&self) -> usize {
136 self.entries.len()
137 }
138}
139
140pub struct SyncManager {
142 replica_set: Arc<ReplicaSet>,
144 log: Arc<ReplicationLog>,
146 sync_mode: Arc<RwLock<SyncMode>>,
148 sync_timeout: Duration,
150}
151
152impl SyncManager {
153 pub fn new(replica_set: Arc<ReplicaSet>, log: Arc<ReplicationLog>) -> Self {
155 Self {
156 replica_set,
157 log,
158 sync_mode: Arc::new(RwLock::new(SyncMode::Async)),
159 sync_timeout: Duration::from_secs(5),
160 }
161 }
162
163 pub fn set_sync_mode(&self, mode: SyncMode) {
165 *self.sync_mode.write() = mode;
166 }
167
168 pub fn sync_mode(&self) -> SyncMode {
170 *self.sync_mode.read()
171 }
172
173 pub fn set_sync_timeout(&mut self, timeout: Duration) {
175 self.sync_timeout = timeout;
176 }
177
178 pub async fn replicate(&self, data: Vec<u8>) -> Result<LogEntry> {
180 let entry = self.log.append(data);
182
183 let mode = self.sync_mode();
185
186 match mode {
187 SyncMode::Sync => {
188 self.replicate_sync(&entry).await?;
189 }
190 SyncMode::Async => {
191 let entry_clone = entry.clone();
193 let replica_set = self.replica_set.clone();
194 tokio::spawn(async move {
195 if let Err(e) = Self::send_to_replicas(&replica_set, &entry_clone).await {
196 tracing::error!("Async replication failed: {}", e);
197 }
198 });
199 }
200 SyncMode::SemiSync { min_replicas } => {
201 self.replicate_semi_sync(&entry, min_replicas).await?;
202 }
203 }
204
205 Ok(entry)
206 }
207
208 async fn replicate_sync(&self, entry: &LogEntry) -> Result<()> {
210 timeout(
211 self.sync_timeout,
212 Self::send_to_replicas(&self.replica_set, entry),
213 )
214 .await
215 .map_err(|_| ReplicationError::Timeout("Sync replication timed out".to_string()))?
216 }
217
218 async fn replicate_semi_sync(&self, entry: &LogEntry, min_replicas: usize) -> Result<()> {
220 let secondaries = self.replica_set.get_secondaries();
221 if secondaries.len() < min_replicas {
222 return Err(ReplicationError::QuorumNotMet {
223 needed: min_replicas,
224 available: secondaries.len(),
225 });
226 }
227
228 let entry_clone = entry.clone();
230 let replica_set = self.replica_set.clone();
231 let min = min_replicas;
232
233 timeout(self.sync_timeout, async move {
234 let acks = secondaries.len().min(min);
237 if acks >= min {
238 Ok(())
239 } else {
240 Err(ReplicationError::QuorumNotMet {
241 needed: min,
242 available: acks,
243 })
244 }
245 })
246 .await
247 .map_err(|_| ReplicationError::Timeout("Semi-sync replication timed out".to_string()))?
248 }
249
250 async fn send_to_replicas(replica_set: &ReplicaSet, entry: &LogEntry) -> Result<()> {
252 let secondaries = replica_set.get_secondaries();
253
254 for replica in secondaries {
257 if replica.is_healthy() {
258 tracing::debug!("Replicating entry {} to {}", entry.sequence, replica.id);
259 }
260 }
261
262 Ok(())
263 }
264
265 pub async fn catchup(&self, replica_id: &str, from_sequence: u64) -> Result<Vec<LogEntry>> {
267 let replica = self
268 .replica_set
269 .get_replica(replica_id)
270 .ok_or_else(|| ReplicationError::ReplicaNotFound(replica_id.to_string()))?;
271
272 let current_sequence = self.log.current_sequence();
273 if from_sequence >= current_sequence {
274 return Ok(Vec::new());
275 }
276
277 let entries = self.log.get_since(from_sequence);
279
280 tracing::info!(
281 "Catching up replica {} with {} entries (from {} to {})",
282 replica_id,
283 entries.len(),
284 from_sequence + 1,
285 current_sequence
286 );
287
288 Ok(entries)
289 }
290
291 pub fn current_position(&self) -> u64 {
293 self.log.current_sequence()
294 }
295
296 pub fn verify_entry(&self, sequence: u64) -> Result<bool> {
298 let entry = self
299 .log
300 .get(sequence)
301 .ok_or_else(|| ReplicationError::InvalidState("Log entry not found".to_string()))?;
302 Ok(entry.verify())
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::ReplicaRole;
310
311 #[test]
312 fn test_log_entry_creation() {
313 let data = b"test data".to_vec();
314 let entry = LogEntry::new(1, data, "replica-1".to_string());
315 assert_eq!(entry.sequence, 1);
316 assert!(entry.verify());
317 }
318
319 #[test]
320 fn test_replication_log() {
321 let log = ReplicationLog::new("replica-1");
322
323 let entry1 = log.append(b"data1".to_vec());
324 let entry2 = log.append(b"data2".to_vec());
325
326 assert_eq!(entry1.sequence, 1);
327 assert_eq!(entry2.sequence, 2);
328 assert_eq!(log.current_sequence(), 2);
329
330 let entries = log.get_range(1, 2);
331 assert_eq!(entries.len(), 2);
332 }
333
334 #[tokio::test]
335 async fn test_sync_manager() {
336 let mut replica_set = ReplicaSet::new("cluster-1");
337 replica_set
338 .add_replica("r1", "127.0.0.1:9001", ReplicaRole::Primary)
339 .unwrap();
340 replica_set
341 .add_replica("r2", "127.0.0.1:9002", ReplicaRole::Secondary)
342 .unwrap();
343
344 let log = Arc::new(ReplicationLog::new("r1"));
345 let manager = SyncManager::new(Arc::new(replica_set), log);
346
347 manager.set_sync_mode(SyncMode::Async);
348 let entry = manager.replicate(b"test".to_vec()).await.unwrap();
349 assert_eq!(entry.sequence, 1);
350 }
351
352 #[tokio::test]
353 async fn test_catchup() {
354 let mut replica_set = ReplicaSet::new("cluster-1");
355 replica_set
356 .add_replica("r1", "127.0.0.1:9001", ReplicaRole::Primary)
357 .unwrap();
358 replica_set
359 .add_replica("r2", "127.0.0.1:9002", ReplicaRole::Secondary)
360 .unwrap();
361
362 let log = Arc::new(ReplicationLog::new("r1"));
363 let manager = SyncManager::new(Arc::new(replica_set), log.clone());
364
365 log.append(b"data1".to_vec());
367 log.append(b"data2".to_vec());
368 log.append(b"data3".to_vec());
369
370 let entries = manager.catchup("r2", 1).await.unwrap();
372 assert_eq!(entries.len(), 2); }
374}