1use crate::{
11 election::{ElectionState, VoteValidator},
12 rpc::{
13 AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest,
14 InstallSnapshotResponse, RaftMessage, RequestVoteRequest, RequestVoteResponse,
15 },
16 state::{LeaderState, PersistentState, RaftState, VolatileState},
17 LogIndex, NodeId, RaftError, RaftResult, Term,
18};
19use parking_lot::RwLock;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::mpsc;
23use tokio::time::{interval, sleep};
24use tracing::{debug, error, info, warn};
25
26#[derive(Debug, Clone)]
28pub struct RaftNodeConfig {
29 pub node_id: NodeId,
31
32 pub cluster_members: Vec<NodeId>,
34
35 pub election_timeout_min: u64,
37
38 pub election_timeout_max: u64,
40
41 pub heartbeat_interval: u64,
43
44 pub max_entries_per_message: usize,
46
47 pub snapshot_chunk_size: usize,
49}
50
51impl RaftNodeConfig {
52 pub fn new(node_id: NodeId, cluster_members: Vec<NodeId>) -> Self {
54 Self {
55 node_id,
56 cluster_members,
57 election_timeout_min: 150,
58 election_timeout_max: 300,
59 heartbeat_interval: 50,
60 max_entries_per_message: 100,
61 snapshot_chunk_size: 64 * 1024, }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct Command {
69 pub data: Vec<u8>,
70}
71
72#[derive(Debug, Clone)]
74pub struct CommandResult {
75 pub index: LogIndex,
76 pub term: Term,
77}
78
79#[derive(Debug)]
81enum InternalMessage {
82 Rpc { from: NodeId, message: RaftMessage },
84 ClientCommand {
86 command: Command,
87 response_tx: mpsc::Sender<RaftResult<CommandResult>>,
88 },
89 ElectionTimeout,
91 HeartbeatTimeout,
93}
94
95pub struct RaftNode {
97 config: RaftNodeConfig,
99
100 persistent: Arc<RwLock<PersistentState>>,
102
103 volatile: Arc<RwLock<VolatileState>>,
105
106 state: Arc<RwLock<RaftState>>,
108
109 leader_state: Arc<RwLock<Option<LeaderState>>>,
111
112 election_state: Arc<RwLock<ElectionState>>,
114
115 current_leader: Arc<RwLock<Option<NodeId>>>,
117
118 internal_tx: mpsc::UnboundedSender<InternalMessage>,
120 internal_rx: Arc<RwLock<mpsc::UnboundedReceiver<InternalMessage>>>,
121}
122
123impl RaftNode {
124 pub fn new(config: RaftNodeConfig) -> Self {
126 let (internal_tx, internal_rx) = mpsc::unbounded_channel();
127 let cluster_size = config.cluster_members.len();
128
129 Self {
130 persistent: Arc::new(RwLock::new(PersistentState::new())),
131 volatile: Arc::new(RwLock::new(VolatileState::new())),
132 state: Arc::new(RwLock::new(RaftState::Follower)),
133 leader_state: Arc::new(RwLock::new(None)),
134 election_state: Arc::new(RwLock::new(ElectionState::new(
135 cluster_size,
136 config.election_timeout_min,
137 config.election_timeout_max,
138 ))),
139 current_leader: Arc::new(RwLock::new(None)),
140 config,
141 internal_tx,
142 internal_rx: Arc::new(RwLock::new(internal_rx)),
143 }
144 }
145
146 pub async fn start(self: Arc<Self>) {
148 info!("Starting Raft node: {}", self.config.node_id);
149
150 self.clone().spawn_election_timer();
152
153 self.clone().spawn_heartbeat_timer();
155
156 self.run().await;
158 }
159
160 async fn run(self: Arc<Self>) {
162 loop {
163 let message = {
164 let mut rx = self.internal_rx.write();
165 rx.recv().await
166 };
167
168 match message {
169 Some(InternalMessage::Rpc { from, message }) => {
170 self.handle_rpc_message(from, message).await;
171 }
172 Some(InternalMessage::ClientCommand {
173 command,
174 response_tx,
175 }) => {
176 self.handle_client_command(command, response_tx).await;
177 }
178 Some(InternalMessage::ElectionTimeout) => {
179 self.handle_election_timeout().await;
180 }
181 Some(InternalMessage::HeartbeatTimeout) => {
182 self.handle_heartbeat_timeout().await;
183 }
184 None => {
185 warn!("Internal channel closed, stopping node");
186 break;
187 }
188 }
189 }
190 }
191
192 async fn handle_rpc_message(&self, from: NodeId, message: RaftMessage) {
194 let message_term = message.term();
196 let current_term = self.persistent.read().current_term;
197
198 if message_term > current_term {
199 self.step_down(message_term).await;
200 }
201
202 match message {
203 RaftMessage::AppendEntriesRequest(req) => {
204 let response = self.handle_append_entries(req).await;
205 debug!("AppendEntries response to {}: {:?}", from, response);
207 }
208 RaftMessage::AppendEntriesResponse(resp) => {
209 self.handle_append_entries_response(from, resp).await;
210 }
211 RaftMessage::RequestVoteRequest(req) => {
212 let response = self.handle_request_vote(req).await;
213 debug!("RequestVote response to {}: {:?}", from, response);
215 }
216 RaftMessage::RequestVoteResponse(resp) => {
217 self.handle_request_vote_response(from, resp).await;
218 }
219 RaftMessage::InstallSnapshotRequest(req) => {
220 let response = self.handle_install_snapshot(req).await;
221 debug!("InstallSnapshot response to {}: {:?}", from, response);
223 }
224 RaftMessage::InstallSnapshotResponse(resp) => {
225 self.handle_install_snapshot_response(from, resp).await;
226 }
227 }
228 }
229
230 async fn handle_append_entries(&self, req: AppendEntriesRequest) -> AppendEntriesResponse {
232 let mut persistent = self.persistent.write();
233 let mut volatile = self.volatile.write();
234
235 if req.term < persistent.current_term {
237 return AppendEntriesResponse::failure(persistent.current_term, None, None);
238 }
239
240 self.election_state.write().reset_timer();
242 *self.current_leader.write() = Some(req.leader_id.clone());
243
244 if !persistent
246 .log
247 .matches(req.prev_log_index, req.prev_log_term)
248 {
249 let conflict_index = req.prev_log_index;
250 let conflict_term = persistent.log.term_at(conflict_index);
251 return AppendEntriesResponse::failure(
252 persistent.current_term,
253 Some(conflict_index),
254 conflict_term,
255 );
256 }
257
258 if !req.entries.is_empty() {
260 let mut index = req.prev_log_index + 1;
262 for entry in &req.entries {
263 if let Some(existing_term) = persistent.log.term_at(index) {
264 if existing_term != entry.term {
265 let _ = persistent.log.truncate_from(index);
267 }
268 }
269 index += 1;
270 }
271
272 if let Err(e) = persistent.log.append_entries(req.entries.clone()) {
274 error!("Failed to append entries: {}", e);
275 return AppendEntriesResponse::failure(persistent.current_term, None, None);
276 }
277 }
278
279 if req.leader_commit > volatile.commit_index {
281 let last_new_entry = if req.entries.is_empty() {
282 req.prev_log_index
283 } else {
284 req.entries
286 .last()
287 .expect("entries verified non-empty")
288 .index
289 };
290 volatile.update_commit_index(std::cmp::min(req.leader_commit, last_new_entry));
291 }
292
293 AppendEntriesResponse::success(persistent.current_term, persistent.log.last_index())
294 }
295
296 async fn handle_append_entries_response(&self, from: NodeId, resp: AppendEntriesResponse) {
298 if !self.state.read().is_leader() {
299 return;
300 }
301
302 let persistent = self.persistent.write();
303 let mut leader_state_guard = self.leader_state.write();
304
305 if let Some(leader_state) = leader_state_guard.as_mut() {
306 if resp.success {
307 if let Some(match_index) = resp.match_index {
309 leader_state.update_replication(&from, match_index);
310
311 let new_commit = leader_state.calculate_commit_index();
313 let mut volatile = self.volatile.write();
314 if new_commit > volatile.commit_index {
315 if let Some(term) = persistent.log.term_at(new_commit) {
317 if term == persistent.current_term {
318 volatile.update_commit_index(new_commit);
319 info!("Updated commit index to {}", new_commit);
320 }
321 }
322 }
323 }
324 } else {
325 leader_state.decrement_next_index(&from);
327 debug!("Replication failed for {}, decrementing next_index", from);
328 }
329 }
330 }
331
332 async fn handle_request_vote(&self, req: RequestVoteRequest) -> RequestVoteResponse {
334 let mut persistent = self.persistent.write();
335
336 if req.term < persistent.current_term {
338 return RequestVoteResponse::denied(persistent.current_term);
339 }
340
341 let last_log_index = persistent.log.last_index();
342 let last_log_term = persistent.log.last_term();
343
344 let should_grant = VoteValidator::should_grant_vote(
346 persistent.current_term,
347 &persistent.voted_for,
348 last_log_index,
349 last_log_term,
350 &req.candidate_id,
351 req.term,
352 req.last_log_index,
353 req.last_log_term,
354 );
355
356 if should_grant {
357 persistent.vote_for(req.candidate_id.clone());
358 self.election_state.write().reset_timer();
359 info!("Granted vote to {} for term {}", req.candidate_id, req.term);
360 RequestVoteResponse::granted(persistent.current_term)
361 } else {
362 debug!("Denied vote to {} for term {}", req.candidate_id, req.term);
363 RequestVoteResponse::denied(persistent.current_term)
364 }
365 }
366
367 async fn handle_request_vote_response(&self, from: NodeId, resp: RequestVoteResponse) {
369 if !self.state.read().is_candidate() {
370 return;
371 }
372
373 let current_term = self.persistent.read().current_term;
374 if resp.term != current_term {
375 return;
376 }
377
378 if resp.vote_granted {
379 let won_election = self.election_state.write().record_vote(from.clone());
380 if won_election {
381 info!("Won election for term {}", current_term);
382 self.become_leader().await;
383 }
384 }
385 }
386
387 async fn handle_install_snapshot(
389 &self,
390 req: InstallSnapshotRequest,
391 ) -> InstallSnapshotResponse {
392 let persistent = self.persistent.write();
393
394 if req.term < persistent.current_term {
395 return InstallSnapshotResponse::failure(persistent.current_term);
396 }
397
398 InstallSnapshotResponse::success(persistent.current_term, None)
401 }
402
403 async fn handle_install_snapshot_response(
405 &self,
406 _from: NodeId,
407 _resp: InstallSnapshotResponse,
408 ) {
409 }
411
412 async fn handle_client_command(
414 &self,
415 command: Command,
416 response_tx: mpsc::Sender<RaftResult<CommandResult>>,
417 ) {
418 if !self.state.read().is_leader() {
420 let _ = response_tx.send(Err(RaftError::NotLeader)).await;
421 return;
422 }
423
424 let mut persistent = self.persistent.write();
425 let term = persistent.current_term;
426 let index = persistent.log.append(term, command.data);
427
428 let result = CommandResult { index, term };
429 let _ = response_tx.send(Ok(result)).await;
430
431 drop(persistent);
433 let _ = self.internal_tx.send(InternalMessage::HeartbeatTimeout);
434 }
435
436 async fn handle_election_timeout(&self) {
438 if self.state.read().is_leader() {
439 return;
440 }
441
442 if !self.election_state.read().should_start_election() {
443 return;
444 }
445
446 info!("Election timeout, starting election");
447 self.start_election().await;
448 }
449
450 async fn start_election(&self) {
452 *self.state.write() = RaftState::Candidate;
454
455 let mut persistent = self.persistent.write();
457 persistent.increment_term();
458 persistent.vote_for(self.config.node_id.clone());
459 let term = persistent.current_term;
460
461 self.election_state
463 .write()
464 .start_election(term, &self.config.node_id);
465
466 let last_log_index = persistent.log.last_index();
467 let last_log_term = persistent.log.last_term();
468
469 info!(
470 "Starting election for term {} as {}",
471 term, self.config.node_id
472 );
473
474 for member in &self.config.cluster_members {
476 if member != &self.config.node_id {
477 let _request = RequestVoteRequest::new(
478 term,
479 self.config.node_id.clone(),
480 last_log_index,
481 last_log_term,
482 );
483 debug!("Would send RequestVote to {}", member);
485 }
486 }
487 }
488
489 async fn become_leader(&self) {
491 info!(
492 "Becoming leader for term {}",
493 self.persistent.read().current_term
494 );
495
496 *self.state.write() = RaftState::Leader;
497 *self.current_leader.write() = Some(self.config.node_id.clone());
498
499 let last_log_index = self.persistent.read().log.last_index();
500 let other_members: Vec<_> = self
501 .config
502 .cluster_members
503 .iter()
504 .filter(|m| *m != &self.config.node_id)
505 .cloned()
506 .collect();
507
508 *self.leader_state.write() = Some(LeaderState::new(&other_members, last_log_index));
509
510 let _ = self.internal_tx.send(InternalMessage::HeartbeatTimeout);
512 }
513
514 async fn step_down(&self, term: Term) {
516 info!("Stepping down to follower for term {}", term);
517
518 *self.state.write() = RaftState::Follower;
519 *self.leader_state.write() = None;
520 *self.current_leader.write() = None;
521
522 let mut persistent = self.persistent.write();
523 persistent.update_term(term);
524 }
525
526 async fn handle_heartbeat_timeout(&self) {
528 if !self.state.read().is_leader() {
529 return;
530 }
531
532 self.send_heartbeats().await;
533 }
534
535 async fn send_heartbeats(&self) {
537 let persistent = self.persistent.read();
538 let term = persistent.current_term;
539 let commit_index = self.volatile.read().commit_index;
540
541 for member in &self.config.cluster_members {
542 if member != &self.config.node_id {
543 let request = AppendEntriesRequest::heartbeat(
544 term,
545 self.config.node_id.clone(),
546 commit_index,
547 );
548 debug!("Would send heartbeat to {}", member);
550 }
551 }
552 }
553
554 fn spawn_election_timer(self: Arc<Self>) {
556 let node = self.clone();
557 tokio::spawn(async move {
558 let mut interval = interval(Duration::from_millis(50));
559 loop {
560 interval.tick().await;
561 if node.election_state.read().should_start_election() {
562 let _ = node.internal_tx.send(InternalMessage::ElectionTimeout);
563 }
564 }
565 });
566 }
567
568 fn spawn_heartbeat_timer(self: Arc<Self>) {
570 let node = self.clone();
571 tokio::spawn(async move {
572 let interval_ms = node.config.heartbeat_interval;
573 let mut interval = interval(Duration::from_millis(interval_ms));
574 loop {
575 interval.tick().await;
576 if node.state.read().is_leader() {
577 let _ = node.internal_tx.send(InternalMessage::HeartbeatTimeout);
578 }
579 }
580 });
581 }
582
583 pub async fn submit_command(&self, data: Vec<u8>) -> RaftResult<CommandResult> {
585 let (tx, mut rx) = mpsc::channel(1);
586 let command = Command { data };
587
588 self.internal_tx
589 .send(InternalMessage::ClientCommand {
590 command,
591 response_tx: tx,
592 })
593 .map_err(|_| RaftError::Internal("Node stopped".to_string()))?;
594
595 rx.recv()
596 .await
597 .ok_or_else(|| RaftError::Internal("Response channel closed".to_string()))?
598 }
599
600 pub fn current_state(&self) -> RaftState {
602 *self.state.read()
603 }
604
605 pub fn current_term(&self) -> Term {
607 self.persistent.read().current_term
608 }
609
610 pub fn current_leader(&self) -> Option<NodeId> {
612 self.current_leader.read().clone()
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn test_node_creation() {
622 let config = RaftNodeConfig::new(
623 "node1".to_string(),
624 vec![
625 "node1".to_string(),
626 "node2".to_string(),
627 "node3".to_string(),
628 ],
629 );
630
631 let node = RaftNode::new(config);
632 assert_eq!(node.current_state(), RaftState::Follower);
633 assert_eq!(node.current_term(), 0);
634 }
635}