1pub(crate) mod skiff_proto {
2 tonic::include_proto!("skiff");
3}
4
5use crate::error::Error;
6use rand::{rngs::StdRng, Rng, SeedableRng};
7use serde::{Deserialize, Serialize};
8use skiff_proto::{
9 skiff_client::SkiffClient, skiff_server::SkiffServer, DeleteReply, DeleteRequest, GetReply,
10 GetRequest, InsertReply, InsertRequest,
11};
12use skiff_proto::{
13 Empty, EntryReply, EntryRequest, ListKeysReply, ListKeysRequest, PrefixReply, ServerReply,
14 ServerRequest, SubscribeReply, SubscribeRequest, VoteReply, VoteRequest,
15};
16use std::cmp::min;
17use std::pin::Pin;
18use std::time::Duration;
19use std::{collections::HashMap, str::FromStr};
20use std::{
21 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
22 sync::Arc,
23};
24use tokio_stream::wrappers::ReceiverStream;
25use tokio_stream::{Stream, StreamExt};
26
27use tokio::sync::{mpsc, watch, Mutex, Notify};
28use tonic::{transport::Channel, Request, Response, Status};
29use tracing::{debug, info, trace};
30use uuid::Uuid;
31
32const RAFT_META_TREE: &str = "__raft_meta";
33const RAFT_LOG_TREE: &str = "__raft_log";
34const KEY_CURRENT_TERM: &[u8] = b"current_term";
35const KEY_VOTED_FOR: &[u8] = b"voted_for";
36const KEY_LAST_APPLIED: &[u8] = b"last_applied";
37const KEY_NODE_ID: &[u8] = b"node_id";
38
39#[derive(Serialize, Deserialize)]
40struct PersistedLog {
41 index: u32,
42 term: u32,
43 action: Action,
44}
45
46fn persist_hard_state(conn: &sled::Db, term: u32, voted_for: Option<Uuid>) -> Result<(), Error> {
47 let meta = conn.open_tree(RAFT_META_TREE)?;
48 meta.insert(KEY_CURRENT_TERM, &term.to_be_bytes())?;
49 match voted_for {
50 Some(id) => {
51 meta.insert(KEY_VOTED_FOR, id.as_bytes().as_slice())?;
52 }
53 None => {
54 meta.remove(KEY_VOTED_FOR)?;
55 }
56 };
57 Ok(())
58}
59
60fn persist_log_entry(conn: &sled::Db, log: &Log) -> Result<(), Error> {
61 let tree = conn.open_tree(RAFT_LOG_TREE)?;
62 let persisted = PersistedLog {
63 index: log.index,
64 term: log.term,
65 action: log.action.clone(),
66 };
67 tree.insert(log.index.to_be_bytes(), bincode::serialize(&persisted)?)?;
68 Ok(())
69}
70
71fn truncate_log_from(conn: &sled::Db, from_index: u32) -> Result<(), Error> {
72 let tree = conn.open_tree(RAFT_LOG_TREE)?;
73 let keys: Vec<_> = tree
74 .range(from_index.to_be_bytes()..)
75 .keys()
76 .collect::<Result<Vec<_>, _>>()?;
77 for key in keys {
78 tree.remove(key)?;
79 }
80 Ok(())
81}
82
83fn persist_last_applied(conn: &sled::Db, last_applied: u32) -> Result<(), Error> {
84 let meta = conn.open_tree(RAFT_META_TREE)?;
85 meta.insert(KEY_LAST_APPLIED, &last_applied.to_be_bytes())?;
86 Ok(())
87}
88
89fn load_or_create_id(conn: &sled::Db) -> Result<Uuid, Error> {
90 let meta = conn.open_tree(RAFT_META_TREE)?;
91 if let Some(bytes) = meta.get(KEY_NODE_ID)? {
92 return Uuid::from_slice(bytes.as_ref()).map_err(|_| Error::DeserializeFailed);
93 }
94 let id = Uuid::new_v4();
95 meta.insert(KEY_NODE_ID, id.as_bytes().as_slice())?;
96 conn.flush()?;
97 Ok(id)
98}
99
100fn load_raft_state(conn: &sled::Db) -> Result<(u32, Option<Uuid>, Vec<Log>, u32), Error> {
101 let meta = conn.open_tree(RAFT_META_TREE)?;
102
103 let current_term = meta
104 .get(KEY_CURRENT_TERM)?
105 .map(|b| u32::from_be_bytes(b.as_ref().try_into().unwrap_or([0; 4])))
106 .unwrap_or(0);
107
108 let voted_for = meta
109 .get(KEY_VOTED_FOR)?
110 .and_then(|b| Uuid::from_slice(b.as_ref()).ok());
111
112 let last_applied = meta
113 .get(KEY_LAST_APPLIED)?
114 .map(|b| u32::from_be_bytes(b.as_ref().try_into().unwrap_or([0; 4])))
115 .unwrap_or(0);
116
117 let log_tree = conn.open_tree(RAFT_LOG_TREE)?;
118 let mut log = Vec::new();
119 for result in log_tree.iter() {
120 let (_, value) = result?;
121 let persisted: PersistedLog = bincode::deserialize(&value)?;
122 log.push(Log {
123 index: persisted.index,
124 term: persisted.term,
125 action: persisted.action,
126 committed: Arc::new((Mutex::new(true), Notify::new())),
127 });
128 }
129
130 Ok((current_term, voted_for, log, last_applied))
131}
132
133#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134enum Action {
135 Insert(String, Vec<u8>),
136 Delete(String),
137 Configure(HashMap<Uuid, Ipv4Addr>),
138}
139
140#[derive(Debug, Clone)]
141struct Log {
142 index: u32,
143 term: u32,
144 action: Action,
145 committed: Arc<(Mutex<bool>, Notify)>,
146}
147
148#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
150pub enum ElectionState {
151 Candidate,
153 Leader,
155 Follower(Uuid),
158}
159
160#[derive(Debug, Clone)]
161struct State {
162 election_state: ElectionState,
163 current_term: u32,
164 voted_for: Option<Uuid>,
165 committed_index: u32,
166 last_applied: u32,
167
168 peer_clients: HashMap<Uuid, Arc<Mutex<SkiffClient<Channel>>>>,
169
170 next_index: HashMap<Uuid, u32>,
171 match_index: HashMap<Uuid, u32>,
172
173 log: Vec<Log>,
174 conn: sled::Db,
175}
176
177#[derive(Debug, Clone)]
184pub struct Skiff {
185 id: Uuid,
186 address: Ipv4Addr,
187 port: u16,
188 state: Arc<Mutex<State>>,
189 tx_entries: Arc<Mutex<mpsc::Sender<u8>>>,
190 rx_entries: Arc<Mutex<mpsc::Receiver<u8>>>,
191 subscribers: Arc<Mutex<HashMap<String, Vec<mpsc::Sender<SubscribeReply>>>>>,
192 shutdown_tx: Arc<watch::Sender<bool>>,
193 shutdown_rx: watch::Receiver<bool>,
194}
195
196impl Skiff {
197 pub(crate) fn new(
198 address: Ipv4Addr,
199 port: u16,
200 data_dir: String,
201 peers: Vec<Ipv4Addr>,
202 ) -> Result<Self, Error> {
203 let conn = sled::open(data_dir)?;
204 let id = load_or_create_id(&conn)?;
205 let (tx_entries, rx_entries) = mpsc::channel(32);
206 let (shutdown_tx, shutdown_rx) = watch::channel(false);
207
208 let (current_term, voted_for, persisted_log, last_applied) = load_raft_state(&conn)?;
209
210 let mut cluster: HashMap<Uuid, Ipv4Addr> = peers
214 .into_iter()
215 .map(|addr| (Uuid::new_v4(), addr))
216 .collect();
217
218 cluster.insert(id, address);
219
220 let mut log = vec![Log {
221 term: 0,
222 index: 0,
223 action: Action::Configure(cluster),
224 committed: Arc::new((Mutex::new(true), Notify::new())),
225 }];
226 log.extend(persisted_log);
227
228 Ok(Skiff {
229 id,
230 address,
231 port,
232 state: Arc::new(Mutex::new(State {
233 election_state: ElectionState::Follower(Uuid::nil()),
234 current_term,
235 voted_for,
236 committed_index: last_applied,
238 last_applied,
239 peer_clients: HashMap::new(),
240 next_index: HashMap::new(),
241 match_index: HashMap::new(),
242 log,
243 conn,
244 })),
245 tx_entries: Arc::new(Mutex::new(tx_entries)),
246 rx_entries: Arc::new(Mutex::new(rx_entries)),
247 subscribers: Arc::new(Mutex::new(HashMap::new())),
248 shutdown_tx: Arc::new(shutdown_tx),
249 shutdown_rx,
250 })
251 }
252
253 pub fn get_id(&self) -> Uuid {
260 self.id
261 }
262
263 pub fn shutdown(&self) {
271 let _ = self.shutdown_tx.send(true);
272 }
273
274 pub fn get_address(&self) -> Ipv4Addr {
276 self.address
277 }
278
279 pub async fn is_leader_elected(&self) -> bool {
285 let election_state = self.get_election_state().await;
286 match election_state {
287 ElectionState::Leader => true,
288 ElectionState::Candidate => false,
289 ElectionState::Follower(id) => {
290 if Uuid::nil() == id {
291 return false;
292 }
293
294 true
295 }
296 }
297 }
298
299 pub async fn wait_for_leader(&self, timeout: Duration) -> Result<(), Error> {
332 tokio::time::timeout(timeout, async {
333 loop {
334 if self.is_leader_elected().await {
335 return;
336 }
337 tokio::time::sleep(Duration::from_millis(50)).await;
338 }
339 })
340 .await
341 .map_err(|_| Error::LeaderElectionTimeout)
342 }
343
344 pub async fn get_cluster(&self) -> Result<HashMap<Uuid, Ipv4Addr>, Error> {
354 let config = match self
355 .state
356 .lock()
357 .await
358 .log
359 .iter()
360 .rev()
361 .find(|log| matches!(log.action, Action::Configure(_)))
362 {
363 Some(log) => match &log.action {
364 Action::Configure(config) => config.clone(),
365 _ => return Err(Error::MissingClusterConfig),
366 },
367 _ => return Err(Error::MissingClusterConfig),
368 };
369
370 Ok(config)
371 }
372
373 async fn get_peers(&self) -> HashMap<Uuid, Ipv4Addr> {
374 self.get_cluster()
375 .await
376 .unwrap()
379 .into_iter()
380 .filter(|(_, addr)| *addr != self.address)
381 .collect()
382 }
383
384 async fn get_peer_client(
385 &self,
386 peer: &Uuid,
387 ) -> Result<Arc<Mutex<SkiffClient<Channel>>>, Error> {
388 let peers = self.get_peers().await;
389
390 if !peers.contains_key(peer) {
391 return Err(Error::PeerNotFound);
392 }
393
394 if let Some(client) = self.state.lock().await.peer_clients.get(peer) {
395 return Ok(client.clone());
396 }
397
398 match SkiffClient::connect(format!(
399 "http://{}",
400 SocketAddrV4::new(*peers.get(peer).unwrap(), self.port)
401 ))
402 .await
403 {
404 Ok(client) => {
405 let arc = Arc::new(Mutex::new(client));
406 self.state
407 .lock()
408 .await
409 .peer_clients
410 .insert(peer.to_owned(), arc.clone());
411 Ok(arc)
412 }
413 Err(_) => Err(Error::PeerConnectFailed),
414 }
415 }
416
417 async fn drop_peer_client(&self, id: &Uuid) {
418 let mut lock = self.state.lock().await;
419 let _ = lock.peer_clients.remove(id);
420 }
421
422 pub async fn get_election_state(&self) -> ElectionState {
424 self.state.lock().await.election_state.clone()
425 }
426
427 async fn set_election_state(&self, state: ElectionState) {
428 self.state.lock().await.election_state = state;
429 }
430
431 async fn get_current_term(&self) -> u32 {
432 self.state.lock().await.current_term
433 }
434
435 async fn increment_term(&self) {
436 self.state.lock().await.current_term += 1;
437 }
438
439 async fn set_current_term(&self, term: u32) {
440 self.state.lock().await.current_term = term;
441 }
442
443 async fn get_last_log_index(&self) -> u32 {
446 self.state
447 .lock()
448 .await
449 .log
450 .last()
451 .map(|last| last.index)
452 .unwrap_or(0)
453 }
454
455 async fn get_last_log_term(&self) -> u32 {
456 self.state
457 .lock()
458 .await
459 .log
460 .last()
461 .map(|last| last.term)
462 .unwrap_or(0)
463 }
464
465 async fn get_commit_index(&self) -> u32 {
466 self.state.lock().await.committed_index
467 }
468
469 async fn vote_for(&self, candidate_id: Option<Uuid>) {
470 self.state.lock().await.voted_for = candidate_id;
471 }
472
473 async fn get_voted_for(&self) -> Option<Uuid> {
474 self.state.lock().await.voted_for
475 }
476
477 async fn log(&self, action: Action) -> Arc<(Mutex<bool>, Notify)> {
478 let mut lock = self.state.lock().await;
479 let current_term = lock.current_term;
480 let last_index = lock.log.last().map(|last: &Log| last.index).unwrap_or(0);
481
482 let commit_pair = Arc::new((Mutex::new(false), Notify::new()));
483 lock.log.push(Log {
484 index: last_index + 1,
485 term: current_term,
486 action,
487 committed: commit_pair.clone(),
488 });
489
490 commit_pair
491 }
492
493 async fn get_prefixes(&self) -> Result<Vec<String>, Error> {
494 match self.state.lock().await.conn.get("trees")? {
495 Some(tree_vec) => match bincode::deserialize::<Vec<String>>(&tree_vec) {
496 Ok(trees) => Ok(trees),
497 Err(_) => Err(Error::DeserializeFailed),
498 },
499 None => Ok(vec![]),
500 }
501 }
502
503 async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, Error> {
504 let mut keys = vec![];
505 let trees = self.get_prefixes().await?;
506
507 if prefix.is_empty() || prefix == "/" {
509 let tree = self.state.lock().await.conn.open_tree("base")?;
510
511 tree.into_iter()
512 .keys()
513 .map(|key| String::from_utf8(key.unwrap().to_vec()).unwrap())
514 .for_each(|key| {
515 keys.push(key);
516 });
517 }
518
519 for tree_name in &trees {
520 let prefix_trimmed = match prefix.ends_with("/") {
521 true => prefix.trim_end_matches("/"),
522 false => prefix,
523 };
524
525 if tree_name.starts_with(prefix_trimmed) {
526 let tree = self
527 .state
528 .lock()
529 .await
530 .conn
531 .open_tree(format!("base_{}", tree_name.replace("/", "_")))?;
532
533 tree.into_iter()
534 .keys()
535 .map(|key| String::from_utf8(key.unwrap().to_vec()).unwrap())
536 .for_each(|key| {
537 keys.push(format!("{}/{}", tree_name, key));
538 });
539 }
540 }
541
542 Ok(keys)
543 }
544
545 async fn get_logs(&self, peer: &Uuid) -> (u32, u32, Vec<skiff_proto::Log>) {
546 let lock = self.state.lock().await;
547 let log_next_index: &u32 = lock.next_index.get(peer).unwrap();
548 let mut prev_log_index = 0;
549 let mut prev_log_term = 0;
550 let mut new_logs: Vec<skiff_proto::Log> = vec![];
551
552 for log in &lock.log {
553 if log.index < *log_next_index && log.index > prev_log_index {
555 prev_log_index = log.index;
556 prev_log_term = log.term;
557 } else if log.index >= *log_next_index {
558 new_logs.push(match &log.action {
559 Action::Insert(key, value) => skiff_proto::Log {
560 index: log.index,
561 term: log.term,
562 action: skiff_proto::Action::Insert as i32,
563 key: key.clone(),
564 value: Some(value.clone()),
565 },
566 Action::Delete(key) => skiff_proto::Log {
567 index: log.index,
568 term: log.term,
569 action: skiff_proto::Action::Delete as i32,
570 key: key.clone(),
571 value: None,
572 },
573 Action::Configure(config) => skiff_proto::Log {
574 index: log.index,
575 term: log.term,
576 action: skiff_proto::Action::Configure as i32,
577 key: "cluster".to_string(),
578 value: Some(bincode::serialize(&config).unwrap()),
579 },
580 });
581 }
582 }
583
584 (prev_log_index, prev_log_term, new_logs)
585 }
586
587 async fn commit_logs(&self) -> Result<(), Error> {
588 let committed_index = self.state.lock().await.committed_index;
589 let last_applied = self.state.lock().await.last_applied;
590 if committed_index <= last_applied {
591 return Ok(());
592 }
593
594 trace!("committing log entries");
595 let new_logs: Vec<(usize, Action)> = self
596 .state
597 .lock()
598 .await
599 .log
600 .iter()
601 .enumerate()
602 .filter(|(_, log)| log.index > last_applied)
603 .map(|(i, log)| (i, log.action.clone()))
604 .collect();
605
606 for (i, action) in new_logs {
607 match action {
608 Action::Insert(key, value) => {
610 let full_key = key.clone();
611 let mut tree_parts: Vec<&str> = key.split("/").collect();
612 let key = tree_parts.pop().unwrap();
613
614 let prefix = tree_parts.join("/");
615 let tree_name =
616 match prefix.len() {
617 0 => "base".to_string(),
618 _ => {
619 let _ = self.state.lock().await.conn.update_and_fetch(
620 "trees",
621 |trees| match trees {
622 Some(tree_vec) => {
623 let mut updated_tree_vec =
624 bincode::deserialize::<Vec<String>>(tree_vec)
625 .unwrap();
626
627 if !updated_tree_vec.contains(&prefix) {
628 updated_tree_vec.push(prefix.clone());
629 }
630
631 Some(bincode::serialize(&updated_tree_vec).unwrap())
632 }
633 None => Some(bincode::serialize(&vec![&prefix]).unwrap()),
634 },
635 );
636
637 format!("base_{}", &prefix.replace("/", "_"))
638 }
639 };
640
641 let mut subscribers = self.subscribers.lock().await;
643 for (sub_prefix, senders) in subscribers.iter_mut() {
644 if prefix.starts_with(sub_prefix.trim_end_matches("/")) {
646 let mut live_senders = Vec::new();
647 for sender in senders.drain(..) {
648 if sender
649 .send(SubscribeReply {
650 key: full_key.clone(),
651 action: skiff_proto::Action::Insert as i32,
652 value: Some(value.clone()),
653 })
654 .await
655 .is_ok()
656 {
657 live_senders.push(sender);
658 } }
660 *senders = live_senders;
661 }
662 }
663
664 let tree = self.state.lock().await.conn.open_tree(tree_name)?;
665 tree.insert(key, value)?;
666 }
667 Action::Delete(key) => {
668 let mut tree_parts: Vec<&str> = key.split("/").collect();
669 let key = tree_parts.pop().unwrap();
670
671 let prefix = tree_parts.join("/");
672 let tree_name = match prefix.len() {
673 0 => "base".to_string(),
674 _ => format!("base_{}", &prefix.replace("/", "_")),
675 };
676
677 let trees = self.state.lock().await.conn.tree_names();
678 if trees.contains(&tree_name.as_bytes().into()) {
679 let lock = self.state.lock().await;
680 let tree = lock.conn.open_tree(&tree_name)?;
681 tree.remove(key)?;
682
683 if tree.is_empty() {
684 let _ = lock.conn.drop_tree(&tree_name);
685 let _ = lock.conn.update_and_fetch("trees", |trees| match trees {
686 Some(tree_vec) => {
687 let mut updated_tree_vec =
688 bincode::deserialize::<Vec<String>>(tree_vec).unwrap();
689
690 trace!(trees = ?updated_tree_vec, name = %tree_name, "dropping tree");
691 if let Some(index) =
692 updated_tree_vec.iter().position(|name| name == &prefix)
693 {
694 updated_tree_vec.remove(index);
695 }
696
697 Some(bincode::serialize(&updated_tree_vec).unwrap())
698 }
699 None => Some(bincode::serialize(&Vec::<String>::new()).unwrap()),
700 });
701 }
702 }
703 }
704 Action::Configure(config) => {
705 self.state
707 .lock()
708 .await
709 .conn
710 .insert("cluster", bincode::serialize(&config).unwrap())?;
711 }
712 }
713
714 if let Some(log) = self.state.lock().await.log.get_mut(i) {
715 let (committed, notify) = &*log.committed;
716 *committed.lock().await = true;
717 notify.notify_one();
718 }
719 }
720
721 let conn = self.state.lock().await.conn.clone();
722 persist_last_applied(&conn, committed_index)?;
723 conn.flush_async().await.map_err(Error::SledError)?;
724 self.state.lock().await.last_applied = committed_index;
725
726 Ok(())
727 }
728
729 async fn reset_heartbeat_timer(&self) {
730 let _ = self.tx_entries.lock().await.send(1).await; }
732
733 fn initialize_service(&self) -> SkiffServer<Skiff> {
734 let skiff = self.clone();
735 drop(tokio::spawn(async move {
736 if skiff.get_cluster().await.unwrap().len() > 1 {
741 debug!("joining cluster");
742
743 let mut joined_cluster = false;
744 for id in skiff.get_peers().await.keys() {
745 debug!(?id, "asking peer to add us to cluster");
746
747 let mut request = Request::new(ServerRequest {
748 id: skiff.id.to_string(),
749 address: skiff.address.to_string(),
750 });
751
752 request.set_timeout(Duration::from_millis(300));
753
754 let client_arc = skiff.get_peer_client(id).await.unwrap();
755 let mut client = client_arc.lock().await;
756 match client.add_server(request).await {
757 Ok(response) => {
758 let inner = response.into_inner();
759 if inner.success {
760 if let Some(cluster) = inner.cluster {
761 skiff
764 .log(Action::Configure(
765 bincode::deserialize(&cluster).unwrap(),
766 ))
767 .await;
768
769 joined_cluster = true;
770
771 break;
772 }
773 }
774 }
775 Err(_) => continue,
776 }
777 }
778
779 if !joined_cluster {
780 return Err(Error::ClusterJoinFailed);
781 }
782 }
783 let mut server1 = skiff.clone();
786 let _elections: tokio::task::JoinHandle<Result<(), Error>> = tokio::spawn(async move {
787 server1.election_manager().await?;
788 Ok(())
789 });
790
791 Ok(())
792 }));
793
794 SkiffServer::new(self.clone())
795 }
796
797 pub async fn start(self) -> Result<(), Error> {
812 let service = self.initialize_service();
813 let mut shutdown = self.shutdown_rx.clone();
814
815 let bind_address = SocketAddr::new(self.address.into(), self.port);
816 tonic::transport::Server::builder()
817 .add_service(service)
818 .serve_with_shutdown(bind_address, async move {
819 let _ = shutdown.changed().await;
821 })
822 .await?;
823
824 Ok(())
825 }
826
827 #[allow(unreachable_code)]
828 async fn election_manager(&mut self) -> Result<(), Error> {
829 let mut rng = StdRng::from_entropy();
830
831 let mut rx_entries = self.rx_entries.lock().await;
832 let mut shutdown = self.shutdown_rx.clone();
833
834 loop {
835 self.commit_logs().await?;
836
837 if self.get_election_state().await == ElectionState::Leader {
838 tokio::select! {
840 _ = tokio::time::sleep(Duration::from_millis(75)) => {
841
842 let last_log_index = self.get_last_log_index().await;
845 let committed_index = self.get_commit_index().await;
846 let current_term = self.get_current_term().await;
847 for peer in self.get_peers().await.keys() {
848 let (peer_last_log_index, peer_last_log_term, entries) = self.get_logs(peer).await;
849 let num_entries = entries.len();
850 let mut request = Request::new(EntryRequest {
851 term: current_term,
852 leader_id: self.id.to_string(),
853 prev_log_index: peer_last_log_index,
854 prev_log_term: peer_last_log_term,
855 entries,
856 leader_commit: committed_index
857 });
858 request.set_timeout(Duration::from_millis(300));
860
861 let client_arc = match self.get_peer_client(peer).await {
862 Ok(c) => c,
863 Err(_) => { self.drop_peer_client(peer).await; continue; }
864 };
865 let mut client = client_arc.lock().await;
866 match client.append_entry(request).await {
867 Ok(response) => {
868 match response.into_inner().success {
869 true => {
870 if num_entries > 0 {
871 if let Some(value) = self.state.lock().await.next_index.get_mut(peer) {
872 *value = last_log_index + 1;
873 }
874
875 if let Some(value) = self.state.lock().await.match_index.get_mut(peer) {
876 *value = last_log_index;
877 }
878 }
879 },
880 false => {
881 if let Some(value) = self.state.lock().await.next_index.get_mut(peer) {
883 *value -= 1;
884 }
885 }
886 }
887 },
888 Err(_) => { self.drop_peer_client(peer).await; }
889 }
890 }
891
892 let num_peers = self.get_cluster().await?.len();
895 for i in ((committed_index + 1) ..=last_log_index).rev() {
896 let num_peers_applied = self.state.lock().await.match_index.iter().filter(|(_, &applied_index)| applied_index >= i).collect::<Vec<_>>().len() + 1;
898 let correct_term = self.state.lock().await.log.iter().filter(|log| log.index == i && log.term != current_term).collect::<Vec<_>>().is_empty();
900 if (num_peers_applied > num_peers / 2) && correct_term {
901 self.state.lock().await.committed_index = i;
902 break
903 }
904 }
905 }
906 _ = shutdown.changed() => { return Ok(()); }
907 }
908 } else {
909 tokio::select! {
910 Some(1) = rx_entries.recv() => {
911 continue;
912 }
913
914 _ = tokio::time::sleep(Duration::from_millis(rng.gen_range(150..300))) => {
915 debug!("election timeout, starting election");
916 self.run_election().await?;
917 }
918
919 _ = shutdown.changed() => { return Ok(()); }
920 };
921 }
922 }
923
924 Ok(())
925 }
926
927 async fn run_election(&self) -> Result<(), Error> {
928 self.set_election_state(ElectionState::Candidate).await;
929 self.increment_term().await;
930 self.vote_for(Some(self.id)).await;
931 let term = self.get_current_term().await;
932 let state = self.get_election_state().await;
933 debug!(?state, term, "starting election");
934
935 let mut num_votes: u32 = 1; for peer in self.get_peers().await.keys() {
938 trace!(?peer, "requesting vote");
939
940 let mut request = Request::new(VoteRequest {
941 term: self.get_current_term().await,
942 candidate_id: self.id.to_string(),
943 last_log_index: self.get_last_log_index().await,
944 last_log_term: self.get_last_log_term().await,
945 });
946 request.set_timeout(Duration::from_millis(300));
947
948 let client = match self.get_peer_client(peer).await {
949 Ok(c) => c,
950 Err(_) => {
951 self.drop_peer_client(peer).await;
952 continue;
953 }
954 };
955 let response = match client.lock().await.request_vote(request).await {
956 Ok(response) => response.into_inner(),
957 Err(_) => {
958 self.drop_peer_client(peer).await;
959 continue;
960 }
961 };
962
963 if response.vote_granted {
964 debug!(peer = ?&peer, "received vote");
965 num_votes += 1;
966 }
967 }
968
969 if num_votes > self.get_cluster().await?.len() as u32 / 2 {
970 info!("elected leader");
971 self.set_election_state(ElectionState::Leader).await;
972 self.vote_for(None).await;
973
974 let last_log_index = self.get_last_log_index().await;
975 let last_log_term = self.get_last_log_term().await;
976 let committed_index = self.get_commit_index().await;
977 let current_term = self.get_current_term().await;
978
979 {
980 let peers = self.get_peers().await;
981 let mut lock = self.state.lock().await;
982 lock.next_index = peers
983 .keys()
984 .map(|peer_id| (*peer_id, last_log_index + 1))
985 .collect::<HashMap<Uuid, u32>>();
986 lock.match_index = peers
987 .into_keys()
988 .map(|peer_id| (peer_id, 0))
989 .collect::<HashMap<Uuid, u32>>();
990 }
991
992 for id in self.get_peers().await.keys() {
994 let mut request = Request::new(EntryRequest {
995 term: current_term,
996 leader_id: self.id.to_string(),
997 prev_log_index: last_log_index,
998 prev_log_term: last_log_term,
999 entries: vec![],
1000 leader_commit: committed_index,
1001 });
1002 request.set_timeout(Duration::from_millis(300));
1003
1004 let client_arc = match self.get_peer_client(id).await {
1005 Ok(c) => c,
1006 Err(_) => {
1007 self.drop_peer_client(id).await;
1008 continue;
1009 }
1010 };
1011 let mut client = client_arc.lock().await;
1012 if client.append_entry(request).await.is_err() {
1013 self.drop_peer_client(id).await;
1014 }
1015 }
1016 }
1017
1018 Ok(())
1019 }
1020}
1021
1022#[tonic::async_trait]
1024impl skiff_proto::skiff_server::Skiff for Skiff {
1025 type SubscribeStream = Pin<Box<dyn Stream<Item = Result<SubscribeReply, Status>> + Send>>;
1026
1027 async fn request_vote(
1028 &self,
1029 request: Request<VoteRequest>,
1030 ) -> Result<Response<VoteReply>, Status> {
1031 trace!("received vote request");
1032
1033 let current_term = self.get_current_term().await;
1034 let voted_for = self.get_voted_for().await;
1035 let last_log_index = self.get_last_log_index().await;
1036 let last_log_term = self.get_last_log_term().await;
1037 let conn = self.state.lock().await.conn.clone();
1038
1039 let vote_request = request.into_inner();
1040
1041 let candidate_id = Uuid::from_str(&vote_request.candidate_id)
1042 .map_err(|_| Status::invalid_argument("invalid candidate id"))?;
1043
1044 let can_vote = vote_request.term > current_term
1048 || (vote_request.term == current_term
1049 && (voted_for.is_none() || voted_for == Some(candidate_id)));
1050
1051 if can_vote
1052 && vote_request.last_log_index >= last_log_index
1053 && vote_request.last_log_term >= last_log_term
1054 {
1055 debug!(?candidate_id, "granting vote");
1056
1057 persist_hard_state(&conn, vote_request.term, Some(candidate_id))
1059 .map_err(|_| Status::internal("failed to persist hard state"))?;
1060 conn.flush_async()
1061 .await
1062 .map_err(|_| Status::internal("failed to flush"))?;
1063
1064 self.vote_for(Some(candidate_id)).await;
1065 self.set_election_state(ElectionState::Follower(candidate_id))
1066 .await;
1067 self.set_current_term(vote_request.term).await;
1068
1069 return Ok(Response::new(VoteReply {
1070 term: vote_request.term,
1071 vote_granted: true,
1072 }));
1073 }
1074
1075 Ok(Response::new(VoteReply {
1076 term: current_term,
1077 vote_granted: false,
1078 }))
1079 }
1080
1081 async fn append_entry(
1082 &self,
1083 request: Request<EntryRequest>,
1084 ) -> Result<Response<EntryReply>, Status> {
1085 let entry_request = request.into_inner();
1086 let current_term = self.get_current_term().await;
1087 let conn = self.state.lock().await.conn.clone();
1088
1089 if entry_request.term < current_term {
1090 return Ok(Response::new(EntryReply {
1091 term: current_term,
1092 success: false,
1093 }));
1094 }
1095
1096 let term_changed = entry_request.term > current_term;
1098 if term_changed {
1099 self.set_current_term(entry_request.term).await;
1100 self.vote_for(None).await;
1101 }
1102
1103 self.set_election_state(ElectionState::Follower(
1105 Uuid::from_str(&entry_request.leader_id).unwrap(),
1106 ))
1107 .await;
1108 self.reset_heartbeat_timer().await;
1109
1110 if entry_request.prev_log_index > 0 {
1111 let mut found_matching_log = false;
1112 for log in &self.state.lock().await.log {
1113 if log.index == entry_request.prev_log_index
1114 && log.term == entry_request.prev_log_term
1115 {
1116 found_matching_log = true;
1117 break;
1118 }
1119 }
1120
1121 if !found_matching_log {
1122 return Ok(Response::new(EntryReply {
1123 term: entry_request.term,
1124 success: false,
1125 }));
1126 }
1127 }
1128
1129 for new_log in &entry_request.entries {
1130 let mut drop_index: Option<u32> = None;
1131 for current_log in &self.state.lock().await.log {
1132 if current_log.index == new_log.index && current_log.term != new_log.term {
1134 drop_index = Some(current_log.index);
1135 }
1136 }
1137
1138 if let Some(drop_index) = drop_index {
1139 self.state
1140 .lock()
1141 .await
1142 .log
1143 .retain(|log| log.index < drop_index);
1144 truncate_log_from(&conn, drop_index)
1145 .map_err(|_| Status::internal("failed to truncate log"))?;
1146 }
1147 }
1148
1149 let last_log_index = self.get_last_log_index().await;
1150 let new_term = entry_request.term;
1151 let mut appended_entries = false;
1152
1153 for new_log in entry_request.entries {
1154 if new_log.index > last_log_index {
1155 trace!("appending log entry");
1156 let log_entry = Log {
1157 index: new_log.index,
1158 term: new_log.term,
1159 action: match skiff_proto::Action::try_from(new_log.action) {
1160 Ok(skiff_proto::Action::Insert) => {
1161 Action::Insert(new_log.key, new_log.value.unwrap())
1162 }
1163 Ok(skiff_proto::Action::Delete) => Action::Delete(new_log.key),
1164 Ok(skiff_proto::Action::Configure) => Action::Configure(
1165 bincode::deserialize(&new_log.value.unwrap()).unwrap(),
1166 ),
1167 Err(_) => return Err(Status::invalid_argument("Invalid action")),
1168 },
1169 committed: Arc::new((Mutex::new(false), Notify::new())),
1170 };
1171 persist_log_entry(&conn, &log_entry)
1172 .map_err(|_| Status::internal("failed to persist log entry"))?;
1173 self.state.lock().await.log.push(log_entry);
1174 appended_entries = true;
1175 }
1176 }
1177
1178 if term_changed || appended_entries {
1180 if term_changed {
1181 persist_hard_state(&conn, new_term, None)
1182 .map_err(|_| Status::internal("failed to persist hard state"))?;
1183 }
1184 conn.flush_async()
1185 .await
1186 .map_err(|_| Status::internal("failed to flush"))?;
1187 }
1188
1189 if entry_request.leader_commit > self.get_commit_index().await {
1190 self.state.lock().await.committed_index =
1191 min(entry_request.leader_commit, self.get_last_log_index().await);
1192 }
1193
1194 Ok(Response::new(EntryReply {
1195 term: new_term,
1196 success: true,
1197 }))
1198 }
1199
1200 async fn add_server(
1201 &self,
1202 request: Request<ServerRequest>,
1203 ) -> Result<Response<ServerReply>, Status> {
1204 let election_state = self.state.lock().await.election_state.clone();
1205 if let ElectionState::Follower(leader) = election_state {
1206 let client = self.get_peer_client(&leader).await;
1207 if let Ok(client_inner) = client {
1208 return client_inner.lock().await.add_server(request).await;
1209 }
1210
1211 return Err(Status::internal("failed to forward request to leader"));
1212 }
1213
1214 debug!("adding server to cluster");
1215 let new_server = request.into_inner();
1216 let new_uuid = Uuid::from_str(&new_server.id).unwrap();
1217
1218 let mut cluster_config: HashMap<Uuid, Ipv4Addr> = self.get_cluster().await.unwrap();
1219
1220 let id = Uuid::from_str(&new_server.id).unwrap();
1221 let addr = Ipv4Addr::from_str(&new_server.address).unwrap();
1222
1223 if let std::collections::hash_map::Entry::Vacant(e) = cluster_config.entry(id) {
1224 e.insert(addr);
1225 self.log(Action::Configure(cluster_config.clone())).await;
1226 }
1227
1228 let last_log_index = self.get_last_log_index().await;
1229 self.state
1230 .lock()
1231 .await
1232 .next_index
1233 .insert(new_uuid, last_log_index + 1);
1234 self.state.lock().await.match_index.insert(new_uuid, 0);
1235
1236 Ok(Response::new(ServerReply {
1237 success: true,
1238 cluster: Some(bincode::serialize(&cluster_config).unwrap()),
1239 }))
1240 }
1241
1242 async fn remove_server(
1243 &self,
1244 request: Request<ServerRequest>,
1245 ) -> Result<Response<ServerReply>, Status> {
1246 let election_state = self.state.lock().await.election_state.clone();
1247 if let ElectionState::Follower(leader) = election_state {
1248 let client = self.get_peer_client(&leader).await;
1249 if let Ok(client_inner) = client {
1250 return client_inner.lock().await.remove_server(request).await;
1251 }
1252 return Err(Status::internal("failed to forward request to leader"));
1253 }
1254
1255 let remove_request = request.into_inner();
1256 let remove_id = Uuid::from_str(&remove_request.id)
1257 .map_err(|_| Status::invalid_argument("invalid server id"))?;
1258
1259 let mut cluster_config = self
1260 .get_cluster()
1261 .await
1262 .map_err(|_| Status::internal("failed to get cluster config"))?;
1263
1264 if cluster_config.remove(&remove_id).is_none() {
1265 return Err(Status::not_found("server not found in cluster"));
1266 }
1267
1268 self.log(Action::Configure(cluster_config.clone())).await;
1269
1270 {
1271 let mut state = self.state.lock().await;
1272 state.next_index.remove(&remove_id);
1273 state.match_index.remove(&remove_id);
1274 }
1275 self.drop_peer_client(&remove_id).await;
1276
1277 Ok(Response::new(ServerReply {
1278 success: true,
1279 cluster: Some(
1280 bincode::serialize(&cluster_config)
1281 .map_err(|_| Status::internal("failed to serialize cluster"))?,
1282 ),
1283 }))
1284 }
1285
1286 async fn get(&self, request: Request<GetRequest>) -> Result<Response<GetReply>, Status> {
1292 let election_state = self.state.lock().await.election_state.clone();
1299 if let ElectionState::Follower(leader) = election_state {
1300 let client = self.get_peer_client(&leader).await;
1301 if let Ok(client_inner) = client {
1302 return client_inner.lock().await.get(request).await;
1303 }
1304
1305 return Err(Status::internal("failed to forward request to leader"));
1306 }
1307
1308 let get_request = request.into_inner();
1309 let mut tree_parts: Vec<&str> = get_request.key.split("/").collect();
1310 let key = tree_parts.pop().unwrap();
1311
1312 let mut tree_name = tree_parts.join("/");
1313 tree_name = match tree_name.len() {
1314 0 => "base".to_string(),
1315 _ => format!("base_{}", tree_name.replace("/", "_")),
1316 };
1317
1318 if let Ok(tree) = self.state.lock().await.conn.open_tree(tree_name) {
1319 let value = tree.get(key);
1320 match value {
1321 Ok(inner1) => match inner1 {
1322 Some(data) => Ok(Response::new(GetReply {
1323 value: Some(data.to_vec()),
1324 })),
1325 None => Ok(Response::new(GetReply { value: None })),
1326 },
1327 Err(_) => Err(Status::internal("failed to query sled db")),
1328 }
1329 } else {
1330 Err(Status::internal("failed to open sled tree"))
1331 }
1332 }
1333
1334 async fn insert(
1335 &self,
1336 request: Request<InsertRequest>,
1337 ) -> Result<Response<InsertReply>, Status> {
1338 let election_state = self.state.lock().await.election_state.clone();
1340 if let ElectionState::Follower(leader) = election_state {
1341 let client = self.get_peer_client(&leader).await;
1342 if let Ok(client_inner) = client {
1343 return client_inner.lock().await.insert(request).await;
1344 }
1345
1346 return Err(Status::internal("failed to forward request to leader"));
1347 }
1348
1349 let insert_request = request.into_inner();
1350 let commit_arc = self
1351 .log(Action::Insert(insert_request.key, insert_request.value))
1352 .await;
1353
1354 let (_, notify) = &*commit_arc;
1355 tokio::time::timeout(Duration::from_secs(5), notify.notified())
1356 .await
1357 .map_err(|_| Status::deadline_exceeded("timed out waiting for commit"))?;
1358 Ok(Response::new(InsertReply { success: true }))
1359 }
1360
1361 async fn delete(
1362 &self,
1363 request: Request<DeleteRequest>,
1364 ) -> Result<Response<DeleteReply>, Status> {
1365 let election_state = self.state.lock().await.election_state.clone();
1367 if let ElectionState::Follower(leader) = election_state {
1368 let client = self.get_peer_client(&leader).await;
1369 if let Ok(client_inner) = client {
1370 return client_inner.lock().await.delete(request).await;
1371 }
1372
1373 return Err(Status::internal("failed to forward request to leader"));
1374 }
1375
1376 let delete_request = request.into_inner();
1377 let commit_arc = self.log(Action::Delete(delete_request.key)).await;
1378
1379 let (_, notify) = &*commit_arc;
1380 tokio::time::timeout(Duration::from_secs(5), notify.notified())
1381 .await
1382 .map_err(|_| Status::deadline_exceeded("timed out waiting for commit"))?;
1383 Ok(Response::new(DeleteReply { success: true }))
1384 }
1385
1386 async fn get_prefixes(&self, request: Request<Empty>) -> Result<Response<PrefixReply>, Status> {
1387 let election_state = self.state.lock().await.election_state.clone();
1389 if let ElectionState::Follower(leader) = election_state {
1390 let client = self.get_peer_client(&leader).await;
1391 if let Ok(client_inner) = client {
1392 return client_inner.lock().await.get_prefixes(request).await;
1393 }
1394
1395 return Err(Status::internal("failed to forward request to leader"));
1396 }
1397
1398 match self.get_prefixes().await {
1399 Ok(prefixes) => Ok(Response::new(PrefixReply { prefixes })),
1400 Err(_) => Err(Status::internal("failed to get prefixes")),
1401 }
1402 }
1403
1404 async fn list_keys(
1405 &self,
1406 request: Request<ListKeysRequest>,
1407 ) -> Result<Response<ListKeysReply>, Status> {
1408 let election_state = self.state.lock().await.election_state.clone();
1410 if let ElectionState::Follower(leader) = election_state {
1411 let client = self.get_peer_client(&leader).await;
1412 if let Ok(client_inner) = client {
1413 return client_inner.lock().await.list_keys(request).await;
1414 }
1415
1416 return Err(Status::internal("failed to forward request to leader"));
1417 }
1418
1419 match self.list_keys(request.into_inner().prefix.as_str()).await {
1420 Ok(keys) => Ok(Response::new(ListKeysReply { keys })),
1421 Err(_) => Err(Status::internal("failed to get keys")),
1422 }
1423 }
1424
1425 async fn subscribe(
1427 &self,
1428 request: Request<SubscribeRequest>,
1429 ) -> Result<Response<Self::SubscribeStream>, Status> {
1430 let prefix = request.into_inner().prefix;
1431
1432 let (sender, receiver) = mpsc::channel(32);
1433 let mut subscribers = self.subscribers.lock().await;
1434 if !subscribers.contains_key(&prefix) {
1435 subscribers.insert(prefix.clone(), Vec::new());
1436 }
1437
1438 let senders = subscribers.get_mut(&prefix).unwrap();
1439 senders.push(sender);
1440
1441 let stream = ReceiverStream::new(receiver).map(Ok);
1442
1443 Ok(Response::new(Box::pin(stream)))
1444 }
1445}