vortex_raft/
raft.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt,
4    sync::Arc,
5    time::Duration,
6};
7
8use rand::Rng;
9
10use crate::*;
11
12#[derive(Debug, Deserialize, Serialize, Clone)]
13#[serde(tag = "type")]
14#[serde(rename_all = "snake_case")]
15pub enum RaftPayload<Entry = ()> {
16    RequestVote {
17        /// candidate’s term
18        term: u64,
19        /// index of candidate’s last log entry
20        last_log_index: usize,
21        /// term of candidate’s last log entry
22        last_log_term: u64,
23    },
24    RequestVoteOk {
25        /// currentTerm, for candidate to update itself
26        term: u64,
27        /// true means candidate received vote
28        vote_granted: bool,
29    },
30    /// Invoked by leader to replicate log entries; also used as heartbeat.
31    AppendEntries {
32        /// leader’s term
33        term: u64,
34        /// index of log entry immediately preceding new ones
35        prev_log_index: usize,
36        /// term of prev_log_index entry
37        prev_log_term: u64,
38        entries: Vec<(Entry, u64)>,
39        leader_commit: usize,
40    },
41    AppendEntriesOk {
42        term: u64,
43        success: bool,
44        /// Last index that was applied in the AppendEntries
45        applied_up_to: usize,
46    },
47}
48
49#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
50pub enum RaftSignal {
51    Heartbeat,
52    Campaign,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56#[serde(untagged)]
57pub enum RaftEvent<Entry: Clone> {
58    RaftMessage(Message<RaftPayload<Entry>>),
59    RaftSignal(RaftSignal),
60    CommitedEntry(Entry),
61}
62
63pub struct RaftService<Entry: Clone> {
64    msg_id: IdCounter,
65    /// latest term server has seen
66    current_term: u64,
67    /// log entries; each entry contains command for state machine, and term when entry was received by leader
68    log: Vec<Option<(Entry, u64)>>,
69    /// index of highest log entry known to be committed
70    commit_index: usize,
71    /// Node that received our vote in current term
72    voted_for: Option<String>,
73    /// for each server, index of the next log entry to send to that server
74    next_index: HashMap<String, usize>,
75    /// for each server, index of highest log entry known to be replicated on server
76    match_index: HashMap<String, usize>,
77    /// Votes recieved from servers
78    votes: HashSet<String>,
79    /// Whether or not to reset the election timeout
80    /// (set to true only when leader sends an AppendEntries)
81    reset_election_timer: Arc<AtomicBool>,
82    /// Sends commited log entries to the client node
83    log_sender: Box<dyn SenderExt<Entry>>,
84}
85
86impl<Entry: Clone + Serialize + fmt::Debug + 'static> RaftService<Entry> {
87    pub fn create(network: &mut Network, sender: impl SenderExt<RaftEvent<Entry>> + Clone) -> Self {
88        let reset_election_timer = Arc::new(AtomicBool::new(false));
89        let sender_clone = sender.clone();
90
91        let mut node = Self {
92            msg_id: IdCounter::new(),
93            log: vec![None],
94            commit_index: 0,
95            current_term: 0,
96            voted_for: None,
97            next_index: HashMap::new(),
98            match_index: HashMap::new(),
99            votes: HashSet::new(),
100            reset_election_timer: reset_election_timer.clone(),
101            log_sender: Box::new(sender_clone.map_input(RaftEvent::CommitedEntry)),
102        };
103
104        // No need to worry about silly things like consensus in singleton mode
105        if network.is_singleton() {
106            // Make sure we're considered leader still
107            node.next_index.insert(network.node_id.clone(), 1);
108            node.match_index.insert(network.node_id.clone(), 0);
109            return node;
110        }
111
112        // Between 150 and 300 as per Raft spec
113        let election_timeout = Duration::from_millis(rand::thread_rng().gen_range(150..=300));
114        let heartbeat_timeout = Duration::from_millis(150);
115
116        let sender_clone = sender.clone();
117
118        spawn_timer(
119            Box::new(move || {
120                let _ = sender_clone.send(RaftEvent::RaftSignal(RaftSignal::Campaign));
121                Ok(())
122            }),
123            election_timeout,
124            Some(reset_election_timer),
125        );
126
127        spawn_timer(
128            Box::new(move || {
129                let _ = sender.send(RaftEvent::RaftSignal(RaftSignal::Heartbeat));
130                Ok(())
131            }),
132            heartbeat_timeout,
133            None,
134        );
135
136        node
137    }
138
139    pub fn become_leader(&mut self, network: &mut Network) -> anyhow::Result<()> {
140        eprintln!(
141            "{} has been elected as leader in term {}",
142            network.node_id, self.current_term
143        );
144
145        self.next_index = network
146            .all_nodes
147            .iter()
148            .cloned()
149            .map(|n| (n, self.last_log_index() + 1))
150            .collect();
151
152        self.match_index = network.all_nodes.iter().cloned().map(|n| (n, 0)).collect();
153
154        // Send inital heartbeat as soon as elected
155        self.step(RaftEvent::RaftSignal(RaftSignal::Heartbeat), network)
156            .context("Heartbeat on elected")
157    }
158
159    pub fn is_leader(&self) -> bool {
160        !self.next_index.is_empty()
161    }
162
163    /// Fallible
164    pub fn voted_for(&self) -> Option<String> {
165        self.voted_for.clone()
166    }
167
168    fn last_log_index(&self) -> usize {
169        self.log.len() - 1
170    }
171
172    fn log_term_at(&self, idx: usize) -> u64 {
173        self.log[idx].as_ref().map(|(_, term)| *term).unwrap_or(0)
174    }
175
176    fn delay_election(&self) {
177        self.reset_election_timer.store(true, Ordering::Relaxed);
178    }
179
180    fn update_term(&mut self, term: u64) {
181        self.current_term = term;
182        self.votes.clear();
183        self.next_index.clear();
184        self.match_index.clear();
185        // Voted for no one in this new term
186        self.voted_for = None;
187    }
188
189    fn send_append_entries(&mut self, dest: String, network: &mut Network) -> anyhow::Result<()> {
190        let next_index = self.next_index[&dest];
191        let entries = self.log[next_index..]
192            .iter()
193            .map(|e| e.clone().expect("Only the 0th element should be None"))
194            .collect();
195
196        network
197            .send(
198                dest,
199                Body {
200                    msg_id: self.msg_id.next(),
201                    in_reply_to: None,
202                    payload: RaftPayload::AppendEntries {
203                        term: self.current_term,
204                        prev_log_index: next_index - 1,
205                        prev_log_term: self.log_term_at(next_index - 1),
206                        entries,
207                        leader_commit: self.commit_index,
208                    },
209                },
210            )
211            .context("Send append entries")
212    }
213
214    /// Request to append entry to the log
215    pub fn request(&mut self, entry: Entry, network: &mut Network) -> anyhow::Result<()> {
216        if !self.is_leader() {
217            return Ok(());
218        }
219
220        self.log.push(Some((entry, self.current_term)));
221        *self
222            .match_index
223            .get_mut(&network.node_id)
224            .expect("Node should have itself in match_index") += 1;
225
226        if network.is_singleton() {
227            *self
228                .next_index
229                .get_mut(&network.node_id)
230                .expect("Node should have itself in next_index") += 1;
231
232            // As singleton we can immediately commit
233            self.commit(self.commit_index + 1);
234        }
235
236        Ok(())
237    }
238
239    fn commit(&mut self, new_idx: usize) {
240        let prev = self.commit_index;
241        self.commit_index = new_idx;
242        for i in prev + 1..=self.commit_index {
243            if let Some((e, _)) = self.log[i].clone() {
244                let _ = self.log_sender.send(e);
245            }
246        }
247
248        eprintln!("Committed entries {}-{}", prev + 1, self.commit_index);
249    }
250
251    pub fn step(&mut self, event: RaftEvent<Entry>, network: &mut Network) -> anyhow::Result<()> {
252        match event {
253            RaftEvent::RaftSignal(signal) => match signal {
254                RaftSignal::Heartbeat => {
255                    // If leader, send append entries
256                    if !self.is_leader() {
257                        return Ok(());
258                    }
259
260                    eprintln!("{} sending heartbeats", network.node_id);
261                    for follower in network.all_nodes.clone() {
262                        if follower == network.node_id {
263                            continue;
264                        }
265
266                        self.send_append_entries(follower, network)
267                            .context("Heartbeat send append entries")?;
268                    }
269                }
270                RaftSignal::Campaign => {
271                    if self.is_leader() {
272                        return Ok(());
273                    }
274
275                    self.update_term(self.current_term + 1);
276
277                    eprintln!(
278                        "{} campaigning in new term {}",
279                        network.node_id, self.current_term
280                    );
281
282                    // Broadcast RequestVote, vote for self
283                    self.voted_for = Some(network.node_id.clone());
284                    self.votes.insert(network.node_id.clone());
285                    self.delay_election();
286
287                    for other in network.all_nodes.clone() {
288                        if other == network.node_id {
289                            continue;
290                        }
291
292                        network
293                            .send(
294                                other,
295                                Body {
296                                    msg_id: self.msg_id.next(),
297                                    in_reply_to: None,
298                                    payload: RaftPayload::<Entry>::RequestVote {
299                                        term: self.current_term,
300                                        last_log_index: self.last_log_index(),
301                                        last_log_term: self.log_term_at(self.last_log_index()),
302                                    },
303                                },
304                            )
305                            .context("Send request vote")?;
306                    }
307                }
308            },
309            RaftEvent::RaftMessage(msg) => match msg.body.payload {
310                RaftPayload::RequestVote {
311                    term,
312                    last_log_index,
313                    last_log_term,
314                } => {
315                    eprintln!("Got vote request from {} at term {}", msg.src, term);
316                    //  Grant vote iff term >= currentTerm,
317                    //  votedFor is null or candidateId, and
318                    //  candidate’s log is at least as up-to-date as
319                    //  receiver’s log
320                    if term > self.current_term {
321                        self.update_term(term);
322                    }
323
324                    let granted = term >= self.current_term
325                        && !self.voted_for.as_ref().is_some_and(|v| v != &msg.src)
326                        && (last_log_term, last_log_index)
327                            >= (
328                                self.log_term_at(self.last_log_index()),
329                                self.last_log_index(),
330                            );
331
332                    if granted {
333                        self.update_term(term);
334
335                        self.voted_for = Some(msg.src.clone());
336                        self.delay_election();
337
338                        eprintln!("{} voted for {} in term {term}", network.node_id, msg.src);
339                    }
340
341                    network
342                        .reply(
343                            msg.src,
344                            self.msg_id.next(),
345                            msg.body.msg_id,
346                            RaftPayload::<Entry>::RequestVoteOk {
347                                term: self.current_term,
348                                vote_granted: granted,
349                            },
350                        )
351                        .context("Request vote OK reply")?;
352                }
353                RaftPayload::RequestVoteOk { term, vote_granted } => {
354                    if term > self.current_term {
355                        self.update_term(term);
356                        return Ok(());
357                    }
358
359                    if vote_granted && term == self.current_term {
360                        self.votes.insert(msg.src);
361                    }
362
363                    if self.votes.len() > network.all_nodes.len() / 2 && !self.is_leader() {
364                        self.become_leader(network)
365                            .context("Become leader on majority")?;
366                    }
367                }
368                RaftPayload::AppendEntries {
369                    term,
370                    prev_log_index,
371                    prev_log_term,
372                    mut entries,
373                    leader_commit,
374                } => {
375                    if term >= self.current_term {
376                        eprintln!("got appendEntries from {}", msg.src);
377
378                        // Must be from a new leader
379                        self.update_term(term);
380                        self.voted_for = Some(msg.src.clone());
381                        self.delay_election();
382                    }
383
384                    if term < self.current_term
385                        || self.last_log_index() < prev_log_index
386                        || self.log_term_at(prev_log_index) != prev_log_term
387                    {
388                        network
389                            .reply(
390                                msg.src,
391                                self.msg_id.next(),
392                                msg.body.msg_id,
393                                RaftPayload::<Entry>::AppendEntriesOk {
394                                    term: self.current_term,
395                                    success: false,
396                                    applied_up_to: self.last_log_index(),
397                                },
398                            )
399                            .context("Append entries rejection")?;
400                        return Ok(());
401                    }
402
403                    for i in prev_log_index + 1
404                        ..std::cmp::min(prev_log_index + 1 + entries.len(), self.log.len())
405                    {
406                        if self.log_term_at(i) != entries[0].1 {
407                            // Conflicting entries
408                            // so remove everything from this point on
409                            self.log.drain(i..);
410                            break;
411                        }
412                        entries.remove(0);
413                    }
414
415                    self.log.extend(entries.into_iter().map(Some));
416
417                    let new_idx = std::cmp::max(
418                        self.commit_index,
419                        std::cmp::min(leader_commit, self.last_log_index()),
420                    );
421
422                    self.commit(dbg!(new_idx));
423
424                    network
425                        .reply(
426                            msg.src,
427                            self.msg_id.next(),
428                            msg.body.msg_id,
429                            RaftPayload::<Entry>::AppendEntriesOk {
430                                term: self.current_term,
431                                success: true,
432                                applied_up_to: self.last_log_index(),
433                            },
434                        )
435                        .context("Append entries acceptance")?;
436                }
437                RaftPayload::AppendEntriesOk {
438                    term,
439                    success,
440                    applied_up_to,
441                } => {
442                    if term > self.current_term {
443                        self.update_term(term);
444                        return Ok(());
445                    }
446
447                    if !self.is_leader() {
448                        return Ok(());
449                    }
450
451                    // Should never have term < current_term
452                    // Since follower updated their term to leader's
453                    // On AppendEntries, thus this message is too delayed
454                    if term < self.current_term {
455                        return Ok(());
456                    }
457
458                    *self
459                        .next_index
460                        .get_mut(&msg.src)
461                        .expect("Leader should have all nodes in next_index") = applied_up_to + 1;
462
463                    *self
464                        .match_index
465                        .get_mut(&msg.src)
466                        .expect("Leader should have all nodes in match_index") = applied_up_to;
467
468                    if success {
469                        let mut counts = HashMap::new();
470
471                        for &e in self.match_index.values() {
472                            *counts.entry(e).or_insert(0) += 1
473                        }
474
475                        let half_nodes = network.all_nodes.len() / 2;
476
477                        let new_idx = counts
478                            .into_iter()
479                            .filter(|(_, c)| c > &half_nodes)
480                            .map(|(v, _)| v)
481                            .max()
482                            .unwrap_or(0);
483
484                        self.commit(dbg!(new_idx));
485                    } else {
486                        self.send_append_entries(msg.src, network)
487                            .context("Append entries retry")?;
488                    }
489                }
490            },
491            RaftEvent::CommitedEntry(_) => bail!("Commited entry should be handled by Raft client"),
492        }
493        Ok(())
494    }
495}