ruvector_raft/
node.rs

1//! Raft node implementation
2//!
3//! Coordinates all Raft components:
4//! - State machine management
5//! - RPC message handling
6//! - Log replication
7//! - Leader election
8//! - Client request processing
9
10use crate::{
11    election::{ElectionState, VoteValidator},
12    rpc::{
13        AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest,
14        InstallSnapshotResponse, RaftMessage, RequestVoteRequest, RequestVoteResponse,
15    },
16    state::{LeaderState, PersistentState, RaftState, VolatileState},
17    LogIndex, NodeId, RaftError, RaftResult, Term,
18};
19use parking_lot::RwLock;
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::sync::mpsc;
23use tokio::time::{interval, sleep};
24use tracing::{debug, error, info, warn};
25
26/// Configuration for a Raft node
27#[derive(Debug, Clone)]
28pub struct RaftNodeConfig {
29    /// This node's ID
30    pub node_id: NodeId,
31
32    /// IDs of all cluster members (including self)
33    pub cluster_members: Vec<NodeId>,
34
35    /// Minimum election timeout (milliseconds)
36    pub election_timeout_min: u64,
37
38    /// Maximum election timeout (milliseconds)
39    pub election_timeout_max: u64,
40
41    /// Heartbeat interval (milliseconds)
42    pub heartbeat_interval: u64,
43
44    /// Maximum entries per AppendEntries RPC
45    pub max_entries_per_message: usize,
46
47    /// Snapshot chunk size (bytes)
48    pub snapshot_chunk_size: usize,
49}
50
51impl RaftNodeConfig {
52    /// Create a new configuration with defaults
53    pub fn new(node_id: NodeId, cluster_members: Vec<NodeId>) -> Self {
54        Self {
55            node_id,
56            cluster_members,
57            election_timeout_min: 150,
58            election_timeout_max: 300,
59            heartbeat_interval: 50,
60            max_entries_per_message: 100,
61            snapshot_chunk_size: 64 * 1024, // 64KB
62        }
63    }
64}
65
66/// Command to apply to the state machine
67#[derive(Debug, Clone)]
68pub struct Command {
69    pub data: Vec<u8>,
70}
71
72/// Result of applying a command
73#[derive(Debug, Clone)]
74pub struct CommandResult {
75    pub index: LogIndex,
76    pub term: Term,
77}
78
79/// Internal messages for the Raft node
80#[derive(Debug)]
81enum InternalMessage {
82    /// RPC message from another node
83    Rpc { from: NodeId, message: RaftMessage },
84    /// Client command to replicate
85    ClientCommand {
86        command: Command,
87        response_tx: mpsc::Sender<RaftResult<CommandResult>>,
88    },
89    /// Election timeout fired
90    ElectionTimeout,
91    /// Heartbeat timeout fired
92    HeartbeatTimeout,
93}
94
95/// The Raft consensus node
96pub struct RaftNode {
97    /// Configuration
98    config: RaftNodeConfig,
99
100    /// Persistent state
101    persistent: Arc<RwLock<PersistentState>>,
102
103    /// Volatile state
104    volatile: Arc<RwLock<VolatileState>>,
105
106    /// Current Raft state (Follower, Candidate, Leader)
107    state: Arc<RwLock<RaftState>>,
108
109    /// Leader-specific state (only valid when state is Leader)
110    leader_state: Arc<RwLock<Option<LeaderState>>>,
111
112    /// Election state
113    election_state: Arc<RwLock<ElectionState>>,
114
115    /// Current leader ID (if known)
116    current_leader: Arc<RwLock<Option<NodeId>>>,
117
118    /// Channel for internal messages
119    internal_tx: mpsc::UnboundedSender<InternalMessage>,
120    internal_rx: Arc<RwLock<mpsc::UnboundedReceiver<InternalMessage>>>,
121}
122
123impl RaftNode {
124    /// Create a new Raft node
125    pub fn new(config: RaftNodeConfig) -> Self {
126        let (internal_tx, internal_rx) = mpsc::unbounded_channel();
127        let cluster_size = config.cluster_members.len();
128
129        Self {
130            persistent: Arc::new(RwLock::new(PersistentState::new())),
131            volatile: Arc::new(RwLock::new(VolatileState::new())),
132            state: Arc::new(RwLock::new(RaftState::Follower)),
133            leader_state: Arc::new(RwLock::new(None)),
134            election_state: Arc::new(RwLock::new(ElectionState::new(
135                cluster_size,
136                config.election_timeout_min,
137                config.election_timeout_max,
138            ))),
139            current_leader: Arc::new(RwLock::new(None)),
140            config,
141            internal_tx,
142            internal_rx: Arc::new(RwLock::new(internal_rx)),
143        }
144    }
145
146    /// Start the Raft node
147    pub async fn start(self: Arc<Self>) {
148        info!("Starting Raft node: {}", self.config.node_id);
149
150        // Spawn election timer task
151        self.clone().spawn_election_timer();
152
153        // Spawn heartbeat timer task (for leaders)
154        self.clone().spawn_heartbeat_timer();
155
156        // Main message processing loop
157        self.run().await;
158    }
159
160    /// Main message processing loop
161    async fn run(self: Arc<Self>) {
162        loop {
163            let message = {
164                let mut rx = self.internal_rx.write();
165                rx.recv().await
166            };
167
168            match message {
169                Some(InternalMessage::Rpc { from, message }) => {
170                    self.handle_rpc_message(from, message).await;
171                }
172                Some(InternalMessage::ClientCommand {
173                    command,
174                    response_tx,
175                }) => {
176                    self.handle_client_command(command, response_tx).await;
177                }
178                Some(InternalMessage::ElectionTimeout) => {
179                    self.handle_election_timeout().await;
180                }
181                Some(InternalMessage::HeartbeatTimeout) => {
182                    self.handle_heartbeat_timeout().await;
183                }
184                None => {
185                    warn!("Internal channel closed, stopping node");
186                    break;
187                }
188            }
189        }
190    }
191
192    /// Handle RPC message from another node
193    async fn handle_rpc_message(&self, from: NodeId, message: RaftMessage) {
194        // Update term if necessary
195        let message_term = message.term();
196        let current_term = self.persistent.read().current_term;
197
198        if message_term > current_term {
199            self.step_down(message_term).await;
200        }
201
202        match message {
203            RaftMessage::AppendEntriesRequest(req) => {
204                let response = self.handle_append_entries(req).await;
205                // TODO: Send response back to sender
206                debug!("AppendEntries response to {}: {:?}", from, response);
207            }
208            RaftMessage::AppendEntriesResponse(resp) => {
209                self.handle_append_entries_response(from, resp).await;
210            }
211            RaftMessage::RequestVoteRequest(req) => {
212                let response = self.handle_request_vote(req).await;
213                // TODO: Send response back to sender
214                debug!("RequestVote response to {}: {:?}", from, response);
215            }
216            RaftMessage::RequestVoteResponse(resp) => {
217                self.handle_request_vote_response(from, resp).await;
218            }
219            RaftMessage::InstallSnapshotRequest(req) => {
220                let response = self.handle_install_snapshot(req).await;
221                // TODO: Send response back to sender
222                debug!("InstallSnapshot response to {}: {:?}", from, response);
223            }
224            RaftMessage::InstallSnapshotResponse(resp) => {
225                self.handle_install_snapshot_response(from, resp).await;
226            }
227        }
228    }
229
230    /// Handle AppendEntries RPC
231    async fn handle_append_entries(&self, req: AppendEntriesRequest) -> AppendEntriesResponse {
232        let mut persistent = self.persistent.write();
233        let mut volatile = self.volatile.write();
234
235        // Reply false if term < currentTerm
236        if req.term < persistent.current_term {
237            return AppendEntriesResponse::failure(persistent.current_term, None, None);
238        }
239
240        // Reset election timer
241        self.election_state.write().reset_timer();
242        *self.current_leader.write() = Some(req.leader_id.clone());
243
244        // Reply false if log doesn't contain an entry at prevLogIndex with prevLogTerm
245        if !persistent
246            .log
247            .matches(req.prev_log_index, req.prev_log_term)
248        {
249            let conflict_index = req.prev_log_index;
250            let conflict_term = persistent.log.term_at(conflict_index);
251            return AppendEntriesResponse::failure(
252                persistent.current_term,
253                Some(conflict_index),
254                conflict_term,
255            );
256        }
257
258        // Append new entries
259        if !req.entries.is_empty() {
260            // Delete conflicting entries and append new ones
261            let mut index = req.prev_log_index + 1;
262            for entry in &req.entries {
263                if let Some(existing_term) = persistent.log.term_at(index) {
264                    if existing_term != entry.term {
265                        // Conflict found, truncate from here
266                        let _ = persistent.log.truncate_from(index);
267                    }
268                }
269                index += 1;
270            }
271
272            // Append entries
273            if let Err(e) = persistent.log.append_entries(req.entries.clone()) {
274                error!("Failed to append entries: {}", e);
275                return AppendEntriesResponse::failure(persistent.current_term, None, None);
276            }
277        }
278
279        // Update commit index
280        if req.leader_commit > volatile.commit_index {
281            let last_new_entry = if req.entries.is_empty() {
282                req.prev_log_index
283            } else {
284                req.entries.last().unwrap().index
285            };
286            volatile.update_commit_index(std::cmp::min(req.leader_commit, last_new_entry));
287        }
288
289        AppendEntriesResponse::success(persistent.current_term, persistent.log.last_index())
290    }
291
292    /// Handle AppendEntries response
293    async fn handle_append_entries_response(&self, from: NodeId, resp: AppendEntriesResponse) {
294        if !self.state.read().is_leader() {
295            return;
296        }
297
298        let persistent = self.persistent.write();
299        let mut leader_state_guard = self.leader_state.write();
300
301        if let Some(leader_state) = leader_state_guard.as_mut() {
302            if resp.success {
303                // Update next_index and match_index
304                if let Some(match_index) = resp.match_index {
305                    leader_state.update_replication(&from, match_index);
306
307                    // Update commit index
308                    let new_commit = leader_state.calculate_commit_index();
309                    let mut volatile = self.volatile.write();
310                    if new_commit > volatile.commit_index {
311                        // Verify the entry is from current term
312                        if let Some(term) = persistent.log.term_at(new_commit) {
313                            if term == persistent.current_term {
314                                volatile.update_commit_index(new_commit);
315                                info!("Updated commit index to {}", new_commit);
316                            }
317                        }
318                    }
319                }
320            } else {
321                // Decrement next_index and retry
322                leader_state.decrement_next_index(&from);
323                debug!("Replication failed for {}, decrementing next_index", from);
324            }
325        }
326    }
327
328    /// Handle RequestVote RPC
329    async fn handle_request_vote(&self, req: RequestVoteRequest) -> RequestVoteResponse {
330        let mut persistent = self.persistent.write();
331
332        // Reply false if term < currentTerm
333        if req.term < persistent.current_term {
334            return RequestVoteResponse::denied(persistent.current_term);
335        }
336
337        let last_log_index = persistent.log.last_index();
338        let last_log_term = persistent.log.last_term();
339
340        // Check if we should grant vote
341        let should_grant = VoteValidator::should_grant_vote(
342            persistent.current_term,
343            &persistent.voted_for,
344            last_log_index,
345            last_log_term,
346            &req.candidate_id,
347            req.term,
348            req.last_log_index,
349            req.last_log_term,
350        );
351
352        if should_grant {
353            persistent.vote_for(req.candidate_id.clone());
354            self.election_state.write().reset_timer();
355            info!("Granted vote to {} for term {}", req.candidate_id, req.term);
356            RequestVoteResponse::granted(persistent.current_term)
357        } else {
358            debug!("Denied vote to {} for term {}", req.candidate_id, req.term);
359            RequestVoteResponse::denied(persistent.current_term)
360        }
361    }
362
363    /// Handle RequestVote response
364    async fn handle_request_vote_response(&self, from: NodeId, resp: RequestVoteResponse) {
365        if !self.state.read().is_candidate() {
366            return;
367        }
368
369        let current_term = self.persistent.read().current_term;
370        if resp.term != current_term {
371            return;
372        }
373
374        if resp.vote_granted {
375            let won_election = self.election_state.write().record_vote(from.clone());
376            if won_election {
377                info!("Won election for term {}", current_term);
378                self.become_leader().await;
379            }
380        }
381    }
382
383    /// Handle InstallSnapshot RPC
384    async fn handle_install_snapshot(
385        &self,
386        req: InstallSnapshotRequest,
387    ) -> InstallSnapshotResponse {
388        let persistent = self.persistent.write();
389
390        if req.term < persistent.current_term {
391            return InstallSnapshotResponse::failure(persistent.current_term);
392        }
393
394        // TODO: Implement snapshot installation
395        // For now, just acknowledge
396        InstallSnapshotResponse::success(persistent.current_term, None)
397    }
398
399    /// Handle InstallSnapshot response
400    async fn handle_install_snapshot_response(
401        &self,
402        _from: NodeId,
403        _resp: InstallSnapshotResponse,
404    ) {
405        // TODO: Implement snapshot response handling
406    }
407
408    /// Handle client command
409    async fn handle_client_command(
410        &self,
411        command: Command,
412        response_tx: mpsc::Sender<RaftResult<CommandResult>>,
413    ) {
414        // Only leader can handle client commands
415        if !self.state.read().is_leader() {
416            let _ = response_tx.send(Err(RaftError::NotLeader)).await;
417            return;
418        }
419
420        let mut persistent = self.persistent.write();
421        let term = persistent.current_term;
422        let index = persistent.log.append(term, command.data);
423
424        let result = CommandResult { index, term };
425        let _ = response_tx.send(Ok(result)).await;
426
427        // Trigger immediate replication
428        drop(persistent);
429        let _ = self.internal_tx.send(InternalMessage::HeartbeatTimeout);
430    }
431
432    /// Handle election timeout
433    async fn handle_election_timeout(&self) {
434        if self.state.read().is_leader() {
435            return;
436        }
437
438        if !self.election_state.read().should_start_election() {
439            return;
440        }
441
442        info!("Election timeout, starting election");
443        self.start_election().await;
444    }
445
446    /// Start a new election
447    async fn start_election(&self) {
448        // Transition to candidate
449        *self.state.write() = RaftState::Candidate;
450
451        // Increment term and vote for self
452        let mut persistent = self.persistent.write();
453        persistent.increment_term();
454        persistent.vote_for(self.config.node_id.clone());
455        let term = persistent.current_term;
456
457        // Initialize election state
458        self.election_state
459            .write()
460            .start_election(term, &self.config.node_id);
461
462        let last_log_index = persistent.log.last_index();
463        let last_log_term = persistent.log.last_term();
464
465        info!(
466            "Starting election for term {} as {}",
467            term, self.config.node_id
468        );
469
470        // Send RequestVote RPCs to all other nodes
471        for member in &self.config.cluster_members {
472            if member != &self.config.node_id {
473                let _request = RequestVoteRequest::new(
474                    term,
475                    self.config.node_id.clone(),
476                    last_log_index,
477                    last_log_term,
478                );
479                // TODO: Send request to member
480                debug!("Would send RequestVote to {}", member);
481            }
482        }
483    }
484
485    /// Become leader after winning election
486    async fn become_leader(&self) {
487        info!(
488            "Becoming leader for term {}",
489            self.persistent.read().current_term
490        );
491
492        *self.state.write() = RaftState::Leader;
493        *self.current_leader.write() = Some(self.config.node_id.clone());
494
495        let last_log_index = self.persistent.read().log.last_index();
496        let other_members: Vec<_> = self
497            .config
498            .cluster_members
499            .iter()
500            .filter(|m| *m != &self.config.node_id)
501            .cloned()
502            .collect();
503
504        *self.leader_state.write() = Some(LeaderState::new(&other_members, last_log_index));
505
506        // Send initial heartbeats
507        let _ = self.internal_tx.send(InternalMessage::HeartbeatTimeout);
508    }
509
510    /// Step down to follower (when discovering higher term)
511    async fn step_down(&self, term: Term) {
512        info!("Stepping down to follower for term {}", term);
513
514        *self.state.write() = RaftState::Follower;
515        *self.leader_state.write() = None;
516        *self.current_leader.write() = None;
517
518        let mut persistent = self.persistent.write();
519        persistent.update_term(term);
520    }
521
522    /// Handle heartbeat timeout (for leaders)
523    async fn handle_heartbeat_timeout(&self) {
524        if !self.state.read().is_leader() {
525            return;
526        }
527
528        self.send_heartbeats().await;
529    }
530
531    /// Send heartbeats to all followers
532    async fn send_heartbeats(&self) {
533        let persistent = self.persistent.read();
534        let term = persistent.current_term;
535        let commit_index = self.volatile.read().commit_index;
536
537        for member in &self.config.cluster_members {
538            if member != &self.config.node_id {
539                let request = AppendEntriesRequest::heartbeat(
540                    term,
541                    self.config.node_id.clone(),
542                    commit_index,
543                );
544                // TODO: Send heartbeat to member
545                debug!("Would send heartbeat to {}", member);
546            }
547        }
548    }
549
550    /// Spawn election timer task
551    fn spawn_election_timer(self: Arc<Self>) {
552        let node = self.clone();
553        tokio::spawn(async move {
554            let mut interval = interval(Duration::from_millis(50));
555            loop {
556                interval.tick().await;
557                if node.election_state.read().should_start_election() {
558                    let _ = node.internal_tx.send(InternalMessage::ElectionTimeout);
559                }
560            }
561        });
562    }
563
564    /// Spawn heartbeat timer task
565    fn spawn_heartbeat_timer(self: Arc<Self>) {
566        let node = self.clone();
567        tokio::spawn(async move {
568            let interval_ms = node.config.heartbeat_interval;
569            let mut interval = interval(Duration::from_millis(interval_ms));
570            loop {
571                interval.tick().await;
572                if node.state.read().is_leader() {
573                    let _ = node.internal_tx.send(InternalMessage::HeartbeatTimeout);
574                }
575            }
576        });
577    }
578
579    /// Submit a command to the Raft cluster
580    pub async fn submit_command(&self, data: Vec<u8>) -> RaftResult<CommandResult> {
581        let (tx, mut rx) = mpsc::channel(1);
582        let command = Command { data };
583
584        self.internal_tx
585            .send(InternalMessage::ClientCommand {
586                command,
587                response_tx: tx,
588            })
589            .map_err(|_| RaftError::Internal("Node stopped".to_string()))?;
590
591        rx.recv()
592            .await
593            .ok_or_else(|| RaftError::Internal("Response channel closed".to_string()))?
594    }
595
596    /// Get current state
597    pub fn current_state(&self) -> RaftState {
598        *self.state.read()
599    }
600
601    /// Get current term
602    pub fn current_term(&self) -> Term {
603        self.persistent.read().current_term
604    }
605
606    /// Get current leader
607    pub fn current_leader(&self) -> Option<NodeId> {
608        self.current_leader.read().clone()
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615
616    #[test]
617    fn test_node_creation() {
618        let config = RaftNodeConfig::new(
619            "node1".to_string(),
620            vec![
621                "node1".to_string(),
622                "node2".to_string(),
623                "node3".to_string(),
624            ],
625        );
626
627        let node = RaftNode::new(config);
628        assert_eq!(node.current_state(), RaftState::Follower);
629        assert_eq!(node.current_term(), 0);
630    }
631}