Skip to main content

sqlrite/
ha.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::{HashMap, HashSet};
4use std::time::{SystemTime, UNIX_EPOCH};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
7#[serde(rename_all = "snake_case")]
8pub enum ServerRole {
9    #[default]
10    Standalone,
11    Primary,
12    Replica,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
16#[serde(rename_all = "snake_case")]
17pub enum FailoverMode {
18    #[default]
19    Manual,
20    Automatic,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ReplicationConfig {
25    pub enabled: bool,
26    pub cluster_id: String,
27    pub node_id: String,
28    pub role: ServerRole,
29    pub advertise_addr: String,
30    pub peers: Vec<String>,
31    pub sync_ack_quorum: usize,
32    pub heartbeat_interval_ms: u64,
33    pub election_timeout_ms: u64,
34    pub max_replication_lag_ms: u64,
35    pub failover_mode: FailoverMode,
36}
37
38impl Default for ReplicationConfig {
39    fn default() -> Self {
40        Self {
41            enabled: false,
42            cluster_id: "local-cluster".to_string(),
43            node_id: "node-1".to_string(),
44            role: ServerRole::Standalone,
45            advertise_addr: "127.0.0.1:8099".to_string(),
46            peers: Vec::new(),
47            sync_ack_quorum: 1,
48            heartbeat_interval_ms: 1_000,
49            election_timeout_ms: 3_000,
50            max_replication_lag_ms: 2_000,
51            failover_mode: FailoverMode::Manual,
52        }
53    }
54}
55
56impl ReplicationConfig {
57    pub fn validate(&self) -> Result<(), String> {
58        if !self.enabled {
59            if self.role != ServerRole::Standalone {
60                return Err(
61                    "replication role must be `standalone` when replication is disabled"
62                        .to_string(),
63                );
64            }
65            if !self.peers.is_empty() {
66                return Err(
67                    "replication peers are not allowed when replication is disabled".to_string(),
68                );
69            }
70            return Ok(());
71        }
72
73        if self.cluster_id.trim().is_empty() {
74            return Err(
75                "replication cluster_id cannot be empty when replication is enabled".to_string(),
76            );
77        }
78        if self.node_id.trim().is_empty() {
79            return Err(
80                "replication node_id cannot be empty when replication is enabled".to_string(),
81            );
82        }
83        if self.advertise_addr.trim().is_empty() {
84            return Err(
85                "replication advertise_addr cannot be empty when replication is enabled"
86                    .to_string(),
87            );
88        }
89        if self.role == ServerRole::Standalone {
90            return Err(
91                "replication role must be `primary` or `replica` when replication is enabled"
92                    .to_string(),
93            );
94        }
95        if self.sync_ack_quorum == 0 {
96            return Err("replication sync_ack_quorum must be at least 1".to_string());
97        }
98        if self.heartbeat_interval_ms == 0 {
99            return Err("replication heartbeat_interval_ms must be greater than 0".to_string());
100        }
101        if self.election_timeout_ms <= self.heartbeat_interval_ms {
102            return Err(
103                "replication election_timeout_ms must be greater than heartbeat_interval_ms"
104                    .to_string(),
105            );
106        }
107        if self.max_replication_lag_ms == 0 {
108            return Err("replication max_replication_lag_ms must be greater than 0".to_string());
109        }
110
111        if self.role == ServerRole::Primary {
112            let cluster_size = self.peers.len() + 1;
113            if self.sync_ack_quorum > cluster_size {
114                return Err(format!(
115                    "replication sync_ack_quorum {} exceeds cluster size {}",
116                    self.sync_ack_quorum, cluster_size
117                ));
118            }
119        }
120
121        Ok(())
122    }
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct RecoveryConfig {
127    pub backup_dir: String,
128    pub snapshot_interval_seconds: u64,
129    pub pitr_retention_seconds: u64,
130}
131
132impl Default for RecoveryConfig {
133    fn default() -> Self {
134        Self {
135            backup_dir: "./backups".to_string(),
136            snapshot_interval_seconds: 300,
137            pitr_retention_seconds: 86_400,
138        }
139    }
140}
141
142impl RecoveryConfig {
143    pub fn validate(&self) -> Result<(), String> {
144        if self.backup_dir.trim().is_empty() {
145            return Err("recovery backup_dir cannot be empty".to_string());
146        }
147        if self.snapshot_interval_seconds == 0 {
148            return Err("recovery snapshot_interval_seconds must be greater than 0".to_string());
149        }
150        if self.pitr_retention_seconds < self.snapshot_interval_seconds {
151            return Err(
152                "recovery pitr_retention_seconds must be >= snapshot_interval_seconds".to_string(),
153            );
154        }
155        Ok(())
156    }
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize, Default)]
160pub struct HaRuntimeProfile {
161    pub replication: ReplicationConfig,
162    pub recovery: RecoveryConfig,
163}
164
165impl HaRuntimeProfile {
166    pub fn validate(&self) -> Result<(), String> {
167        self.replication.validate()?;
168        self.recovery.validate()?;
169        Ok(())
170    }
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct HaRuntimeState {
175    pub role: ServerRole,
176    pub leader_id: Option<String>,
177    pub current_term: u64,
178    pub voted_for: Option<String>,
179    pub commit_index: u64,
180    pub last_applied_index: u64,
181    pub last_log_index: u64,
182    pub last_log_term: u64,
183    pub replication_lag_ms: u64,
184    pub failover_in_progress: bool,
185    pub last_transition_unix_ms: u64,
186    pub last_heartbeat_unix_ms: Option<u64>,
187    pub last_recovery_event: Option<String>,
188}
189
190impl HaRuntimeState {
191    pub fn new(profile: &HaRuntimeProfile) -> Self {
192        let role = profile.replication.role;
193        let leader_id = if role == ServerRole::Primary {
194            Some(profile.replication.node_id.clone())
195        } else {
196            None
197        };
198        Self {
199            role,
200            leader_id,
201            current_term: 0,
202            voted_for: None,
203            commit_index: 0,
204            last_applied_index: 0,
205            last_log_index: 0,
206            last_log_term: 0,
207            replication_lag_ms: 0,
208            failover_in_progress: false,
209            last_transition_unix_ms: unix_ms_now(),
210            last_heartbeat_unix_ms: None,
211            last_recovery_event: None,
212        }
213    }
214
215    pub fn promote_to_primary(&mut self, node_id: String) {
216        self.role = ServerRole::Primary;
217        self.current_term = self.current_term.saturating_add(1).max(1);
218        self.voted_for = Some(node_id.clone());
219        self.leader_id = Some(node_id);
220        self.failover_in_progress = false;
221        self.last_transition_unix_ms = unix_ms_now();
222    }
223
224    pub fn step_down_to_replica(&mut self, leader_id: Option<String>) {
225        self.role = ServerRole::Replica;
226        self.leader_id = leader_id;
227        self.failover_in_progress = false;
228        self.last_transition_unix_ms = unix_ms_now();
229    }
230
231    pub fn mark_failover_started(&mut self) {
232        self.failover_in_progress = true;
233        self.last_transition_unix_ms = unix_ms_now();
234    }
235
236    pub fn mark_heartbeat(
237        &mut self,
238        leader_id: Option<String>,
239        commit_index: u64,
240        replication_lag_ms: u64,
241    ) {
242        self.leader_id = leader_id;
243        self.commit_index = self.commit_index.max(commit_index).min(self.last_log_index);
244        self.last_applied_index = self.last_applied_index.max(self.commit_index);
245        self.replication_lag_ms = replication_lag_ms;
246        self.last_heartbeat_unix_ms = Some(unix_ms_now());
247    }
248
249    pub fn mark_recovery_event(&mut self, event: String) {
250        self.last_recovery_event = Some(event);
251        self.last_transition_unix_ms = unix_ms_now();
252    }
253
254    pub fn adopt_term(&mut self, term: u64) {
255        if term > self.current_term {
256            self.current_term = term;
257            self.voted_for = None;
258            if self.role == ServerRole::Primary {
259                self.role = ServerRole::Replica;
260            }
261            self.last_transition_unix_ms = unix_ms_now();
262        }
263    }
264
265    pub fn grant_vote(&mut self, term: u64, candidate_id: String) {
266        self.adopt_term(term);
267        self.voted_for = Some(candidate_id);
268        self.last_transition_unix_ms = unix_ms_now();
269    }
270
271    pub fn can_grant_vote(
272        &self,
273        term: u64,
274        candidate_id: &str,
275        candidate_last_log_index: u64,
276        candidate_last_log_term: u64,
277    ) -> bool {
278        if term < self.current_term {
279            return false;
280        }
281        if term == self.current_term
282            && self
283                .voted_for
284                .as_ref()
285                .is_some_and(|voted| voted != candidate_id)
286        {
287            return false;
288        }
289
290        candidate_last_log_term > self.last_log_term
291            || (candidate_last_log_term == self.last_log_term
292                && candidate_last_log_index >= self.last_log_index)
293    }
294
295    pub fn note_log_position(&mut self, last_log_index: u64, last_log_term: u64) {
296        self.last_log_index = last_log_index;
297        self.last_log_term = last_log_term;
298    }
299
300    pub fn advance_commit_index(&mut self, commit_index: u64) {
301        self.commit_index = commit_index.min(self.last_log_index).max(self.commit_index);
302        self.last_applied_index = self.last_applied_index.max(self.commit_index);
303    }
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
307pub struct ReplicationLogEntry {
308    pub index: u64,
309    pub term: u64,
310    pub leader_id: String,
311    pub operation: String,
312    pub payload: Value,
313    pub checksum: String,
314    pub created_at_unix_ms: u64,
315}
316
317impl ReplicationLogEntry {
318    pub fn new(
319        index: u64,
320        term: u64,
321        leader_id: String,
322        operation: String,
323        payload: Value,
324    ) -> Result<Self, String> {
325        if index == 0 {
326            return Err("replication log index must be >= 1".to_string());
327        }
328        if term == 0 {
329            return Err("replication log term must be >= 1".to_string());
330        }
331        if leader_id.trim().is_empty() {
332            return Err("replication log leader_id cannot be empty".to_string());
333        }
334        if operation.trim().is_empty() {
335            return Err("replication log operation cannot be empty".to_string());
336        }
337        let checksum = compute_log_checksum(index, term, &leader_id, &operation, &payload)?;
338        Ok(Self {
339            index,
340            term,
341            leader_id,
342            operation,
343            payload,
344            checksum,
345            created_at_unix_ms: unix_ms_now(),
346        })
347    }
348
349    pub fn verify_checksum(&self) -> bool {
350        compute_log_checksum(
351            self.index,
352            self.term,
353            &self.leader_id,
354            &self.operation,
355            &self.payload,
356        )
357        .is_ok_and(|checksum| checksum == self.checksum)
358    }
359}
360
361#[derive(Debug, Clone, Default)]
362pub struct ReplicationLog {
363    entries: Vec<ReplicationLogEntry>,
364    acked_by: HashMap<u64, HashSet<String>>,
365}
366
367impl ReplicationLog {
368    pub fn new() -> Self {
369        Self::default()
370    }
371
372    pub fn from_entries(entries: Vec<ReplicationLogEntry>) -> Result<Self, String> {
373        let mut out = Self::new();
374        if entries.is_empty() {
375            return Ok(out);
376        }
377
378        let mut sorted = entries;
379        sorted.sort_by_key(|entry| entry.index);
380        for (position, entry) in sorted.iter().enumerate() {
381            let expected_index = (position as u64) + 1;
382            if entry.index != expected_index {
383                return Err(format!(
384                    "replication log index gap: expected {}, found {}",
385                    expected_index, entry.index
386                ));
387            }
388            if !entry.verify_checksum() {
389                return Err(format!(
390                    "replication log checksum mismatch at index {}",
391                    entry.index
392                ));
393            }
394        }
395
396        out.entries = sorted;
397        Ok(out)
398    }
399
400    pub fn entries(&self) -> &[ReplicationLogEntry] {
401        &self.entries
402    }
403
404    pub fn len(&self) -> usize {
405        self.entries.len()
406    }
407
408    pub fn is_empty(&self) -> bool {
409        self.entries.is_empty()
410    }
411
412    pub fn last_index(&self) -> u64 {
413        self.entries.last().map(|entry| entry.index).unwrap_or(0)
414    }
415
416    pub fn last_term(&self) -> u64 {
417        self.entries.last().map(|entry| entry.term).unwrap_or(0)
418    }
419
420    pub fn entry_at(&self, index: u64) -> Option<&ReplicationLogEntry> {
421        if index == 0 {
422            return None;
423        }
424        self.entries.get((index - 1) as usize)
425    }
426
427    pub fn entries_from(&self, start_index: u64, limit: usize) -> Vec<ReplicationLogEntry> {
428        if limit == 0 {
429            return Vec::new();
430        }
431        self.entries
432            .iter()
433            .filter(|entry| entry.index >= start_index)
434            .take(limit)
435            .cloned()
436            .collect()
437    }
438
439    pub fn append_leader_event(
440        &mut self,
441        term: u64,
442        leader_id: &str,
443        operation: String,
444        payload: Value,
445        local_node_id: &str,
446    ) -> Result<ReplicationLogEntry, String> {
447        let index = self.last_index().saturating_add(1);
448        let entry =
449            ReplicationLogEntry::new(index, term, leader_id.to_string(), operation, payload)?;
450        self.entries.push(entry.clone());
451        self.acknowledge(index, local_node_id.to_string());
452        Ok(entry)
453    }
454
455    pub fn append_remote_entries(
456        &mut self,
457        prev_log_index: u64,
458        prev_log_term: u64,
459        entries: &[ReplicationLogEntry],
460    ) -> Result<(), String> {
461        if prev_log_index > self.last_index() {
462            return Err(format!(
463                "replication log mismatch: prev_log_index {} beyond local last_index {}",
464                prev_log_index,
465                self.last_index()
466            ));
467        }
468        if prev_log_index > 0 {
469            let local_prev = self
470                .entry_at(prev_log_index)
471                .ok_or_else(|| "replication log previous entry missing".to_string())?;
472            if local_prev.term != prev_log_term {
473                return Err(format!(
474                    "replication log mismatch: prev term {} does not match local term {} at index {}",
475                    prev_log_term, local_prev.term, prev_log_index
476                ));
477            }
478        }
479
480        let mut cursor = prev_log_index;
481        for incoming in entries {
482            cursor = cursor.saturating_add(1);
483            if incoming.index != cursor {
484                return Err(format!(
485                    "replication log sequence mismatch: expected incoming index {}, found {}",
486                    cursor, incoming.index
487                ));
488            }
489            if !incoming.verify_checksum() {
490                return Err(format!(
491                    "replication log checksum mismatch at incoming index {}",
492                    incoming.index
493                ));
494            }
495
496            if let Some(existing) = self.entry_at(incoming.index)
497                && (existing.term != incoming.term || existing.checksum != incoming.checksum)
498            {
499                self.truncate_from(incoming.index);
500            }
501            if self.entry_at(incoming.index).is_none() {
502                self.entries.push(incoming.clone());
503            }
504        }
505
506        Ok(())
507    }
508
509    pub fn acknowledge(&mut self, index: u64, node_id: String) -> usize {
510        if index == 0 || index > self.last_index() {
511            return 0;
512        }
513        let node_set = self.acked_by.entry(index).or_default();
514        node_set.insert(node_id);
515        node_set.len()
516    }
517
518    pub fn ack_count(&self, index: u64) -> usize {
519        self.acked_by
520            .get(&index)
521            .map(HashSet::len)
522            .unwrap_or_default()
523    }
524
525    pub fn compute_commit_index(&self, current_commit_index: u64, quorum: usize) -> u64 {
526        if quorum == 0 {
527            return current_commit_index;
528        }
529        let mut next = current_commit_index.saturating_add(1);
530        let mut committed = current_commit_index;
531        while next <= self.last_index() {
532            if self.ack_count(next) >= quorum {
533                committed = next;
534                next = next.saturating_add(1);
535            } else {
536                break;
537            }
538        }
539        committed
540    }
541
542    fn truncate_from(&mut self, index: u64) {
543        if index == 0 {
544            self.entries.clear();
545            self.acked_by.clear();
546            return;
547        }
548        self.entries.retain(|entry| entry.index < index);
549        self.acked_by.retain(|entry_index, _| *entry_index < index);
550    }
551}
552
553fn compute_log_checksum(
554    index: u64,
555    term: u64,
556    leader_id: &str,
557    operation: &str,
558    payload: &Value,
559) -> Result<String, String> {
560    let payload_json = serde_json::to_string(payload)
561        .map_err(|error| format!("checksum payload error: {error}"))?;
562    let raw = format!("{index}|{term}|{leader_id}|{operation}|{payload_json}");
563    Ok(format!("{:016x}", fnv1a64(raw.as_bytes())))
564}
565
566fn fnv1a64(bytes: &[u8]) -> u64 {
567    let mut hash = 0xcbf29ce484222325u64;
568    for byte in bytes {
569        hash ^= *byte as u64;
570        hash = hash.wrapping_mul(0x100000001b3);
571    }
572    hash
573}
574
575fn unix_ms_now() -> u64 {
576    SystemTime::now()
577        .duration_since(UNIX_EPOCH)
578        .unwrap_or_default()
579        .as_millis() as u64
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585    use serde_json::json;
586
587    #[test]
588    fn standalone_profile_is_valid() {
589        let profile = HaRuntimeProfile::default();
590        assert!(profile.validate().is_ok());
591    }
592
593    #[test]
594    fn enabled_replication_requires_ha_role() {
595        let mut profile = HaRuntimeProfile::default();
596        profile.replication.enabled = true;
597        assert!(profile.validate().is_err());
598    }
599
600    #[test]
601    fn primary_quorum_cannot_exceed_cluster_size() {
602        let mut profile = HaRuntimeProfile::default();
603        profile.replication.enabled = true;
604        profile.replication.role = ServerRole::Primary;
605        profile.replication.peers = vec!["n2".to_string()];
606        profile.replication.sync_ack_quorum = 3;
607        assert!(profile.validate().is_err());
608    }
609
610    #[test]
611    fn state_transitions_record_role_changes() {
612        let mut profile = HaRuntimeProfile::default();
613        profile.replication.enabled = true;
614        profile.replication.role = ServerRole::Replica;
615        let mut state = HaRuntimeState::new(&profile);
616        assert_eq!(state.role, ServerRole::Replica);
617        state.mark_failover_started();
618        assert!(state.failover_in_progress);
619        state.promote_to_primary("node-a".to_string());
620        assert_eq!(state.role, ServerRole::Primary);
621        assert_eq!(state.leader_id.as_deref(), Some("node-a"));
622        assert!(!state.failover_in_progress);
623    }
624
625    #[test]
626    fn vote_guard_rejects_stale_candidate_logs() {
627        let mut profile = HaRuntimeProfile::default();
628        profile.replication.enabled = true;
629        profile.replication.role = ServerRole::Replica;
630        let mut state = HaRuntimeState::new(&profile);
631        state.current_term = 5;
632        state.note_log_position(10, 5);
633        assert!(!state.can_grant_vote(5, "node-b", 9, 5));
634        assert!(state.can_grant_vote(5, "node-b", 10, 5));
635    }
636
637    #[test]
638    fn replication_log_appends_and_commits_with_quorum() {
639        let mut log = ReplicationLog::new();
640        let entry = log
641            .append_leader_event(
642                1,
643                "node-a",
644                "ingest_chunk".to_string(),
645                json!({"id": "c1"}),
646                "node-a",
647            )
648            .expect("append must succeed");
649        assert_eq!(entry.index, 1);
650        assert_eq!(log.ack_count(1), 1);
651        log.acknowledge(1, "node-b".to_string());
652        assert_eq!(log.compute_commit_index(0, 2), 1);
653    }
654
655    #[test]
656    fn replication_log_conflict_truncates_suffix() {
657        let mut log = ReplicationLog::new();
658        let _ = log
659            .append_leader_event(1, "node-a", "write".to_string(), json!({"k": 1}), "node-a")
660            .expect("append 1");
661        let _ = log
662            .append_leader_event(1, "node-a", "write".to_string(), json!({"k": 2}), "node-a")
663            .expect("append 2");
664
665        let replacement = vec![
666            ReplicationLogEntry::new(
667                2,
668                2,
669                "node-b".to_string(),
670                "write".to_string(),
671                json!({"k": 20}),
672            )
673            .expect("entry"),
674            ReplicationLogEntry::new(
675                3,
676                2,
677                "node-b".to_string(),
678                "write".to_string(),
679                json!({"k": 30}),
680            )
681            .expect("entry"),
682        ];
683        log.append_remote_entries(1, 1, &replacement)
684            .expect("append replacement");
685
686        assert_eq!(log.last_index(), 3);
687        assert_eq!(log.entry_at(2).map(|entry| entry.term), Some(2));
688        assert_eq!(log.entry_at(3).map(|entry| entry.term), Some(2));
689    }
690
691    #[test]
692    fn replication_log_rejects_bad_checksum() {
693        let mut entry = ReplicationLogEntry::new(
694            1,
695            1,
696            "node-a".to_string(),
697            "write".to_string(),
698            json!({"x": 1}),
699        )
700        .expect("entry");
701        entry.checksum = "deadbeef".to_string();
702
703        let mut log = ReplicationLog::new();
704        let err = log
705            .append_remote_entries(0, 0, &[entry])
706            .expect_err("checksum must fail");
707        assert!(err.contains("checksum"));
708    }
709}