1use std::{
2 collections::{HashMap, HashSet},
3 fmt,
4 sync::Arc,
5 time::Duration,
6};
7
8use rand::Rng;
9
10use crate::*;
11
12#[derive(Debug, Deserialize, Serialize, Clone)]
13#[serde(tag = "type")]
14#[serde(rename_all = "snake_case")]
15pub enum RaftPayload<Entry = ()> {
16 RequestVote {
17 term: u64,
19 last_log_index: usize,
21 last_log_term: u64,
23 },
24 RequestVoteOk {
25 term: u64,
27 vote_granted: bool,
29 },
30 AppendEntries {
32 term: u64,
34 prev_log_index: usize,
36 prev_log_term: u64,
38 entries: Vec<(Entry, u64)>,
39 leader_commit: usize,
40 },
41 AppendEntriesOk {
42 term: u64,
43 success: bool,
44 applied_up_to: usize,
46 },
47}
48
49#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
50pub enum RaftSignal {
51 Heartbeat,
52 Campaign,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56#[serde(untagged)]
57pub enum RaftEvent<Entry: Clone> {
58 RaftMessage(Message<RaftPayload<Entry>>),
59 RaftSignal(RaftSignal),
60 CommitedEntry(Entry),
61}
62
63pub struct RaftService<Entry: Clone> {
64 msg_id: IdCounter,
65 current_term: u64,
67 log: Vec<Option<(Entry, u64)>>,
69 commit_index: usize,
71 voted_for: Option<String>,
73 next_index: HashMap<String, usize>,
75 match_index: HashMap<String, usize>,
77 votes: HashSet<String>,
79 reset_election_timer: Arc<AtomicBool>,
82 log_sender: Box<dyn SenderExt<Entry>>,
84}
85
86impl<Entry: Clone + Serialize + fmt::Debug + 'static> RaftService<Entry> {
87 pub fn create(network: &mut Network, sender: impl SenderExt<RaftEvent<Entry>> + Clone) -> Self {
88 let reset_election_timer = Arc::new(AtomicBool::new(false));
89 let sender_clone = sender.clone();
90
91 let mut node = Self {
92 msg_id: IdCounter::new(),
93 log: vec![None],
94 commit_index: 0,
95 current_term: 0,
96 voted_for: None,
97 next_index: HashMap::new(),
98 match_index: HashMap::new(),
99 votes: HashSet::new(),
100 reset_election_timer: reset_election_timer.clone(),
101 log_sender: Box::new(sender_clone.map_input(RaftEvent::CommitedEntry)),
102 };
103
104 if network.is_singleton() {
106 node.next_index.insert(network.node_id.clone(), 1);
108 node.match_index.insert(network.node_id.clone(), 0);
109 return node;
110 }
111
112 let election_timeout = Duration::from_millis(rand::thread_rng().gen_range(150..=300));
114 let heartbeat_timeout = Duration::from_millis(150);
115
116 let sender_clone = sender.clone();
117
118 spawn_timer(
119 Box::new(move || {
120 let _ = sender_clone.send(RaftEvent::RaftSignal(RaftSignal::Campaign));
121 Ok(())
122 }),
123 election_timeout,
124 Some(reset_election_timer),
125 );
126
127 spawn_timer(
128 Box::new(move || {
129 let _ = sender.send(RaftEvent::RaftSignal(RaftSignal::Heartbeat));
130 Ok(())
131 }),
132 heartbeat_timeout,
133 None,
134 );
135
136 node
137 }
138
139 pub fn become_leader(&mut self, network: &mut Network) -> anyhow::Result<()> {
140 eprintln!(
141 "{} has been elected as leader in term {}",
142 network.node_id, self.current_term
143 );
144
145 self.next_index = network
146 .all_nodes
147 .iter()
148 .cloned()
149 .map(|n| (n, self.last_log_index() + 1))
150 .collect();
151
152 self.match_index = network.all_nodes.iter().cloned().map(|n| (n, 0)).collect();
153
154 self.step(RaftEvent::RaftSignal(RaftSignal::Heartbeat), network)
156 .context("Heartbeat on elected")
157 }
158
159 pub fn is_leader(&self) -> bool {
160 !self.next_index.is_empty()
161 }
162
163 pub fn voted_for(&self) -> Option<String> {
165 self.voted_for.clone()
166 }
167
168 fn last_log_index(&self) -> usize {
169 self.log.len() - 1
170 }
171
172 fn log_term_at(&self, idx: usize) -> u64 {
173 self.log[idx].as_ref().map(|(_, term)| *term).unwrap_or(0)
174 }
175
176 fn delay_election(&self) {
177 self.reset_election_timer.store(true, Ordering::Relaxed);
178 }
179
180 fn update_term(&mut self, term: u64) {
181 self.current_term = term;
182 self.votes.clear();
183 self.next_index.clear();
184 self.match_index.clear();
185 self.voted_for = None;
187 }
188
189 fn send_append_entries(&mut self, dest: String, network: &mut Network) -> anyhow::Result<()> {
190 let next_index = self.next_index[&dest];
191 let entries = self.log[next_index..]
192 .iter()
193 .map(|e| e.clone().expect("Only the 0th element should be None"))
194 .collect();
195
196 network
197 .send(
198 dest,
199 Body {
200 msg_id: self.msg_id.next(),
201 in_reply_to: None,
202 payload: RaftPayload::AppendEntries {
203 term: self.current_term,
204 prev_log_index: next_index - 1,
205 prev_log_term: self.log_term_at(next_index - 1),
206 entries,
207 leader_commit: self.commit_index,
208 },
209 },
210 )
211 .context("Send append entries")
212 }
213
214 pub fn request(&mut self, entry: Entry, network: &mut Network) -> anyhow::Result<()> {
216 if !self.is_leader() {
217 return Ok(());
218 }
219
220 self.log.push(Some((entry, self.current_term)));
221 *self
222 .match_index
223 .get_mut(&network.node_id)
224 .expect("Node should have itself in match_index") += 1;
225
226 if network.is_singleton() {
227 *self
228 .next_index
229 .get_mut(&network.node_id)
230 .expect("Node should have itself in next_index") += 1;
231
232 self.commit(self.commit_index + 1);
234 }
235
236 Ok(())
237 }
238
239 fn commit(&mut self, new_idx: usize) {
240 let prev = self.commit_index;
241 self.commit_index = new_idx;
242 for i in prev + 1..=self.commit_index {
243 if let Some((e, _)) = self.log[i].clone() {
244 let _ = self.log_sender.send(e);
245 }
246 }
247
248 eprintln!("Committed entries {}-{}", prev + 1, self.commit_index);
249 }
250
251 pub fn step(&mut self, event: RaftEvent<Entry>, network: &mut Network) -> anyhow::Result<()> {
252 match event {
253 RaftEvent::RaftSignal(signal) => match signal {
254 RaftSignal::Heartbeat => {
255 if !self.is_leader() {
257 return Ok(());
258 }
259
260 eprintln!("{} sending heartbeats", network.node_id);
261 for follower in network.all_nodes.clone() {
262 if follower == network.node_id {
263 continue;
264 }
265
266 self.send_append_entries(follower, network)
267 .context("Heartbeat send append entries")?;
268 }
269 }
270 RaftSignal::Campaign => {
271 if self.is_leader() {
272 return Ok(());
273 }
274
275 self.update_term(self.current_term + 1);
276
277 eprintln!(
278 "{} campaigning in new term {}",
279 network.node_id, self.current_term
280 );
281
282 self.voted_for = Some(network.node_id.clone());
284 self.votes.insert(network.node_id.clone());
285 self.delay_election();
286
287 for other in network.all_nodes.clone() {
288 if other == network.node_id {
289 continue;
290 }
291
292 network
293 .send(
294 other,
295 Body {
296 msg_id: self.msg_id.next(),
297 in_reply_to: None,
298 payload: RaftPayload::<Entry>::RequestVote {
299 term: self.current_term,
300 last_log_index: self.last_log_index(),
301 last_log_term: self.log_term_at(self.last_log_index()),
302 },
303 },
304 )
305 .context("Send request vote")?;
306 }
307 }
308 },
309 RaftEvent::RaftMessage(msg) => match msg.body.payload {
310 RaftPayload::RequestVote {
311 term,
312 last_log_index,
313 last_log_term,
314 } => {
315 eprintln!("Got vote request from {} at term {}", msg.src, term);
316 if term > self.current_term {
321 self.update_term(term);
322 }
323
324 let granted = term >= self.current_term
325 && !self.voted_for.as_ref().is_some_and(|v| v != &msg.src)
326 && (last_log_term, last_log_index)
327 >= (
328 self.log_term_at(self.last_log_index()),
329 self.last_log_index(),
330 );
331
332 if granted {
333 self.update_term(term);
334
335 self.voted_for = Some(msg.src.clone());
336 self.delay_election();
337
338 eprintln!("{} voted for {} in term {term}", network.node_id, msg.src);
339 }
340
341 network
342 .reply(
343 msg.src,
344 self.msg_id.next(),
345 msg.body.msg_id,
346 RaftPayload::<Entry>::RequestVoteOk {
347 term: self.current_term,
348 vote_granted: granted,
349 },
350 )
351 .context("Request vote OK reply")?;
352 }
353 RaftPayload::RequestVoteOk { term, vote_granted } => {
354 if term > self.current_term {
355 self.update_term(term);
356 return Ok(());
357 }
358
359 if vote_granted && term == self.current_term {
360 self.votes.insert(msg.src);
361 }
362
363 if self.votes.len() > network.all_nodes.len() / 2 && !self.is_leader() {
364 self.become_leader(network)
365 .context("Become leader on majority")?;
366 }
367 }
368 RaftPayload::AppendEntries {
369 term,
370 prev_log_index,
371 prev_log_term,
372 mut entries,
373 leader_commit,
374 } => {
375 if term >= self.current_term {
376 eprintln!("got appendEntries from {}", msg.src);
377
378 self.update_term(term);
380 self.voted_for = Some(msg.src.clone());
381 self.delay_election();
382 }
383
384 if term < self.current_term
385 || self.last_log_index() < prev_log_index
386 || self.log_term_at(prev_log_index) != prev_log_term
387 {
388 network
389 .reply(
390 msg.src,
391 self.msg_id.next(),
392 msg.body.msg_id,
393 RaftPayload::<Entry>::AppendEntriesOk {
394 term: self.current_term,
395 success: false,
396 applied_up_to: self.last_log_index(),
397 },
398 )
399 .context("Append entries rejection")?;
400 return Ok(());
401 }
402
403 for i in prev_log_index + 1
404 ..std::cmp::min(prev_log_index + 1 + entries.len(), self.log.len())
405 {
406 if self.log_term_at(i) != entries[0].1 {
407 self.log.drain(i..);
410 break;
411 }
412 entries.remove(0);
413 }
414
415 self.log.extend(entries.into_iter().map(Some));
416
417 let new_idx = std::cmp::max(
418 self.commit_index,
419 std::cmp::min(leader_commit, self.last_log_index()),
420 );
421
422 self.commit(dbg!(new_idx));
423
424 network
425 .reply(
426 msg.src,
427 self.msg_id.next(),
428 msg.body.msg_id,
429 RaftPayload::<Entry>::AppendEntriesOk {
430 term: self.current_term,
431 success: true,
432 applied_up_to: self.last_log_index(),
433 },
434 )
435 .context("Append entries acceptance")?;
436 }
437 RaftPayload::AppendEntriesOk {
438 term,
439 success,
440 applied_up_to,
441 } => {
442 if term > self.current_term {
443 self.update_term(term);
444 return Ok(());
445 }
446
447 if !self.is_leader() {
448 return Ok(());
449 }
450
451 if term < self.current_term {
455 return Ok(());
456 }
457
458 *self
459 .next_index
460 .get_mut(&msg.src)
461 .expect("Leader should have all nodes in next_index") = applied_up_to + 1;
462
463 *self
464 .match_index
465 .get_mut(&msg.src)
466 .expect("Leader should have all nodes in match_index") = applied_up_to;
467
468 if success {
469 let mut counts = HashMap::new();
470
471 for &e in self.match_index.values() {
472 *counts.entry(e).or_insert(0) += 1
473 }
474
475 let half_nodes = network.all_nodes.len() / 2;
476
477 let new_idx = counts
478 .into_iter()
479 .filter(|(_, c)| c > &half_nodes)
480 .map(|(v, _)| v)
481 .max()
482 .unwrap_or(0);
483
484 self.commit(dbg!(new_idx));
485 } else {
486 self.send_append_entries(msg.src, network)
487 .context("Append entries retry")?;
488 }
489 }
490 },
491 RaftEvent::CommitedEntry(_) => bail!("Commited entry should be handled by Raft client"),
492 }
493 Ok(())
494 }
495}