Skip to main content

ruvector_raft/
node.rs

1//! Raft node implementation
2//!
3//! Coordinates all Raft components:
4//! - State machine management
5//! - RPC message handling
6//! - Log replication
7//! - Leader election
8//! - Client request processing
9
10use crate::{
11    election::{ElectionState, VoteValidator},
12    rpc::{
13        AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest,
14        InstallSnapshotResponse, RaftMessage, RequestVoteRequest, RequestVoteResponse,
15    },
16    state::{LeaderState, PersistentState, RaftState, VolatileState},
17    LogIndex, NodeId, RaftError, RaftResult, Term,
18};
19use parking_lot::RwLock;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::mpsc;
23use tokio::time::{interval, sleep};
24use tracing::{debug, error, info, warn};
25
26/// Configuration for a Raft node
27#[derive(Debug, Clone)]
28pub struct RaftNodeConfig {
29    /// This node's ID
30    pub node_id: NodeId,
31
32    /// IDs of all cluster members (including self)
33    pub cluster_members: Vec<NodeId>,
34
35    /// Minimum election timeout (milliseconds)
36    pub election_timeout_min: u64,
37
38    /// Maximum election timeout (milliseconds)
39    pub election_timeout_max: u64,
40
41    /// Heartbeat interval (milliseconds)
42    pub heartbeat_interval: u64,
43
44    /// Maximum entries per AppendEntries RPC
45    pub max_entries_per_message: usize,
46
47    /// Snapshot chunk size (bytes)
48    pub snapshot_chunk_size: usize,
49}
50
51impl RaftNodeConfig {
52    /// Create a new configuration with defaults
53    pub fn new(node_id: NodeId, cluster_members: Vec<NodeId>) -> Self {
54        Self {
55            node_id,
56            cluster_members,
57            election_timeout_min: 150,
58            election_timeout_max: 300,
59            heartbeat_interval: 50,
60            max_entries_per_message: 100,
61            snapshot_chunk_size: 64 * 1024, // 64KB
62        }
63    }
64}
65
66/// Command to apply to the state machine
67#[derive(Debug, Clone)]
68pub struct Command {
69    pub data: Vec<u8>,
70}
71
72/// Result of applying a command
73#[derive(Debug, Clone)]
74pub struct CommandResult {
75    pub index: LogIndex,
76    pub term: Term,
77}
78
79/// Internal messages for the Raft node
80#[derive(Debug)]
81enum InternalMessage {
82    /// RPC message from another node
83    Rpc { from: NodeId, message: RaftMessage },
84    /// Client command to replicate
85    ClientCommand {
86        command: Command,
87        response_tx: mpsc::Sender<RaftResult<CommandResult>>,
88    },
89    /// Election timeout fired
90    ElectionTimeout,
91    /// Heartbeat timeout fired
92    HeartbeatTimeout,
93}
94
95/// The Raft consensus node
96pub struct RaftNode {
97    /// Configuration
98    config: RaftNodeConfig,
99
100    /// Persistent state
101    persistent: Arc<RwLock<PersistentState>>,
102
103    /// Volatile state
104    volatile: Arc<RwLock<VolatileState>>,
105
106    /// Current Raft state (Follower, Candidate, Leader)
107    state: Arc<RwLock<RaftState>>,
108
109    /// Leader-specific state (only valid when state is Leader)
110    leader_state: Arc<RwLock<Option<LeaderState>>>,
111
112    /// Election state
113    election_state: Arc<RwLock<ElectionState>>,
114
115    /// Current leader ID (if known)
116    current_leader: Arc<RwLock<Option<NodeId>>>,
117
118    /// Channel for internal messages
119    internal_tx: mpsc::UnboundedSender<InternalMessage>,
120    internal_rx: Arc<RwLock<mpsc::UnboundedReceiver<InternalMessage>>>,
121}
122
123impl RaftNode {
124    /// Create a new Raft node
125    pub fn new(config: RaftNodeConfig) -> Self {
126        let (internal_tx, internal_rx) = mpsc::unbounded_channel();
127        let cluster_size = config.cluster_members.len();
128
129        Self {
130            persistent: Arc::new(RwLock::new(PersistentState::new())),
131            volatile: Arc::new(RwLock::new(VolatileState::new())),
132            state: Arc::new(RwLock::new(RaftState::Follower)),
133            leader_state: Arc::new(RwLock::new(None)),
134            election_state: Arc::new(RwLock::new(ElectionState::new(
135                cluster_size,
136                config.election_timeout_min,
137                config.election_timeout_max,
138            ))),
139            current_leader: Arc::new(RwLock::new(None)),
140            config,
141            internal_tx,
142            internal_rx: Arc::new(RwLock::new(internal_rx)),
143        }
144    }
145
146    /// Start the Raft node
147    pub async fn start(self: Arc<Self>) {
148        info!("Starting Raft node: {}", self.config.node_id);
149
150        // Spawn election timer task
151        self.clone().spawn_election_timer();
152
153        // Spawn heartbeat timer task (for leaders)
154        self.clone().spawn_heartbeat_timer();
155
156        // Main message processing loop
157        self.run().await;
158    }
159
160    /// Main message processing loop
161    async fn run(self: Arc<Self>) {
162        loop {
163            let message = {
164                let mut rx = self.internal_rx.write();
165                rx.recv().await
166            };
167
168            match message {
169                Some(InternalMessage::Rpc { from, message }) => {
170                    self.handle_rpc_message(from, message).await;
171                }
172                Some(InternalMessage::ClientCommand {
173                    command,
174                    response_tx,
175                }) => {
176                    self.handle_client_command(command, response_tx).await;
177                }
178                Some(InternalMessage::ElectionTimeout) => {
179                    self.handle_election_timeout().await;
180                }
181                Some(InternalMessage::HeartbeatTimeout) => {
182                    self.handle_heartbeat_timeout().await;
183                }
184                None => {
185                    warn!("Internal channel closed, stopping node");
186                    break;
187                }
188            }
189        }
190    }
191
192    /// Handle RPC message from another node
193    async fn handle_rpc_message(&self, from: NodeId, message: RaftMessage) {
194        // Update term if necessary
195        let message_term = message.term();
196        let current_term = self.persistent.read().current_term;
197
198        if message_term > current_term {
199            self.step_down(message_term).await;
200        }
201
202        match message {
203            RaftMessage::AppendEntriesRequest(req) => {
204                let response = self.handle_append_entries(req).await;
205                // TODO: Send response back to sender
206                debug!("AppendEntries response to {}: {:?}", from, response);
207            }
208            RaftMessage::AppendEntriesResponse(resp) => {
209                self.handle_append_entries_response(from, resp).await;
210            }
211            RaftMessage::RequestVoteRequest(req) => {
212                let response = self.handle_request_vote(req).await;
213                // TODO: Send response back to sender
214                debug!("RequestVote response to {}: {:?}", from, response);
215            }
216            RaftMessage::RequestVoteResponse(resp) => {
217                self.handle_request_vote_response(from, resp).await;
218            }
219            RaftMessage::InstallSnapshotRequest(req) => {
220                let response = self.handle_install_snapshot(req).await;
221                // TODO: Send response back to sender
222                debug!("InstallSnapshot response to {}: {:?}", from, response);
223            }
224            RaftMessage::InstallSnapshotResponse(resp) => {
225                self.handle_install_snapshot_response(from, resp).await;
226            }
227        }
228    }
229
230    /// Handle AppendEntries RPC
231    async fn handle_append_entries(&self, req: AppendEntriesRequest) -> AppendEntriesResponse {
232        let mut persistent = self.persistent.write();
233        let mut volatile = self.volatile.write();
234
235        // Reply false if term < currentTerm
236        if req.term < persistent.current_term {
237            return AppendEntriesResponse::failure(persistent.current_term, None, None);
238        }
239
240        // Reset election timer
241        self.election_state.write().reset_timer();
242        *self.current_leader.write() = Some(req.leader_id.clone());
243
244        // Reply false if log doesn't contain an entry at prevLogIndex with prevLogTerm
245        if !persistent
246            .log
247            .matches(req.prev_log_index, req.prev_log_term)
248        {
249            let conflict_index = req.prev_log_index;
250            let conflict_term = persistent.log.term_at(conflict_index);
251            return AppendEntriesResponse::failure(
252                persistent.current_term,
253                Some(conflict_index),
254                conflict_term,
255            );
256        }
257
258        // Append new entries
259        if !req.entries.is_empty() {
260            // Delete conflicting entries and append new ones
261            let mut index = req.prev_log_index + 1;
262            for entry in &req.entries {
263                if let Some(existing_term) = persistent.log.term_at(index) {
264                    if existing_term != entry.term {
265                        // Conflict found, truncate from here
266                        let _ = persistent.log.truncate_from(index);
267                    }
268                }
269                index += 1;
270            }
271
272            // Append entries
273            if let Err(e) = persistent.log.append_entries(req.entries.clone()) {
274                error!("Failed to append entries: {}", e);
275                return AppendEntriesResponse::failure(persistent.current_term, None, None);
276            }
277        }
278
279        // Update commit index
280        if req.leader_commit > volatile.commit_index {
281            let last_new_entry = if req.entries.is_empty() {
282                req.prev_log_index
283            } else {
284                // SAFETY: We just checked entries is not empty in the if condition
285                req.entries
286                    .last()
287                    .expect("entries verified non-empty")
288                    .index
289            };
290            volatile.update_commit_index(std::cmp::min(req.leader_commit, last_new_entry));
291        }
292
293        AppendEntriesResponse::success(persistent.current_term, persistent.log.last_index())
294    }
295
296    /// Handle AppendEntries response
297    async fn handle_append_entries_response(&self, from: NodeId, resp: AppendEntriesResponse) {
298        if !self.state.read().is_leader() {
299            return;
300        }
301
302        let persistent = self.persistent.write();
303        let mut leader_state_guard = self.leader_state.write();
304
305        if let Some(leader_state) = leader_state_guard.as_mut() {
306            if resp.success {
307                // Update next_index and match_index
308                if let Some(match_index) = resp.match_index {
309                    leader_state.update_replication(&from, match_index);
310
311                    // Update commit index
312                    let new_commit = leader_state.calculate_commit_index();
313                    let mut volatile = self.volatile.write();
314                    if new_commit > volatile.commit_index {
315                        // Verify the entry is from current term
316                        if let Some(term) = persistent.log.term_at(new_commit) {
317                            if term == persistent.current_term {
318                                volatile.update_commit_index(new_commit);
319                                info!("Updated commit index to {}", new_commit);
320                            }
321                        }
322                    }
323                }
324            } else {
325                // Decrement next_index and retry
326                leader_state.decrement_next_index(&from);
327                debug!("Replication failed for {}, decrementing next_index", from);
328            }
329        }
330    }
331
332    /// Handle RequestVote RPC
333    async fn handle_request_vote(&self, req: RequestVoteRequest) -> RequestVoteResponse {
334        let mut persistent = self.persistent.write();
335
336        // Reply false if term < currentTerm
337        if req.term < persistent.current_term {
338            return RequestVoteResponse::denied(persistent.current_term);
339        }
340
341        let last_log_index = persistent.log.last_index();
342        let last_log_term = persistent.log.last_term();
343
344        // Check if we should grant vote
345        let should_grant = VoteValidator::should_grant_vote(
346            persistent.current_term,
347            &persistent.voted_for,
348            last_log_index,
349            last_log_term,
350            &req.candidate_id,
351            req.term,
352            req.last_log_index,
353            req.last_log_term,
354        );
355
356        if should_grant {
357            persistent.vote_for(req.candidate_id.clone());
358            self.election_state.write().reset_timer();
359            info!("Granted vote to {} for term {}", req.candidate_id, req.term);
360            RequestVoteResponse::granted(persistent.current_term)
361        } else {
362            debug!("Denied vote to {} for term {}", req.candidate_id, req.term);
363            RequestVoteResponse::denied(persistent.current_term)
364        }
365    }
366
367    /// Handle RequestVote response
368    async fn handle_request_vote_response(&self, from: NodeId, resp: RequestVoteResponse) {
369        if !self.state.read().is_candidate() {
370            return;
371        }
372
373        let current_term = self.persistent.read().current_term;
374        if resp.term != current_term {
375            return;
376        }
377
378        if resp.vote_granted {
379            let won_election = self.election_state.write().record_vote(from.clone());
380            if won_election {
381                info!("Won election for term {}", current_term);
382                self.become_leader().await;
383            }
384        }
385    }
386
387    /// Handle InstallSnapshot RPC
388    async fn handle_install_snapshot(
389        &self,
390        req: InstallSnapshotRequest,
391    ) -> InstallSnapshotResponse {
392        let persistent = self.persistent.write();
393
394        if req.term < persistent.current_term {
395            return InstallSnapshotResponse::failure(persistent.current_term);
396        }
397
398        // TODO: Implement snapshot installation
399        // For now, just acknowledge
400        InstallSnapshotResponse::success(persistent.current_term, None)
401    }
402
403    /// Handle InstallSnapshot response
404    async fn handle_install_snapshot_response(
405        &self,
406        _from: NodeId,
407        _resp: InstallSnapshotResponse,
408    ) {
409        // TODO: Implement snapshot response handling
410    }
411
412    /// Handle client command
413    async fn handle_client_command(
414        &self,
415        command: Command,
416        response_tx: mpsc::Sender<RaftResult<CommandResult>>,
417    ) {
418        // Only leader can handle client commands
419        if !self.state.read().is_leader() {
420            let _ = response_tx.send(Err(RaftError::NotLeader)).await;
421            return;
422        }
423
424        let mut persistent = self.persistent.write();
425        let term = persistent.current_term;
426        let index = persistent.log.append(term, command.data);
427
428        let result = CommandResult { index, term };
429        let _ = response_tx.send(Ok(result)).await;
430
431        // Trigger immediate replication
432        drop(persistent);
433        let _ = self.internal_tx.send(InternalMessage::HeartbeatTimeout);
434    }
435
436    /// Handle election timeout
437    async fn handle_election_timeout(&self) {
438        if self.state.read().is_leader() {
439            return;
440        }
441
442        if !self.election_state.read().should_start_election() {
443            return;
444        }
445
446        info!("Election timeout, starting election");
447        self.start_election().await;
448    }
449
450    /// Start a new election
451    async fn start_election(&self) {
452        // Transition to candidate
453        *self.state.write() = RaftState::Candidate;
454
455        // Increment term and vote for self
456        let mut persistent = self.persistent.write();
457        persistent.increment_term();
458        persistent.vote_for(self.config.node_id.clone());
459        let term = persistent.current_term;
460
461        // Initialize election state
462        self.election_state
463            .write()
464            .start_election(term, &self.config.node_id);
465
466        let last_log_index = persistent.log.last_index();
467        let last_log_term = persistent.log.last_term();
468
469        info!(
470            "Starting election for term {} as {}",
471            term, self.config.node_id
472        );
473
474        // Send RequestVote RPCs to all other nodes
475        for member in &self.config.cluster_members {
476            if member != &self.config.node_id {
477                let _request = RequestVoteRequest::new(
478                    term,
479                    self.config.node_id.clone(),
480                    last_log_index,
481                    last_log_term,
482                );
483                // TODO: Send request to member
484                debug!("Would send RequestVote to {}", member);
485            }
486        }
487    }
488
489    /// Become leader after winning election
490    async fn become_leader(&self) {
491        info!(
492            "Becoming leader for term {}",
493            self.persistent.read().current_term
494        );
495
496        *self.state.write() = RaftState::Leader;
497        *self.current_leader.write() = Some(self.config.node_id.clone());
498
499        let last_log_index = self.persistent.read().log.last_index();
500        let other_members: Vec<_> = self
501            .config
502            .cluster_members
503            .iter()
504            .filter(|m| *m != &self.config.node_id)
505            .cloned()
506            .collect();
507
508        *self.leader_state.write() = Some(LeaderState::new(&other_members, last_log_index));
509
510        // Send initial heartbeats
511        let _ = self.internal_tx.send(InternalMessage::HeartbeatTimeout);
512    }
513
514    /// Step down to follower (when discovering higher term)
515    async fn step_down(&self, term: Term) {
516        info!("Stepping down to follower for term {}", term);
517
518        *self.state.write() = RaftState::Follower;
519        *self.leader_state.write() = None;
520        *self.current_leader.write() = None;
521
522        let mut persistent = self.persistent.write();
523        persistent.update_term(term);
524    }
525
526    /// Handle heartbeat timeout (for leaders)
527    async fn handle_heartbeat_timeout(&self) {
528        if !self.state.read().is_leader() {
529            return;
530        }
531
532        self.send_heartbeats().await;
533    }
534
535    /// Send heartbeats to all followers
536    async fn send_heartbeats(&self) {
537        let persistent = self.persistent.read();
538        let term = persistent.current_term;
539        let commit_index = self.volatile.read().commit_index;
540
541        for member in &self.config.cluster_members {
542            if member != &self.config.node_id {
543                let request = AppendEntriesRequest::heartbeat(
544                    term,
545                    self.config.node_id.clone(),
546                    commit_index,
547                );
548                // TODO: Send heartbeat to member
549                debug!("Would send heartbeat to {}", member);
550            }
551        }
552    }
553
554    /// Spawn election timer task
555    fn spawn_election_timer(self: Arc<Self>) {
556        let node = self.clone();
557        tokio::spawn(async move {
558            let mut interval = interval(Duration::from_millis(50));
559            loop {
560                interval.tick().await;
561                if node.election_state.read().should_start_election() {
562                    let _ = node.internal_tx.send(InternalMessage::ElectionTimeout);
563                }
564            }
565        });
566    }
567
568    /// Spawn heartbeat timer task
569    fn spawn_heartbeat_timer(self: Arc<Self>) {
570        let node = self.clone();
571        tokio::spawn(async move {
572            let interval_ms = node.config.heartbeat_interval;
573            let mut interval = interval(Duration::from_millis(interval_ms));
574            loop {
575                interval.tick().await;
576                if node.state.read().is_leader() {
577                    let _ = node.internal_tx.send(InternalMessage::HeartbeatTimeout);
578                }
579            }
580        });
581    }
582
583    /// Submit a command to the Raft cluster
584    pub async fn submit_command(&self, data: Vec<u8>) -> RaftResult<CommandResult> {
585        let (tx, mut rx) = mpsc::channel(1);
586        let command = Command { data };
587
588        self.internal_tx
589            .send(InternalMessage::ClientCommand {
590                command,
591                response_tx: tx,
592            })
593            .map_err(|_| RaftError::Internal("Node stopped".to_string()))?;
594
595        rx.recv()
596            .await
597            .ok_or_else(|| RaftError::Internal("Response channel closed".to_string()))?
598    }
599
600    /// Get current state
601    pub fn current_state(&self) -> RaftState {
602        *self.state.read()
603    }
604
605    /// Get current term
606    pub fn current_term(&self) -> Term {
607        self.persistent.read().current_term
608    }
609
610    /// Get current leader
611    pub fn current_leader(&self) -> Option<NodeId> {
612        self.current_leader.read().clone()
613    }
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619
620    #[test]
621    fn test_node_creation() {
622        let config = RaftNodeConfig::new(
623            "node1".to_string(),
624            vec![
625                "node1".to_string(),
626                "node2".to_string(),
627                "node3".to_string(),
628            ],
629        );
630
631        let node = RaftNode::new(config);
632        assert_eq!(node.current_state(), RaftState::Follower);
633        assert_eq!(node.current_term(), 0);
634    }
635}