1use std::{collections::HashMap, time::Duration};
26
27use async_channel::{Receiver, Sender};
28use async_timeouts::Timeout;
29use error::{Result, ResultExt, StateError};
30use inner::TinyRaftInner;
31pub use node::{NodeId, NodeType};
32use send::{Message, MessageSender, MessageType, MsgRouter, NodeState, SignalSender};
33use tokio::{task, time};
34use tracing::{error, info};
35
36pub mod error;
37pub mod inner;
38mod node;
39pub mod send;
40
41type OnReceive = (Sender<(NodeId, Message)>, Receiver<(NodeId, Message)>);
42
43#[derive(Clone)]
52pub struct TinyRaft {
53 stop_tinyraft_tx: Sender<()>,
54 on_receive: OnReceive,
55 raft_change_rx: Receiver<MsgRouter>,
56 set_nodes: (Sender<Vec<NodeId>>, Receiver<Vec<NodeId>>),
57}
58
59impl TinyRaft {
60 #[allow(clippy::too_many_arguments)]
108 pub async fn start(
109 nodes: impl IntoIterator<Item = impl Into<NodeId>>,
110 node_id: impl Into<NodeId>,
111 election_timeout: Duration,
112 heartbeat_timeout: Duration,
113 random_timeout: Option<Duration>,
114 leadership_timeout: Option<Duration>,
115 leader_priority: Option<u64>,
116 term: u64,
117 ) -> Self {
118 let nodes = match node::check_node_ids(nodes) {
119 Ok(nodes) => nodes,
120 Err(e) => panic!("{e:?}"),
121 };
122
123 let node_id = node_id.into();
124
125 info!(node = %node_id, nodes = ?nodes, term, "starting tinyraft node");
126
127 if !nodes.contains(&node_id) {
128 panic!("{}", StateError::SetNodesMissingSelf);
129 }
130
131 let leadership_timeout = leadership_timeout.unwrap_or_default();
132 let random_timeout = random_timeout.unwrap_or(election_timeout);
133
134 let (stop_tx, stop_rx) = async_channel::unbounded();
135 let (start_tx, start_rx) = async_channel::unbounded();
136 let (on_receive_tx, on_receive_rx) = async_channel::unbounded();
137 let (raft_change_tx, raft_change_rx) = async_channel::unbounded();
138 let (set_nodes_tx, set_nodes_rx) = async_channel::unbounded();
139 let (heartbeat_tx, heartbeat_rx) = async_channel::unbounded();
140
141 let set_nodes_rx_clone = set_nodes_rx.clone();
142 let on_receive_rx_clone = on_receive_rx.clone();
143
144 let heartbeat_sender = SignalSender::new(heartbeat_tx.clone());
145
146 let start_sender = SignalSender::new(start_tx.clone());
147 let message_sender = MessageSender::new(raft_change_tx);
148
149 let _ = start_tx.send(()).await;
150 let _ = heartbeat_tx.send(()).await;
151
152 task::spawn(async move {
153 let mut tinyraft = TinyRaftInner {
154 node_type: NodeType::Candidate,
155 confirmed: HashMap::new(),
156 leader_priority,
157 term,
158 leader: node_id.clone(),
159 followers: vec![],
160 votes: HashMap::new(),
161 node_id,
162 nodes,
163 election_timeout,
164 heartbeat_timeout,
165 start_sender: start_sender.clone(),
166 next_term: Timeout::default(),
167 heartbeat_sender: heartbeat_sender.clone(),
168 message_sender,
169 leadership_timeout,
170 random_timeout,
171 };
172
173 loop {
174 tokio::select! {
175 _ = stop_rx.recv() => {
176 info!(node = %tinyraft.node_id, term = tinyraft.term, "stopping tinyraft worker");
177 break
178 },
179
180 _ = start_rx.recv() => {
181 tinyraft.reset_election_state().await;
182 tinyraft.broadcast_vote_requests().await;
183 tinyraft.publish_candidate_state().await;
184 tinyraft.start_timeout(start_sender.clone()).await;
185 }
186
187 _ = heartbeat_rx.recv() => {
188 if let NodeType::Leader = tinyraft.node_type {
189 tinyraft.broadcast_heartbeat().await;
190 }
191
192 let heartbeat_sender = heartbeat_sender.clone();
193
194 task::spawn(async move {
195 time::sleep(heartbeat_timeout).await;
196
197 if let Err(e) = heartbeat_sender.send().await {
198 info!(error = ?e, "heartbeat loop stopped because node is shutting down");
199 }
200 });
201 },
202
203 Ok(nodes) = set_nodes_rx_clone.recv() => {
204 if let Err(e) = tinyraft.set_nodes(nodes).await {
205 error!(node = %tinyraft.node_id, error = ?e, "failed to update cluster nodes");
206 }
207 },
208
209 Ok((from, message)) = on_receive_rx_clone.recv() => {
210 if let Err(e) = tinyraft.on_receive(from, message).await {
211 error!(node = %tinyraft.node_id, error = ?e, "failed to handle incoming message");
212 panic!("{e:?}")
213 }
214 }
215 }
216 }
217 });
218
219 Self {
220 stop_tinyraft_tx: stop_tx,
221 on_receive: (on_receive_tx, on_receive_rx),
222 raft_change_rx,
223 set_nodes: (set_nodes_tx, set_nodes_rx),
224 }
225 }
226
227 pub async fn stop(&self) -> Result<()> {
228 info!("stop requested");
229 self.stop_tinyraft_tx.send(()).await.or_request_stop()
230 }
231
232 pub async fn on_receive(&self, from: NodeId, message: Message) -> Result<()> {
237 info!(from = %from, message = ?message, "enqueueing incoming message");
238 self.on_receive
239 .0
240 .send((from.clone(), message.clone()))
241 .await
242 .or_request_on_receive(from, message)
243 }
244
245 pub fn get_receiver_from_raft(&self) -> Receiver<MsgRouter> {
252 self.raft_change_rx.clone()
253 }
254
255 pub async fn set_nodes(
260 &self,
261 nodes: impl IntoIterator<Item = impl Into<NodeId>>,
262 ) -> Result<()> {
263 let nodes = node::check_node_ids(nodes)?;
264 info!(nodes = ?nodes, "requesting cluster nodes update");
265
266 self.set_nodes
267 .0
268 .send(nodes.clone())
269 .await
270 .or_request_set_nodes(nodes)
271 }
272}
273
274impl TinyRaftInner {
275 async fn reset_election_state(&mut self) {
276 info!(node = %self.node_id, next_term = self.term + 1, "starting election");
277 self.next_term.stop().await;
278 self.term += 1;
279 self.leader = self.node_id.clone();
280 self.node_type = NodeType::Candidate;
281 self.followers.clear();
282 self.confirmed.clear();
283 self.votes = HashMap::from([(self.node_id.clone(), vec![self.node_id.clone()])]);
284 }
285
286 async fn broadcast_vote_requests(&self) {
287 for node in self.nodes.iter().filter(|node| *node != &self.node_id) {
288 if let Err(e) = self
289 .message_sender
290 .send(MsgRouter::ToNode(
291 node.clone(),
292 Message {
293 msg_type: MessageType::VoteRequest,
294 term: self.term,
295 leader: Some(self.leader.clone()),
296 priority: self.leader_priority,
297 },
298 ))
299 .await
300 {
301 error!(node = %self.node_id, target = %node, term = self.term, error = ?e, "failed to send vote request");
302 panic!("start error: {e:?}");
303 }
304 }
305 }
306
307 async fn publish_candidate_state(&self) {
308 if let Err(e) = self
309 .message_sender
310 .send(MsgRouter::ToSelf(NodeState {
311 node_type: self.node_type,
312 term: self.term,
313 leader: None,
314 followers: vec![],
315 }))
316 .await
317 {
318 error!(node = %self.node_id, term = self.term, error = ?e, "failed to publish candidate state");
319 panic!("start error: {e:?}");
320 }
321 }
322
323 async fn broadcast_heartbeat(&mut self) {
324 info!(node = %self.node_id, term = self.term, "sending heartbeat to followers");
325 self.confirmed.clear();
326 for node in self.nodes.iter().filter(|node| *node != &self.node_id) {
327 if let Err(e) = self
328 .message_sender
329 .send(MsgRouter::ToNode(
330 node.clone(),
331 Message {
332 msg_type: MessageType::Ping,
333 term: self.term,
334 leader: None,
335 priority: self.leader_priority,
336 },
337 ))
338 .await
339 {
340 error!(node = %self.node_id, target = %node, term = self.term, error = ?e, "failed to send heartbeat");
341 panic!("heartbeat error: {e:?}");
342 }
343 }
344 }
345}