1use crate::{log::RaftLog, LogIndex, NodeId, Term};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum RaftState {
15 Follower,
17 Candidate,
19 Leader,
21}
22
23impl RaftState {
24 pub fn is_leader(&self) -> bool {
26 matches!(self, RaftState::Leader)
27 }
28
29 pub fn is_candidate(&self) -> bool {
31 matches!(self, RaftState::Candidate)
32 }
33
34 pub fn is_follower(&self) -> bool {
36 matches!(self, RaftState::Follower)
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct PersistentState {
45 pub current_term: Term,
47
48 pub voted_for: Option<NodeId>,
50
51 pub log: RaftLog,
53}
54
55impl PersistentState {
56 pub fn new() -> Self {
58 Self {
59 current_term: 0,
60 voted_for: None,
61 log: RaftLog::new(),
62 }
63 }
64
65 pub fn increment_term(&mut self) {
67 self.current_term += 1;
68 self.voted_for = None;
69 }
70
71 pub fn update_term(&mut self, term: Term) -> bool {
73 if term > self.current_term {
74 self.current_term = term;
75 self.voted_for = None;
76 true
77 } else {
78 false
79 }
80 }
81
82 pub fn vote_for(&mut self, candidate_id: NodeId) {
84 self.voted_for = Some(candidate_id);
85 }
86
87 pub fn can_vote_for(&self, candidate_id: &NodeId) -> bool {
89 match &self.voted_for {
90 None => true,
91 Some(voted) => voted == candidate_id,
92 }
93 }
94
95 pub fn to_bytes(&self) -> Result<Vec<u8>, bincode::error::EncodeError> {
97 use bincode::config;
98 bincode::encode_to_vec(bincode::serde::Compat(self), config::standard())
99 }
100
101 pub fn from_bytes(bytes: &[u8]) -> Result<Self, bincode::error::DecodeError> {
103 use bincode::config;
104 let (compat, _): (bincode::serde::Compat<Self>, _) =
105 bincode::decode_from_slice(bytes, config::standard())?;
106 Ok(compat.0)
107 }
108}
109
110impl Default for PersistentState {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[derive(Debug, Clone)]
120pub struct VolatileState {
121 pub commit_index: LogIndex,
124
125 pub last_applied: LogIndex,
128}
129
130impl VolatileState {
131 pub fn new() -> Self {
133 Self {
134 commit_index: 0,
135 last_applied: 0,
136 }
137 }
138
139 pub fn update_commit_index(&mut self, index: LogIndex) {
141 if index > self.commit_index {
142 self.commit_index = index;
143 }
144 }
145
146 pub fn apply_entries(&mut self, up_to_index: LogIndex) {
148 if up_to_index > self.last_applied {
149 self.last_applied = up_to_index;
150 }
151 }
152
153 pub fn pending_entries(&self) -> u64 {
155 self.commit_index.saturating_sub(self.last_applied)
156 }
157}
158
159impl Default for VolatileState {
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165#[derive(Debug, Clone)]
169pub struct LeaderState {
170 pub next_index: HashMap<NodeId, LogIndex>,
173
174 pub match_index: HashMap<NodeId, LogIndex>,
177}
178
179impl LeaderState {
180 pub fn new(cluster_members: &[NodeId], last_log_index: LogIndex) -> Self {
182 let mut next_index = HashMap::new();
183 let mut match_index = HashMap::new();
184
185 for member in cluster_members {
186 next_index.insert(member.clone(), last_log_index + 1);
188 match_index.insert(member.clone(), 0);
190 }
191
192 Self {
193 next_index,
194 match_index,
195 }
196 }
197
198 pub fn decrement_next_index(&mut self, node_id: &NodeId) {
200 if let Some(index) = self.next_index.get_mut(node_id) {
201 if *index > 1 {
202 *index -= 1;
203 }
204 }
205 }
206
207 pub fn update_replication(&mut self, node_id: &NodeId, match_index: LogIndex) {
209 self.match_index.insert(node_id.clone(), match_index);
210 self.next_index.insert(node_id.clone(), match_index + 1);
211 }
212
213 pub fn calculate_commit_index(&self) -> LogIndex {
215 if self.match_index.is_empty() {
216 return 0;
217 }
218
219 let mut indices: Vec<LogIndex> = self.match_index.values().copied().collect();
220 indices.sort_unstable();
221
222 let mid = indices.len() / 2;
224 indices.get(mid).copied().unwrap_or(0)
225 }
226
227 pub fn get_next_index(&self, node_id: &NodeId) -> Option<LogIndex> {
229 self.next_index.get(node_id).copied()
230 }
231
232 pub fn get_match_index(&self, node_id: &NodeId) -> Option<LogIndex> {
234 self.match_index.get(node_id).copied()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_raft_state_checks() {
244 assert!(RaftState::Leader.is_leader());
245 assert!(RaftState::Candidate.is_candidate());
246 assert!(RaftState::Follower.is_follower());
247 }
248
249 #[test]
250 fn test_persistent_state_term_management() {
251 let mut state = PersistentState::new();
252 assert_eq!(state.current_term, 0);
253
254 state.increment_term();
255 assert_eq!(state.current_term, 1);
256 assert!(state.voted_for.is_none());
257
258 state.update_term(5);
259 assert_eq!(state.current_term, 5);
260 }
261
262 #[test]
263 fn test_voting() {
264 let mut state = PersistentState::new();
265 let candidate = "node1".to_string();
266
267 assert!(state.can_vote_for(&candidate));
268 state.vote_for(candidate.clone());
269 assert!(state.can_vote_for(&candidate));
270 assert!(!state.can_vote_for(&"node2".to_string()));
271 }
272
273 #[test]
274 fn test_volatile_state() {
275 let mut state = VolatileState::new();
276 assert_eq!(state.commit_index, 0);
277 assert_eq!(state.last_applied, 0);
278
279 state.update_commit_index(10);
280 assert_eq!(state.commit_index, 10);
281 assert_eq!(state.pending_entries(), 10);
282
283 state.apply_entries(5);
284 assert_eq!(state.last_applied, 5);
285 assert_eq!(state.pending_entries(), 5);
286 }
287
288 #[test]
289 fn test_leader_state() {
290 let members = vec!["node1".to_string(), "node2".to_string()];
291 let mut leader_state = LeaderState::new(&members, 10);
292
293 assert_eq!(leader_state.get_next_index(&members[0]), Some(11));
294 assert_eq!(leader_state.get_match_index(&members[0]), Some(0));
295
296 leader_state.update_replication(&members[0], 10);
297 assert_eq!(leader_state.get_next_index(&members[0]), Some(11));
298 assert_eq!(leader_state.get_match_index(&members[0]), Some(10));
299 }
300
301 #[test]
302 fn test_commit_index_calculation() {
303 let members = vec![
304 "node1".to_string(),
305 "node2".to_string(),
306 "node3".to_string(),
307 ];
308 let mut leader_state = LeaderState::new(&members, 10);
309
310 leader_state.update_replication(&members[0], 5);
311 leader_state.update_replication(&members[1], 8);
312 leader_state.update_replication(&members[2], 3);
313
314 let commit = leader_state.calculate_commit_index();
315 assert_eq!(commit, 5); }
317}