1use crate::{NodeId, Term};
10use rand::Rng;
11use std::time::Duration;
12use tokio::time::Instant;
13
14#[derive(Debug)]
16pub struct ElectionTimer {
17 last_reset: Instant,
19
20 timeout: Duration,
22
23 min_timeout_ms: u64,
25
26 max_timeout_ms: u64,
28}
29
30impl ElectionTimer {
31 pub fn new(min_timeout_ms: u64, max_timeout_ms: u64) -> Self {
33 let timeout = Self::random_timeout(min_timeout_ms, max_timeout_ms);
34 Self {
35 last_reset: Instant::now(),
36 timeout,
37 min_timeout_ms,
38 max_timeout_ms,
39 }
40 }
41
42 pub fn with_defaults() -> Self {
44 Self::new(150, 300)
45 }
46
47 pub fn reset(&mut self) {
49 self.last_reset = Instant::now();
50 self.timeout = Self::random_timeout(self.min_timeout_ms, self.max_timeout_ms);
51 }
52
53 pub fn is_elapsed(&self) -> bool {
55 self.last_reset.elapsed() >= self.timeout
56 }
57
58 pub fn time_remaining(&self) -> Duration {
60 self.timeout.saturating_sub(self.last_reset.elapsed())
61 }
62
63 fn random_timeout(min_ms: u64, max_ms: u64) -> Duration {
65 let mut rng = rand::thread_rng();
66 let timeout_ms = rng.gen_range(min_ms..=max_ms);
67 Duration::from_millis(timeout_ms)
68 }
69
70 pub fn timeout(&self) -> Duration {
72 self.timeout
73 }
74}
75
76#[derive(Debug)]
78pub struct VoteTracker {
79 votes_received: Vec<NodeId>,
81
82 cluster_size: usize,
84
85 quorum_size: usize,
87}
88
89impl VoteTracker {
90 pub fn new(cluster_size: usize) -> Self {
92 let quorum_size = (cluster_size / 2) + 1;
93 Self {
94 votes_received: Vec::new(),
95 cluster_size,
96 quorum_size,
97 }
98 }
99
100 pub fn record_vote(&mut self, node_id: NodeId) {
102 if !self.votes_received.contains(&node_id) {
103 self.votes_received.push(node_id);
104 }
105 }
106
107 pub fn has_quorum(&self) -> bool {
109 self.votes_received.len() >= self.quorum_size
110 }
111
112 pub fn vote_count(&self) -> usize {
114 self.votes_received.len()
115 }
116
117 pub fn quorum_size(&self) -> usize {
119 self.quorum_size
120 }
121
122 pub fn reset(&mut self) {
124 self.votes_received.clear();
125 }
126}
127
128#[derive(Debug)]
130pub struct ElectionState {
131 pub timer: ElectionTimer,
133
134 pub votes: VoteTracker,
136
137 pub current_term: Term,
139}
140
141impl ElectionState {
142 pub fn new(cluster_size: usize, min_timeout_ms: u64, max_timeout_ms: u64) -> Self {
144 Self {
145 timer: ElectionTimer::new(min_timeout_ms, max_timeout_ms),
146 votes: VoteTracker::new(cluster_size),
147 current_term: 0,
148 }
149 }
150
151 pub fn start_election(&mut self, term: Term, self_id: &NodeId) {
153 self.current_term = term;
154 self.votes.reset();
155 self.votes.record_vote(self_id.clone());
156 self.timer.reset();
157 }
158
159 pub fn reset_timer(&mut self) {
161 self.timer.reset();
162 }
163
164 pub fn should_start_election(&self) -> bool {
166 self.timer.is_elapsed()
167 }
168
169 pub fn record_vote(&mut self, node_id: NodeId) -> bool {
171 self.votes.record_vote(node_id);
172 self.votes.has_quorum()
173 }
174
175 pub fn update_cluster_size(&mut self, cluster_size: usize) {
177 self.votes = VoteTracker::new(cluster_size);
178 }
179}
180
181pub struct VoteValidator;
183
184impl VoteValidator {
185 pub fn should_grant_vote(
192 receiver_term: Term,
193 receiver_voted_for: &Option<NodeId>,
194 receiver_last_log_index: u64,
195 receiver_last_log_term: Term,
196 candidate_id: &NodeId,
197 candidate_term: Term,
198 candidate_last_log_index: u64,
199 candidate_last_log_term: Term,
200 ) -> bool {
201 if candidate_term < receiver_term {
203 return false;
204 }
205
206 let can_vote = match receiver_voted_for {
208 None => true,
209 Some(voted_for) => voted_for == candidate_id,
210 };
211
212 if !can_vote {
213 return false;
214 }
215
216 Self::is_log_up_to_date(
218 candidate_last_log_term,
219 candidate_last_log_index,
220 receiver_last_log_term,
221 receiver_last_log_index,
222 )
223 }
224
225 fn is_log_up_to_date(
233 candidate_last_term: Term,
234 candidate_last_index: u64,
235 receiver_last_term: Term,
236 receiver_last_index: u64,
237 ) -> bool {
238 if candidate_last_term != receiver_last_term {
239 candidate_last_term >= receiver_last_term
240 } else {
241 candidate_last_index >= receiver_last_index
242 }
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use std::thread::sleep;
250
251 #[test]
252 fn test_election_timer() {
253 let mut timer = ElectionTimer::new(50, 100);
254 assert!(!timer.is_elapsed());
255
256 sleep(Duration::from_millis(150));
257 assert!(timer.is_elapsed());
258
259 timer.reset();
260 assert!(!timer.is_elapsed());
261 }
262
263 #[test]
264 fn test_vote_tracker() {
265 let mut tracker = VoteTracker::new(5);
266 assert_eq!(tracker.quorum_size(), 3);
267 assert!(!tracker.has_quorum());
268
269 tracker.record_vote("node1".to_string());
270 assert!(!tracker.has_quorum());
271
272 tracker.record_vote("node2".to_string());
273 assert!(!tracker.has_quorum());
274
275 tracker.record_vote("node3".to_string());
276 assert!(tracker.has_quorum());
277 }
278
279 #[test]
280 fn test_election_state() {
281 let mut state = ElectionState::new(5, 50, 100);
282 let self_id = "node1".to_string();
283
284 state.start_election(1, &self_id);
285 assert_eq!(state.current_term, 1);
286 assert_eq!(state.votes.vote_count(), 1);
287
288 let won = state.record_vote("node2".to_string());
289 assert!(!won);
290
291 let won = state.record_vote("node3".to_string());
292 assert!(won);
293 }
294
295 #[test]
296 fn test_vote_validation() {
297 assert!(VoteValidator::should_grant_vote(
299 1,
300 &None,
301 10,
302 1,
303 &"candidate".to_string(),
304 2,
305 10,
306 1
307 ));
308
309 assert!(!VoteValidator::should_grant_vote(
311 2,
312 &None,
313 10,
314 1,
315 &"candidate".to_string(),
316 1,
317 10,
318 1
319 ));
320
321 assert!(!VoteValidator::should_grant_vote(
323 1,
324 &Some("other".to_string()),
325 10,
326 1,
327 &"candidate".to_string(),
328 1,
329 10,
330 1
331 ));
332
333 assert!(VoteValidator::should_grant_vote(
335 1,
336 &Some("candidate".to_string()),
337 10,
338 1,
339 &"candidate".to_string(),
340 1,
341 10,
342 1
343 ));
344 }
345
346 #[test]
347 fn test_log_up_to_date() {
348 assert!(VoteValidator::is_log_up_to_date(2, 5, 1, 10));
350 assert!(!VoteValidator::is_log_up_to_date(1, 10, 2, 5));
351
352 assert!(VoteValidator::is_log_up_to_date(1, 10, 1, 5));
354 assert!(!VoteValidator::is_log_up_to_date(1, 5, 1, 10));
355
356 assert!(VoteValidator::is_log_up_to_date(1, 10, 1, 10));
358 }
359}