1use crate::{
11 election::{ElectionState, VoteValidator},
12 rpc::{
13 AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest,
14 InstallSnapshotResponse, RaftMessage, RequestVoteRequest, RequestVoteResponse,
15 },
16 state::{LeaderState, PersistentState, RaftState, VolatileState},
17 LogIndex, NodeId, RaftError, RaftResult, Term,
18};
19use parking_lot::RwLock;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::mpsc;
23use tokio::time::{interval, sleep};
24use tracing::{debug, error, info, warn};
25
26#[derive(Debug, Clone)]
28pub struct RaftNodeConfig {
29 pub node_id: NodeId,
31
32 pub cluster_members: Vec<NodeId>,
34
35 pub election_timeout_min: u64,
37
38 pub election_timeout_max: u64,
40
41 pub heartbeat_interval: u64,
43
44 pub max_entries_per_message: usize,
46
47 pub snapshot_chunk_size: usize,
49}
50
51impl RaftNodeConfig {
52 pub fn new(node_id: NodeId, cluster_members: Vec<NodeId>) -> Self {
54 Self {
55 node_id,
56 cluster_members,
57 election_timeout_min: 150,
58 election_timeout_max: 300,
59 heartbeat_interval: 50,
60 max_entries_per_message: 100,
61 snapshot_chunk_size: 64 * 1024, }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct Command {
69 pub data: Vec<u8>,
70}
71
72#[derive(Debug, Clone)]
74pub struct CommandResult {
75 pub index: LogIndex,
76 pub term: Term,
77}
78
79#[derive(Debug)]
81enum InternalMessage {
82 Rpc { from: NodeId, message: RaftMessage },
84 ClientCommand {
86 command: Command,
87 response_tx: mpsc::Sender<RaftResult<CommandResult>>,
88 },
89 ElectionTimeout,
91 HeartbeatTimeout,
93}
94
95pub struct RaftNode {
97 config: RaftNodeConfig,
99
100 persistent: Arc<RwLock<PersistentState>>,
102
103 volatile: Arc<RwLock<VolatileState>>,
105
106 state: Arc<RwLock<RaftState>>,
108
109 leader_state: Arc<RwLock<Option<LeaderState>>>,
111
112 election_state: Arc<RwLock<ElectionState>>,
114
115 current_leader: Arc<RwLock<Option<NodeId>>>,
117
118 internal_tx: mpsc::UnboundedSender<InternalMessage>,
120 internal_rx: Arc<RwLock<mpsc::UnboundedReceiver<InternalMessage>>>,
121}
122
123impl RaftNode {
124 pub fn new(config: RaftNodeConfig) -> Self {
126 let (internal_tx, internal_rx) = mpsc::unbounded_channel();
127 let cluster_size = config.cluster_members.len();
128
129 Self {
130 persistent: Arc::new(RwLock::new(PersistentState::new())),
131 volatile: Arc::new(RwLock::new(VolatileState::new())),
132 state: Arc::new(RwLock::new(RaftState::Follower)),
133 leader_state: Arc::new(RwLock::new(None)),
134 election_state: Arc::new(RwLock::new(ElectionState::new(
135 cluster_size,
136 config.election_timeout_min,
137 config.election_timeout_max,
138 ))),
139 current_leader: Arc::new(RwLock::new(None)),
140 config,
141 internal_tx,
142 internal_rx: Arc::new(RwLock::new(internal_rx)),
143 }
144 }
145
146 pub async fn start(self: Arc<Self>) {
148 info!("Starting Raft node: {}", self.config.node_id);
149
150 self.clone().spawn_election_timer();
152
153 self.clone().spawn_heartbeat_timer();
155
156 self.run().await;
158 }
159
160 async fn run(self: Arc<Self>) {
162 loop {
163 let message = {
164 let mut rx = self.internal_rx.write();
165 rx.recv().await
166 };
167
168 match message {
169 Some(InternalMessage::Rpc { from, message }) => {
170 self.handle_rpc_message(from, message).await;
171 }
172 Some(InternalMessage::ClientCommand {
173 command,
174 response_tx,
175 }) => {
176 self.handle_client_command(command, response_tx).await;
177 }
178 Some(InternalMessage::ElectionTimeout) => {
179 self.handle_election_timeout().await;
180 }
181 Some(InternalMessage::HeartbeatTimeout) => {
182 self.handle_heartbeat_timeout().await;
183 }
184 None => {
185 warn!("Internal channel closed, stopping node");
186 break;
187 }
188 }
189 }
190 }
191
192 async fn handle_rpc_message(&self, from: NodeId, message: RaftMessage) {
194 let message_term = message.term();
196 let current_term = self.persistent.read().current_term;
197
198 if message_term > current_term {
199 self.step_down(message_term).await;
200 }
201
202 match message {
203 RaftMessage::AppendEntriesRequest(req) => {
204 let response = self.handle_append_entries(req).await;
205 debug!("AppendEntries response to {}: {:?}", from, response);
207 }
208 RaftMessage::AppendEntriesResponse(resp) => {
209 self.handle_append_entries_response(from, resp).await;
210 }
211 RaftMessage::RequestVoteRequest(req) => {
212 let response = self.handle_request_vote(req).await;
213 debug!("RequestVote response to {}: {:?}", from, response);
215 }
216 RaftMessage::RequestVoteResponse(resp) => {
217 self.handle_request_vote_response(from, resp).await;
218 }
219 RaftMessage::InstallSnapshotRequest(req) => {
220 let response = self.handle_install_snapshot(req).await;
221 debug!("InstallSnapshot response to {}: {:?}", from, response);
223 }
224 RaftMessage::InstallSnapshotResponse(resp) => {
225 self.handle_install_snapshot_response(from, resp).await;
226 }
227 }
228 }
229
230 async fn handle_append_entries(&self, req: AppendEntriesRequest) -> AppendEntriesResponse {
232 let mut persistent = self.persistent.write();
233 let mut volatile = self.volatile.write();
234
235 if req.term < persistent.current_term {
237 return AppendEntriesResponse::failure(persistent.current_term, None, None);
238 }
239
240 self.election_state.write().reset_timer();
242 *self.current_leader.write() = Some(req.leader_id.clone());
243
244 if !persistent
246 .log
247 .matches(req.prev_log_index, req.prev_log_term)
248 {
249 let conflict_index = req.prev_log_index;
250 let conflict_term = persistent.log.term_at(conflict_index);
251 return AppendEntriesResponse::failure(
252 persistent.current_term,
253 Some(conflict_index),
254 conflict_term,
255 );
256 }
257
258 if !req.entries.is_empty() {
260 let mut index = req.prev_log_index + 1;
262 for entry in &req.entries {
263 if let Some(existing_term) = persistent.log.term_at(index) {
264 if existing_term != entry.term {
265 let _ = persistent.log.truncate_from(index);
267 }
268 }
269 index += 1;
270 }
271
272 if let Err(e) = persistent.log.append_entries(req.entries.clone()) {
274 error!("Failed to append entries: {}", e);
275 return AppendEntriesResponse::failure(persistent.current_term, None, None);
276 }
277 }
278
279 if req.leader_commit > volatile.commit_index {
281 let last_new_entry = if req.entries.is_empty() {
282 req.prev_log_index
283 } else {
284 req.entries.last().unwrap().index
285 };
286 volatile.update_commit_index(std::cmp::min(req.leader_commit, last_new_entry));
287 }
288
289 AppendEntriesResponse::success(persistent.current_term, persistent.log.last_index())
290 }
291
292 async fn handle_append_entries_response(&self, from: NodeId, resp: AppendEntriesResponse) {
294 if !self.state.read().is_leader() {
295 return;
296 }
297
298 let persistent = self.persistent.write();
299 let mut leader_state_guard = self.leader_state.write();
300
301 if let Some(leader_state) = leader_state_guard.as_mut() {
302 if resp.success {
303 if let Some(match_index) = resp.match_index {
305 leader_state.update_replication(&from, match_index);
306
307 let new_commit = leader_state.calculate_commit_index();
309 let mut volatile = self.volatile.write();
310 if new_commit > volatile.commit_index {
311 if let Some(term) = persistent.log.term_at(new_commit) {
313 if term == persistent.current_term {
314 volatile.update_commit_index(new_commit);
315 info!("Updated commit index to {}", new_commit);
316 }
317 }
318 }
319 }
320 } else {
321 leader_state.decrement_next_index(&from);
323 debug!("Replication failed for {}, decrementing next_index", from);
324 }
325 }
326 }
327
328 async fn handle_request_vote(&self, req: RequestVoteRequest) -> RequestVoteResponse {
330 let mut persistent = self.persistent.write();
331
332 if req.term < persistent.current_term {
334 return RequestVoteResponse::denied(persistent.current_term);
335 }
336
337 let last_log_index = persistent.log.last_index();
338 let last_log_term = persistent.log.last_term();
339
340 let should_grant = VoteValidator::should_grant_vote(
342 persistent.current_term,
343 &persistent.voted_for,
344 last_log_index,
345 last_log_term,
346 &req.candidate_id,
347 req.term,
348 req.last_log_index,
349 req.last_log_term,
350 );
351
352 if should_grant {
353 persistent.vote_for(req.candidate_id.clone());
354 self.election_state.write().reset_timer();
355 info!("Granted vote to {} for term {}", req.candidate_id, req.term);
356 RequestVoteResponse::granted(persistent.current_term)
357 } else {
358 debug!("Denied vote to {} for term {}", req.candidate_id, req.term);
359 RequestVoteResponse::denied(persistent.current_term)
360 }
361 }
362
363 async fn handle_request_vote_response(&self, from: NodeId, resp: RequestVoteResponse) {
365 if !self.state.read().is_candidate() {
366 return;
367 }
368
369 let current_term = self.persistent.read().current_term;
370 if resp.term != current_term {
371 return;
372 }
373
374 if resp.vote_granted {
375 let won_election = self.election_state.write().record_vote(from.clone());
376 if won_election {
377 info!("Won election for term {}", current_term);
378 self.become_leader().await;
379 }
380 }
381 }
382
383 async fn handle_install_snapshot(
385 &self,
386 req: InstallSnapshotRequest,
387 ) -> InstallSnapshotResponse {
388 let persistent = self.persistent.write();
389
390 if req.term < persistent.current_term {
391 return InstallSnapshotResponse::failure(persistent.current_term);
392 }
393
394 InstallSnapshotResponse::success(persistent.current_term, None)
397 }
398
399 async fn handle_install_snapshot_response(
401 &self,
402 _from: NodeId,
403 _resp: InstallSnapshotResponse,
404 ) {
405 }
407
408 async fn handle_client_command(
410 &self,
411 command: Command,
412 response_tx: mpsc::Sender<RaftResult<CommandResult>>,
413 ) {
414 if !self.state.read().is_leader() {
416 let _ = response_tx.send(Err(RaftError::NotLeader)).await;
417 return;
418 }
419
420 let mut persistent = self.persistent.write();
421 let term = persistent.current_term;
422 let index = persistent.log.append(term, command.data);
423
424 let result = CommandResult { index, term };
425 let _ = response_tx.send(Ok(result)).await;
426
427 drop(persistent);
429 let _ = self.internal_tx.send(InternalMessage::HeartbeatTimeout);
430 }
431
432 async fn handle_election_timeout(&self) {
434 if self.state.read().is_leader() {
435 return;
436 }
437
438 if !self.election_state.read().should_start_election() {
439 return;
440 }
441
442 info!("Election timeout, starting election");
443 self.start_election().await;
444 }
445
446 async fn start_election(&self) {
448 *self.state.write() = RaftState::Candidate;
450
451 let mut persistent = self.persistent.write();
453 persistent.increment_term();
454 persistent.vote_for(self.config.node_id.clone());
455 let term = persistent.current_term;
456
457 self.election_state
459 .write()
460 .start_election(term, &self.config.node_id);
461
462 let last_log_index = persistent.log.last_index();
463 let last_log_term = persistent.log.last_term();
464
465 info!(
466 "Starting election for term {} as {}",
467 term, self.config.node_id
468 );
469
470 for member in &self.config.cluster_members {
472 if member != &self.config.node_id {
473 let _request = RequestVoteRequest::new(
474 term,
475 self.config.node_id.clone(),
476 last_log_index,
477 last_log_term,
478 );
479 debug!("Would send RequestVote to {}", member);
481 }
482 }
483 }
484
485 async fn become_leader(&self) {
487 info!(
488 "Becoming leader for term {}",
489 self.persistent.read().current_term
490 );
491
492 *self.state.write() = RaftState::Leader;
493 *self.current_leader.write() = Some(self.config.node_id.clone());
494
495 let last_log_index = self.persistent.read().log.last_index();
496 let other_members: Vec<_> = self
497 .config
498 .cluster_members
499 .iter()
500 .filter(|m| *m != &self.config.node_id)
501 .cloned()
502 .collect();
503
504 *self.leader_state.write() = Some(LeaderState::new(&other_members, last_log_index));
505
506 let _ = self.internal_tx.send(InternalMessage::HeartbeatTimeout);
508 }
509
510 async fn step_down(&self, term: Term) {
512 info!("Stepping down to follower for term {}", term);
513
514 *self.state.write() = RaftState::Follower;
515 *self.leader_state.write() = None;
516 *self.current_leader.write() = None;
517
518 let mut persistent = self.persistent.write();
519 persistent.update_term(term);
520 }
521
522 async fn handle_heartbeat_timeout(&self) {
524 if !self.state.read().is_leader() {
525 return;
526 }
527
528 self.send_heartbeats().await;
529 }
530
531 async fn send_heartbeats(&self) {
533 let persistent = self.persistent.read();
534 let term = persistent.current_term;
535 let commit_index = self.volatile.read().commit_index;
536
537 for member in &self.config.cluster_members {
538 if member != &self.config.node_id {
539 let request = AppendEntriesRequest::heartbeat(
540 term,
541 self.config.node_id.clone(),
542 commit_index,
543 );
544 debug!("Would send heartbeat to {}", member);
546 }
547 }
548 }
549
550 fn spawn_election_timer(self: Arc<Self>) {
552 let node = self.clone();
553 tokio::spawn(async move {
554 let mut interval = interval(Duration::from_millis(50));
555 loop {
556 interval.tick().await;
557 if node.election_state.read().should_start_election() {
558 let _ = node.internal_tx.send(InternalMessage::ElectionTimeout);
559 }
560 }
561 });
562 }
563
564 fn spawn_heartbeat_timer(self: Arc<Self>) {
566 let node = self.clone();
567 tokio::spawn(async move {
568 let interval_ms = node.config.heartbeat_interval;
569 let mut interval = interval(Duration::from_millis(interval_ms));
570 loop {
571 interval.tick().await;
572 if node.state.read().is_leader() {
573 let _ = node.internal_tx.send(InternalMessage::HeartbeatTimeout);
574 }
575 }
576 });
577 }
578
579 pub async fn submit_command(&self, data: Vec<u8>) -> RaftResult<CommandResult> {
581 let (tx, mut rx) = mpsc::channel(1);
582 let command = Command { data };
583
584 self.internal_tx
585 .send(InternalMessage::ClientCommand {
586 command,
587 response_tx: tx,
588 })
589 .map_err(|_| RaftError::Internal("Node stopped".to_string()))?;
590
591 rx.recv()
592 .await
593 .ok_or_else(|| RaftError::Internal("Response channel closed".to_string()))?
594 }
595
596 pub fn current_state(&self) -> RaftState {
598 *self.state.read()
599 }
600
601 pub fn current_term(&self) -> Term {
603 self.persistent.read().current_term
604 }
605
606 pub fn current_leader(&self) -> Option<NodeId> {
608 self.current_leader.read().clone()
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 #[test]
617 fn test_node_creation() {
618 let config = RaftNodeConfig::new(
619 "node1".to_string(),
620 vec![
621 "node1".to_string(),
622 "node2".to_string(),
623 "node3".to_string(),
624 ],
625 );
626
627 let node = RaftNode::new(config);
628 assert_eq!(node.current_state(), RaftState::Follower);
629 assert_eq!(node.current_term(), 0);
630 }
631}