vortex_raft/
lib.rs

1pub mod raft;
2
3pub use anyhow::Context;
4pub use raft::*;
5pub use serde::{Deserialize, Serialize};
6pub use std::io::{StdoutLock, Write};
7
8use serde::de::DeserializeOwned;
9use std::{
10    fmt::Debug,
11    sync::{
12        atomic::{AtomicBool, Ordering},
13        mpsc::{self, Sender},
14        Arc,
15    },
16    thread::{self, JoinHandle},
17    time::{Duration, Instant},
18};
19
20use anyhow::{anyhow, bail};
21
22#[derive(Debug, Deserialize, Serialize, Clone)]
23#[serde(untagged)]
24pub enum Event<Payload: Clone, Signal = (), LogEntry: Clone = ()> {
25    Message(Message<Payload>),
26    Signal(Signal),
27    Raft(RaftEvent<LogEntry>),
28    EOF,
29}
30
31#[derive(Debug, Deserialize, Serialize, Clone)]
32pub struct Message<Payload: Clone> {
33    pub src: String,
34    pub dest: String,
35    pub body: Body<Payload>,
36}
37
38#[derive(Debug, Deserialize, Serialize, Clone)]
39pub struct Body<Payload: Clone> {
40    pub msg_id: Option<u64>,
41    pub in_reply_to: Option<u64>,
42    #[serde(flatten)]
43    pub payload: Payload,
44}
45
46#[derive(Debug, Deserialize, Serialize, Clone)]
47#[serde(tag = "type")]
48#[serde(rename_all = "snake_case")]
49enum InitPayload {
50    Init {
51        node_id: String,
52        node_ids: Vec<String>,
53    },
54    InitOk,
55}
56
57pub struct Network {
58    output: StdoutLock<'static>,
59    pub node_id: String,
60    pub neighbors: Vec<String>,
61    pub all_nodes: Vec<String>,
62}
63
64impl Network {
65    pub fn reply<Payload: Serialize + Clone>(
66        &mut self,
67        dest: String,
68        msg_id: Option<u64>,
69        in_reply_to: Option<u64>,
70        payload: Payload,
71    ) -> anyhow::Result<()> {
72        let reply = Message {
73            src: self.node_id.clone(),
74            dest,
75            body: Body {
76                msg_id,
77                in_reply_to,
78                payload,
79            },
80        };
81
82        serde_json::to_writer(&mut self.output, &reply).context("Serialize reply")?;
83        self.output.write_all(b"\n")?;
84        Ok(())
85    }
86
87    pub fn send<Payload: Serialize + Clone>(
88        &mut self,
89        dest: String,
90        body: Body<Payload>,
91    ) -> anyhow::Result<()> {
92        let msg = Message {
93            src: self.node_id.clone(),
94            dest,
95            body,
96        };
97
98        serde_json::to_writer(&mut self.output, &msg).context("Serialize message")?;
99        self.output.write_all(b"\n")?;
100        Ok(())
101    }
102
103    pub fn send_msg<Payload: Serialize + Clone>(
104        &mut self,
105        msg: Message<Payload>,
106    ) -> anyhow::Result<()> {
107        serde_json::to_writer(&mut self.output, &msg).context("Serialize message")?;
108        self.output.write_all(b"\n")?;
109        Ok(())
110    }
111
112    /// sqrt(n) root nodes, all with sqrt(n)-1 children
113    /// Each child connects to all the root nodes
114    pub fn set_sqrt_topology(&mut self) {
115        if self.all_nodes.len() < 4 {
116            self.set_mesh_topology();
117            return;
118        }
119
120        let root_nodes = (self.all_nodes.len() as f64).sqrt() as usize;
121
122        let idx = self
123            .all_nodes
124            .iter()
125            .position(|n| n == &self.node_id)
126            .unwrap_or_else(|| panic!("Node {} is unknown", self.node_id));
127
128        self.neighbors = if idx % root_nodes == 0 {
129            // Node is a root node
130            (idx + 1..(idx + root_nodes).min(self.all_nodes.len()))
131                .map(|i| self.all_nodes[i].clone())
132                .collect()
133        } else {
134            // Node is a child node
135            (0..self.all_nodes.len())
136                .filter(|i| i % root_nodes == 0)
137                .map(|i| self.all_nodes[i].clone())
138                .collect()
139        };
140    }
141
142    pub fn set_mesh_topology(&mut self) {
143        self.neighbors = self.all_nodes.clone();
144    }
145
146    pub fn is_singleton(&self) -> bool {
147        self.all_nodes.len() == 1
148    }
149}
150
151pub struct IdCounter(u64);
152
153impl IdCounter {
154    pub fn new() -> Self {
155        Self(0)
156    }
157    pub fn peek(&mut self) -> u64 {
158        self.0
159    }
160}
161
162impl Iterator for IdCounter {
163    type Item = u64;
164
165    fn next(&mut self) -> Option<Self::Item> {
166        let prev = self.0;
167        self.0 += 1;
168        Some(prev)
169    }
170}
171
172impl Default for IdCounter {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178pub trait Service<Payload, Signal = (), RaftEntry = ()>: Sized
179where
180    Payload: DeserializeOwned + Send + Clone + 'static + Debug,
181    Signal: Send + 'static,
182    RaftEntry: Clone + DeserializeOwned + Send + 'static + Debug,
183{
184    fn create(network: &mut Network, sender: Sender<Event<Payload, Signal, RaftEntry>>) -> Self;
185
186    fn step(
187        &mut self,
188        input: Event<Payload, Signal, RaftEntry>,
189        network: &mut Network,
190    ) -> anyhow::Result<()>;
191
192    fn run() -> anyhow::Result<()> {
193        let mut stdin = std::io::stdin().lock();
194
195        let mut input =
196            serde_json::Deserializer::from_reader(&mut stdin).into_iter::<Message<InitPayload>>();
197
198        let init = input
199            .next()
200            .expect("input will block until next message")
201            .context("Deserialize init message")?;
202
203        // Initialize state
204        let (node_id, all_nodes) = match &init.body.payload {
205            InitPayload::Init { node_id, node_ids } => (node_id.clone(), node_ids.clone()),
206            _ => bail!("First message should have been an init message"),
207        };
208
209        let mut network = Network {
210            output: std::io::stdout().lock(),
211            node_id,
212            neighbors: Vec::new(),
213            all_nodes,
214        };
215
216        let (sender, receiver) = mpsc::channel();
217
218        network
219            .reply(init.src, None, init.body.msg_id, InitPayload::InitOk)
220            .context("Init reply")?;
221
222        drop(stdin);
223        let sender_clone = sender.clone();
224        let handle = thread::spawn(move || {
225            let stdin = std::io::stdin().lock();
226            let input = serde_json::Deserializer::from_reader(stdin)
227                .into_iter::<Event<Payload, (), RaftEntry>>();
228
229            for event in input {
230                let event = event.context("Deserialize event")?;
231
232                let event: Event<Payload, Signal, RaftEntry> = match event {
233                    Event::Message(msg) => Event::Message(msg),
234                    Event::Raft(msg) => Event::Raft(msg),
235                    _ => bail!("Got local event over the network"),
236                };
237
238                if sender_clone.send(event).is_err() {
239                    return Ok::<_, anyhow::Error>(());
240                }
241            }
242
243            let _ = sender_clone.send(Event::EOF);
244
245            Ok(())
246        });
247
248        let mut service = Self::create(&mut network, sender);
249        for event in receiver {
250            service.step(event, &mut network)?;
251        }
252
253        handle
254            .join()
255            .expect("Stdin reader thread panicked")
256            .context("Stdin reader thread err")?;
257
258        Ok(())
259    }
260}
261
262pub fn spawn_timer<F>(
263    cb: Box<F>,
264    dur: Duration,
265    interrupt: Option<Arc<AtomicBool>>,
266) -> JoinHandle<anyhow::Result<()>>
267where
268    F: Fn() -> anyhow::Result<()> + Send + 'static,
269{
270    thread::spawn(move || {
271        let mut now = Instant::now();
272        loop {
273            if let Some(interrupt) = &interrupt {
274                if interrupt.load(Ordering::Relaxed) {
275                    now = Instant::now();
276                    interrupt.store(false, Ordering::Relaxed);
277                }
278            }
279
280            if now.elapsed() > dur {
281                cb()?;
282                now = Instant::now();
283            }
284        }
285    })
286}
287
288pub trait SenderExt<T>: Send + Sync + 'static {
289    fn send(&self, t: T) -> anyhow::Result<()>;
290
291    fn map_input<U, F>(self, func: F) -> MapSender<Self, F>
292    where
293        Self: Sized,
294        F: Fn(U) -> T,
295    {
296        MapSender { sender: self, func }
297    }
298}
299
300impl<T: Send + 'static> SenderExt<T> for Sender<T> {
301    fn send(&self, t: T) -> anyhow::Result<()> {
302        self.send(t).map_err(|_| anyhow!("Send error"))
303    }
304}
305
306#[derive(Clone)]
307pub struct MapSender<S, F> {
308    sender: S,
309    func: F,
310}
311
312impl<S: SenderExt<U>, F, T, U> SenderExt<T> for MapSender<S, F>
313where
314    F: Fn(T) -> U + Clone + Send + Sync + 'static,
315{
316    fn send(&self, value: T) -> anyhow::Result<()> {
317        self.sender.send((self.func)(value))
318    }
319}