1pub mod raft;
2
3pub use anyhow::Context;
4pub use raft::*;
5pub use serde::{Deserialize, Serialize};
6pub use std::io::{StdoutLock, Write};
7
8use serde::de::DeserializeOwned;
9use std::{
10 fmt::Debug,
11 sync::{
12 atomic::{AtomicBool, Ordering},
13 mpsc::{self, Sender},
14 Arc,
15 },
16 thread::{self, JoinHandle},
17 time::{Duration, Instant},
18};
19
20use anyhow::{anyhow, bail};
21
22#[derive(Debug, Deserialize, Serialize, Clone)]
23#[serde(untagged)]
24pub enum Event<Payload: Clone, Signal = (), LogEntry: Clone = ()> {
25 Message(Message<Payload>),
26 Signal(Signal),
27 Raft(RaftEvent<LogEntry>),
28 EOF,
29}
30
31#[derive(Debug, Deserialize, Serialize, Clone)]
32pub struct Message<Payload: Clone> {
33 pub src: String,
34 pub dest: String,
35 pub body: Body<Payload>,
36}
37
38#[derive(Debug, Deserialize, Serialize, Clone)]
39pub struct Body<Payload: Clone> {
40 pub msg_id: Option<u64>,
41 pub in_reply_to: Option<u64>,
42 #[serde(flatten)]
43 pub payload: Payload,
44}
45
46#[derive(Debug, Deserialize, Serialize, Clone)]
47#[serde(tag = "type")]
48#[serde(rename_all = "snake_case")]
49enum InitPayload {
50 Init {
51 node_id: String,
52 node_ids: Vec<String>,
53 },
54 InitOk,
55}
56
57pub struct Network {
58 output: StdoutLock<'static>,
59 pub node_id: String,
60 pub neighbors: Vec<String>,
61 pub all_nodes: Vec<String>,
62}
63
64impl Network {
65 pub fn reply<Payload: Serialize + Clone>(
66 &mut self,
67 dest: String,
68 msg_id: Option<u64>,
69 in_reply_to: Option<u64>,
70 payload: Payload,
71 ) -> anyhow::Result<()> {
72 let reply = Message {
73 src: self.node_id.clone(),
74 dest,
75 body: Body {
76 msg_id,
77 in_reply_to,
78 payload,
79 },
80 };
81
82 serde_json::to_writer(&mut self.output, &reply).context("Serialize reply")?;
83 self.output.write_all(b"\n")?;
84 Ok(())
85 }
86
87 pub fn send<Payload: Serialize + Clone>(
88 &mut self,
89 dest: String,
90 body: Body<Payload>,
91 ) -> anyhow::Result<()> {
92 let msg = Message {
93 src: self.node_id.clone(),
94 dest,
95 body,
96 };
97
98 serde_json::to_writer(&mut self.output, &msg).context("Serialize message")?;
99 self.output.write_all(b"\n")?;
100 Ok(())
101 }
102
103 pub fn send_msg<Payload: Serialize + Clone>(
104 &mut self,
105 msg: Message<Payload>,
106 ) -> anyhow::Result<()> {
107 serde_json::to_writer(&mut self.output, &msg).context("Serialize message")?;
108 self.output.write_all(b"\n")?;
109 Ok(())
110 }
111
112 pub fn set_sqrt_topology(&mut self) {
115 if self.all_nodes.len() < 4 {
116 self.set_mesh_topology();
117 return;
118 }
119
120 let root_nodes = (self.all_nodes.len() as f64).sqrt() as usize;
121
122 let idx = self
123 .all_nodes
124 .iter()
125 .position(|n| n == &self.node_id)
126 .unwrap_or_else(|| panic!("Node {} is unknown", self.node_id));
127
128 self.neighbors = if idx % root_nodes == 0 {
129 (idx + 1..(idx + root_nodes).min(self.all_nodes.len()))
131 .map(|i| self.all_nodes[i].clone())
132 .collect()
133 } else {
134 (0..self.all_nodes.len())
136 .filter(|i| i % root_nodes == 0)
137 .map(|i| self.all_nodes[i].clone())
138 .collect()
139 };
140 }
141
142 pub fn set_mesh_topology(&mut self) {
143 self.neighbors = self.all_nodes.clone();
144 }
145
146 pub fn is_singleton(&self) -> bool {
147 self.all_nodes.len() == 1
148 }
149}
150
151pub struct IdCounter(u64);
152
153impl IdCounter {
154 pub fn new() -> Self {
155 Self(0)
156 }
157 pub fn peek(&mut self) -> u64 {
158 self.0
159 }
160}
161
162impl Iterator for IdCounter {
163 type Item = u64;
164
165 fn next(&mut self) -> Option<Self::Item> {
166 let prev = self.0;
167 self.0 += 1;
168 Some(prev)
169 }
170}
171
172impl Default for IdCounter {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178pub trait Service<Payload, Signal = (), RaftEntry = ()>: Sized
179where
180 Payload: DeserializeOwned + Send + Clone + 'static + Debug,
181 Signal: Send + 'static,
182 RaftEntry: Clone + DeserializeOwned + Send + 'static + Debug,
183{
184 fn create(network: &mut Network, sender: Sender<Event<Payload, Signal, RaftEntry>>) -> Self;
185
186 fn step(
187 &mut self,
188 input: Event<Payload, Signal, RaftEntry>,
189 network: &mut Network,
190 ) -> anyhow::Result<()>;
191
192 fn run() -> anyhow::Result<()> {
193 let mut stdin = std::io::stdin().lock();
194
195 let mut input =
196 serde_json::Deserializer::from_reader(&mut stdin).into_iter::<Message<InitPayload>>();
197
198 let init = input
199 .next()
200 .expect("input will block until next message")
201 .context("Deserialize init message")?;
202
203 let (node_id, all_nodes) = match &init.body.payload {
205 InitPayload::Init { node_id, node_ids } => (node_id.clone(), node_ids.clone()),
206 _ => bail!("First message should have been an init message"),
207 };
208
209 let mut network = Network {
210 output: std::io::stdout().lock(),
211 node_id,
212 neighbors: Vec::new(),
213 all_nodes,
214 };
215
216 let (sender, receiver) = mpsc::channel();
217
218 network
219 .reply(init.src, None, init.body.msg_id, InitPayload::InitOk)
220 .context("Init reply")?;
221
222 drop(stdin);
223 let sender_clone = sender.clone();
224 let handle = thread::spawn(move || {
225 let stdin = std::io::stdin().lock();
226 let input = serde_json::Deserializer::from_reader(stdin)
227 .into_iter::<Event<Payload, (), RaftEntry>>();
228
229 for event in input {
230 let event = event.context("Deserialize event")?;
231
232 let event: Event<Payload, Signal, RaftEntry> = match event {
233 Event::Message(msg) => Event::Message(msg),
234 Event::Raft(msg) => Event::Raft(msg),
235 _ => bail!("Got local event over the network"),
236 };
237
238 if sender_clone.send(event).is_err() {
239 return Ok::<_, anyhow::Error>(());
240 }
241 }
242
243 let _ = sender_clone.send(Event::EOF);
244
245 Ok(())
246 });
247
248 let mut service = Self::create(&mut network, sender);
249 for event in receiver {
250 service.step(event, &mut network)?;
251 }
252
253 handle
254 .join()
255 .expect("Stdin reader thread panicked")
256 .context("Stdin reader thread err")?;
257
258 Ok(())
259 }
260}
261
262pub fn spawn_timer<F>(
263 cb: Box<F>,
264 dur: Duration,
265 interrupt: Option<Arc<AtomicBool>>,
266) -> JoinHandle<anyhow::Result<()>>
267where
268 F: Fn() -> anyhow::Result<()> + Send + 'static,
269{
270 thread::spawn(move || {
271 let mut now = Instant::now();
272 loop {
273 if let Some(interrupt) = &interrupt {
274 if interrupt.load(Ordering::Relaxed) {
275 now = Instant::now();
276 interrupt.store(false, Ordering::Relaxed);
277 }
278 }
279
280 if now.elapsed() > dur {
281 cb()?;
282 now = Instant::now();
283 }
284 }
285 })
286}
287
288pub trait SenderExt<T>: Send + Sync + 'static {
289 fn send(&self, t: T) -> anyhow::Result<()>;
290
291 fn map_input<U, F>(self, func: F) -> MapSender<Self, F>
292 where
293 Self: Sized,
294 F: Fn(U) -> T,
295 {
296 MapSender { sender: self, func }
297 }
298}
299
300impl<T: Send + 'static> SenderExt<T> for Sender<T> {
301 fn send(&self, t: T) -> anyhow::Result<()> {
302 self.send(t).map_err(|_| anyhow!("Send error"))
303 }
304}
305
306#[derive(Clone)]
307pub struct MapSender<S, F> {
308 sender: S,
309 func: F,
310}
311
312impl<S: SenderExt<U>, F, T, U> SenderExt<T> for MapSender<S, F>
313where
314 F: Fn(T) -> U + Clone + Send + Sync + 'static,
315{
316 fn send(&self, value: T) -> anyhow::Result<()> {
317 self.sender.send((self.func)(value))
318 }
319}