ruvector_raft/
state.rs

1//! Raft state management
2//!
3//! Implements the state machine for Raft consensus including:
4//! - Persistent state (term, vote, log)
5//! - Volatile state (commit index, last applied)
6//! - Leader-specific state (next index, match index)
7
8use crate::{log::RaftLog, LogIndex, NodeId, Term};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// The three states a Raft node can be in
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum RaftState {
15    /// Follower state - responds to RPCs from leaders and candidates
16    Follower,
17    /// Candidate state - attempts to become leader
18    Candidate,
19    /// Leader state - handles client requests and replicates log
20    Leader,
21}
22
23impl RaftState {
24    /// Returns true if this node is the leader
25    pub fn is_leader(&self) -> bool {
26        matches!(self, RaftState::Leader)
27    }
28
29    /// Returns true if this node is a candidate
30    pub fn is_candidate(&self) -> bool {
31        matches!(self, RaftState::Candidate)
32    }
33
34    /// Returns true if this node is a follower
35    pub fn is_follower(&self) -> bool {
36        matches!(self, RaftState::Follower)
37    }
38}
39
40/// Persistent state on all servers
41///
42/// Updated on stable storage before responding to RPCs
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct PersistentState {
45    /// Latest term server has seen (initialized to 0, increases monotonically)
46    pub current_term: Term,
47
48    /// Candidate ID that received vote in current term (or None)
49    pub voted_for: Option<NodeId>,
50
51    /// Log entries (each entry contains command and term)
52    pub log: RaftLog,
53}
54
55impl PersistentState {
56    /// Create new persistent state with initial values
57    pub fn new() -> Self {
58        Self {
59            current_term: 0,
60            voted_for: None,
61            log: RaftLog::new(),
62        }
63    }
64
65    /// Increment the current term
66    pub fn increment_term(&mut self) {
67        self.current_term += 1;
68        self.voted_for = None;
69    }
70
71    /// Update term if the given term is higher
72    pub fn update_term(&mut self, term: Term) -> bool {
73        if term > self.current_term {
74            self.current_term = term;
75            self.voted_for = None;
76            true
77        } else {
78            false
79        }
80    }
81
82    /// Vote for a candidate in the current term
83    pub fn vote_for(&mut self, candidate_id: NodeId) {
84        self.voted_for = Some(candidate_id);
85    }
86
87    /// Check if vote can be granted for the given candidate
88    pub fn can_vote_for(&self, candidate_id: &NodeId) -> bool {
89        match &self.voted_for {
90            None => true,
91            Some(voted) => voted == candidate_id,
92        }
93    }
94
95    /// Serialize state to bytes for persistence
96    pub fn to_bytes(&self) -> Result<Vec<u8>, bincode::error::EncodeError> {
97        use bincode::config;
98        bincode::encode_to_vec(bincode::serde::Compat(self), config::standard())
99    }
100
101    /// Deserialize state from bytes
102    pub fn from_bytes(bytes: &[u8]) -> Result<Self, bincode::error::DecodeError> {
103        use bincode::config;
104        let (compat, _): (bincode::serde::Compat<Self>, _) =
105            bincode::decode_from_slice(bytes, config::standard())?;
106        Ok(compat.0)
107    }
108}
109
110impl Default for PersistentState {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116/// Volatile state on all servers
117///
118/// Can be reconstructed from persistent state
119#[derive(Debug, Clone)]
120pub struct VolatileState {
121    /// Index of highest log entry known to be committed
122    /// (initialized to 0, increases monotonically)
123    pub commit_index: LogIndex,
124
125    /// Index of highest log entry applied to state machine
126    /// (initialized to 0, increases monotonically)
127    pub last_applied: LogIndex,
128}
129
130impl VolatileState {
131    /// Create new volatile state with initial values
132    pub fn new() -> Self {
133        Self {
134            commit_index: 0,
135            last_applied: 0,
136        }
137    }
138
139    /// Update commit index
140    pub fn update_commit_index(&mut self, index: LogIndex) {
141        if index > self.commit_index {
142            self.commit_index = index;
143        }
144    }
145
146    /// Advance last_applied index
147    pub fn apply_entries(&mut self, up_to_index: LogIndex) {
148        if up_to_index > self.last_applied {
149            self.last_applied = up_to_index;
150        }
151    }
152
153    /// Get the number of entries that need to be applied
154    pub fn pending_entries(&self) -> u64 {
155        self.commit_index.saturating_sub(self.last_applied)
156    }
157}
158
159impl Default for VolatileState {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165/// Volatile state on leaders
166///
167/// Reinitialized after election
168#[derive(Debug, Clone)]
169pub struct LeaderState {
170    /// For each server, index of the next log entry to send to that server
171    /// (initialized to leader last log index + 1)
172    pub next_index: HashMap<NodeId, LogIndex>,
173
174    /// For each server, index of highest log entry known to be replicated
175    /// (initialized to 0, increases monotonically)
176    pub match_index: HashMap<NodeId, LogIndex>,
177}
178
179impl LeaderState {
180    /// Create new leader state for the given cluster members
181    pub fn new(cluster_members: &[NodeId], last_log_index: LogIndex) -> Self {
182        let mut next_index = HashMap::new();
183        let mut match_index = HashMap::new();
184
185        for member in cluster_members {
186            // Initialize next_index to last log index + 1
187            next_index.insert(member.clone(), last_log_index + 1);
188            // Initialize match_index to 0
189            match_index.insert(member.clone(), 0);
190        }
191
192        Self {
193            next_index,
194            match_index,
195        }
196    }
197
198    /// Update next_index for a follower (decrement on failure)
199    pub fn decrement_next_index(&mut self, node_id: &NodeId) {
200        if let Some(index) = self.next_index.get_mut(node_id) {
201            if *index > 1 {
202                *index -= 1;
203            }
204        }
205    }
206
207    /// Update both next_index and match_index for successful replication
208    pub fn update_replication(&mut self, node_id: &NodeId, match_index: LogIndex) {
209        self.match_index.insert(node_id.clone(), match_index);
210        self.next_index.insert(node_id.clone(), match_index + 1);
211    }
212
213    /// Get the median match_index for determining commit_index
214    pub fn calculate_commit_index(&self) -> LogIndex {
215        if self.match_index.is_empty() {
216            return 0;
217        }
218
219        let mut indices: Vec<LogIndex> = self.match_index.values().copied().collect();
220        indices.sort_unstable();
221
222        // Return the median (quorum)
223        let mid = indices.len() / 2;
224        indices.get(mid).copied().unwrap_or(0)
225    }
226
227    /// Get next_index for a specific follower
228    pub fn get_next_index(&self, node_id: &NodeId) -> Option<LogIndex> {
229        self.next_index.get(node_id).copied()
230    }
231
232    /// Get match_index for a specific follower
233    pub fn get_match_index(&self, node_id: &NodeId) -> Option<LogIndex> {
234        self.match_index.get(node_id).copied()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_raft_state_checks() {
244        assert!(RaftState::Leader.is_leader());
245        assert!(RaftState::Candidate.is_candidate());
246        assert!(RaftState::Follower.is_follower());
247    }
248
249    #[test]
250    fn test_persistent_state_term_management() {
251        let mut state = PersistentState::new();
252        assert_eq!(state.current_term, 0);
253
254        state.increment_term();
255        assert_eq!(state.current_term, 1);
256        assert!(state.voted_for.is_none());
257
258        state.update_term(5);
259        assert_eq!(state.current_term, 5);
260    }
261
262    #[test]
263    fn test_voting() {
264        let mut state = PersistentState::new();
265        let candidate = "node1".to_string();
266
267        assert!(state.can_vote_for(&candidate));
268        state.vote_for(candidate.clone());
269        assert!(state.can_vote_for(&candidate));
270        assert!(!state.can_vote_for(&"node2".to_string()));
271    }
272
273    #[test]
274    fn test_volatile_state() {
275        let mut state = VolatileState::new();
276        assert_eq!(state.commit_index, 0);
277        assert_eq!(state.last_applied, 0);
278
279        state.update_commit_index(10);
280        assert_eq!(state.commit_index, 10);
281        assert_eq!(state.pending_entries(), 10);
282
283        state.apply_entries(5);
284        assert_eq!(state.last_applied, 5);
285        assert_eq!(state.pending_entries(), 5);
286    }
287
288    #[test]
289    fn test_leader_state() {
290        let members = vec!["node1".to_string(), "node2".to_string()];
291        let mut leader_state = LeaderState::new(&members, 10);
292
293        assert_eq!(leader_state.get_next_index(&members[0]), Some(11));
294        assert_eq!(leader_state.get_match_index(&members[0]), Some(0));
295
296        leader_state.update_replication(&members[0], 10);
297        assert_eq!(leader_state.get_next_index(&members[0]), Some(11));
298        assert_eq!(leader_state.get_match_index(&members[0]), Some(10));
299    }
300
301    #[test]
302    fn test_commit_index_calculation() {
303        let members = vec![
304            "node1".to_string(),
305            "node2".to_string(),
306            "node3".to_string(),
307        ];
308        let mut leader_state = LeaderState::new(&members, 10);
309
310        leader_state.update_replication(&members[0], 5);
311        leader_state.update_replication(&members[1], 8);
312        leader_state.update_replication(&members[2], 3);
313
314        let commit = leader_state.calculate_commit_index();
315        assert_eq!(commit, 5); // Median of [3, 5, 8]
316    }
317}