Skip to main content

skiff_rs/
skiff.rs

1pub(crate) mod skiff_proto {
2    tonic::include_proto!("skiff");
3}
4
5use crate::error::Error;
6use rand::{rngs::StdRng, Rng, SeedableRng};
7use serde::{Deserialize, Serialize};
8use skiff_proto::{
9    skiff_client::SkiffClient, skiff_server::SkiffServer, DeleteReply, DeleteRequest, GetReply,
10    GetRequest, InsertReply, InsertRequest,
11};
12use skiff_proto::{
13    Empty, EntryReply, EntryRequest, ListKeysReply, ListKeysRequest, PrefixReply, ServerReply,
14    ServerRequest, SubscribeReply, SubscribeRequest, VoteReply, VoteRequest,
15};
16use std::cmp::min;
17use std::pin::Pin;
18use std::time::Duration;
19use std::{collections::HashMap, str::FromStr};
20use std::{
21    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
22    sync::Arc,
23};
24use tokio_stream::wrappers::ReceiverStream;
25use tokio_stream::{Stream, StreamExt};
26
27use tokio::sync::{mpsc, watch, Mutex, Notify};
28use tonic::{transport::Channel, Request, Response, Status};
29use tracing::{debug, info, trace};
30use uuid::Uuid;
31
32const RAFT_META_TREE: &str = "__raft_meta";
33const RAFT_LOG_TREE: &str = "__raft_log";
34const KEY_CURRENT_TERM: &[u8] = b"current_term";
35const KEY_VOTED_FOR: &[u8] = b"voted_for";
36const KEY_LAST_APPLIED: &[u8] = b"last_applied";
37const KEY_NODE_ID: &[u8] = b"node_id";
38
39#[derive(Serialize, Deserialize)]
40struct PersistedLog {
41    index: u32,
42    term: u32,
43    action: Action,
44}
45
46fn persist_hard_state(conn: &sled::Db, term: u32, voted_for: Option<Uuid>) -> Result<(), Error> {
47    let meta = conn.open_tree(RAFT_META_TREE)?;
48    meta.insert(KEY_CURRENT_TERM, &term.to_be_bytes())?;
49    match voted_for {
50        Some(id) => {
51            meta.insert(KEY_VOTED_FOR, id.as_bytes().as_slice())?;
52        }
53        None => {
54            meta.remove(KEY_VOTED_FOR)?;
55        }
56    };
57    Ok(())
58}
59
60fn persist_log_entry(conn: &sled::Db, log: &Log) -> Result<(), Error> {
61    let tree = conn.open_tree(RAFT_LOG_TREE)?;
62    let persisted = PersistedLog {
63        index: log.index,
64        term: log.term,
65        action: log.action.clone(),
66    };
67    tree.insert(log.index.to_be_bytes(), bincode::serialize(&persisted)?)?;
68    Ok(())
69}
70
71fn truncate_log_from(conn: &sled::Db, from_index: u32) -> Result<(), Error> {
72    let tree = conn.open_tree(RAFT_LOG_TREE)?;
73    let keys: Vec<_> = tree
74        .range(from_index.to_be_bytes()..)
75        .keys()
76        .collect::<Result<Vec<_>, _>>()?;
77    for key in keys {
78        tree.remove(key)?;
79    }
80    Ok(())
81}
82
83fn persist_last_applied(conn: &sled::Db, last_applied: u32) -> Result<(), Error> {
84    let meta = conn.open_tree(RAFT_META_TREE)?;
85    meta.insert(KEY_LAST_APPLIED, &last_applied.to_be_bytes())?;
86    Ok(())
87}
88
89fn load_or_create_id(conn: &sled::Db) -> Result<Uuid, Error> {
90    let meta = conn.open_tree(RAFT_META_TREE)?;
91    if let Some(bytes) = meta.get(KEY_NODE_ID)? {
92        return Uuid::from_slice(bytes.as_ref()).map_err(|_| Error::DeserializeFailed);
93    }
94    let id = Uuid::new_v4();
95    meta.insert(KEY_NODE_ID, id.as_bytes().as_slice())?;
96    conn.flush()?;
97    Ok(id)
98}
99
100fn load_raft_state(conn: &sled::Db) -> Result<(u32, Option<Uuid>, Vec<Log>, u32), Error> {
101    let meta = conn.open_tree(RAFT_META_TREE)?;
102
103    let current_term = meta
104        .get(KEY_CURRENT_TERM)?
105        .map(|b| u32::from_be_bytes(b.as_ref().try_into().unwrap_or([0; 4])))
106        .unwrap_or(0);
107
108    let voted_for = meta
109        .get(KEY_VOTED_FOR)?
110        .and_then(|b| Uuid::from_slice(b.as_ref()).ok());
111
112    let last_applied = meta
113        .get(KEY_LAST_APPLIED)?
114        .map(|b| u32::from_be_bytes(b.as_ref().try_into().unwrap_or([0; 4])))
115        .unwrap_or(0);
116
117    let log_tree = conn.open_tree(RAFT_LOG_TREE)?;
118    let mut log = Vec::new();
119    for result in log_tree.iter() {
120        let (_, value) = result?;
121        let persisted: PersistedLog = bincode::deserialize(&value)?;
122        log.push(Log {
123            index: persisted.index,
124            term: persisted.term,
125            action: persisted.action,
126            committed: Arc::new((Mutex::new(true), Notify::new())),
127        });
128    }
129
130    Ok((current_term, voted_for, log, last_applied))
131}
132
133#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134enum Action {
135    Insert(String, Vec<u8>),
136    Delete(String),
137    Configure(HashMap<Uuid, Ipv4Addr>),
138}
139
140#[derive(Debug, Clone)]
141struct Log {
142    index: u32,
143    term: u32,
144    action: Action,
145    committed: Arc<(Mutex<bool>, Notify)>,
146}
147
148/// The current role of a node in the Raft protocol.
149#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
150pub enum ElectionState {
151    /// The node is campaigning for leadership.
152    Candidate,
153    /// The node is the current cluster leader and accepts writes.
154    Leader,
155    /// The node is a follower of the leader identified by the inner [`Uuid`].
156    /// A `Uuid::nil()` means no leader has been elected yet.
157    Follower(Uuid),
158}
159
160#[derive(Debug, Clone)]
161struct State {
162    election_state: ElectionState,
163    current_term: u32,
164    voted_for: Option<Uuid>,
165    committed_index: u32,
166    last_applied: u32,
167
168    peer_clients: HashMap<Uuid, Arc<Mutex<SkiffClient<Channel>>>>,
169
170    next_index: HashMap<Uuid, u32>,
171    match_index: HashMap<Uuid, u32>,
172
173    log: Vec<Log>,
174    conn: sled::Db,
175}
176
177/// A single node in a skiff cluster.
178///
179/// Construct one with [`Builder`](crate::Builder) and call [`start`](Skiff::start)
180/// to begin serving requests.  The node spawns background tasks for leader
181/// election and heartbeat management; call [`shutdown`](Skiff::shutdown) before
182/// dropping to allow those tasks and the sled database to close cleanly.
183#[derive(Debug, Clone)]
184pub struct Skiff {
185    id: Uuid,
186    address: Ipv4Addr,
187    port: u16,
188    state: Arc<Mutex<State>>,
189    tx_entries: Arc<Mutex<mpsc::Sender<u8>>>,
190    rx_entries: Arc<Mutex<mpsc::Receiver<u8>>>,
191    subscribers: Arc<Mutex<HashMap<String, Vec<mpsc::Sender<SubscribeReply>>>>>,
192    shutdown_tx: Arc<watch::Sender<bool>>,
193    shutdown_rx: watch::Receiver<bool>,
194}
195
196impl Skiff {
197    pub(crate) fn new(
198        address: Ipv4Addr,
199        port: u16,
200        data_dir: String,
201        peers: Vec<Ipv4Addr>,
202    ) -> Result<Self, Error> {
203        let conn = sled::open(data_dir)?;
204        let id = load_or_create_id(&conn)?;
205        let (tx_entries, rx_entries) = mpsc::channel(32);
206        let (shutdown_tx, shutdown_rx) = watch::channel(false);
207
208        let (current_term, voted_for, persisted_log, last_applied) = load_raft_state(&conn)?;
209
210        // The index-0 entry is an in-memory sentinel carrying the bootstrap cluster config derived
211        // from the peers argument. Persisted log entries (index >= 1) are appended on top and take
212        // precedence in get_cluster() via rev().find(), so the latest Configure entry always wins.
213        let mut cluster: HashMap<Uuid, Ipv4Addr> = peers
214            .into_iter()
215            .map(|addr| (Uuid::new_v4(), addr))
216            .collect();
217
218        cluster.insert(id, address);
219
220        let mut log = vec![Log {
221            term: 0,
222            index: 0,
223            action: Action::Configure(cluster),
224            committed: Arc::new((Mutex::new(true), Notify::new())),
225        }];
226        log.extend(persisted_log);
227
228        Ok(Skiff {
229            id,
230            address,
231            port,
232            state: Arc::new(Mutex::new(State {
233                election_state: ElectionState::Follower(Uuid::nil()),
234                current_term,
235                voted_for,
236                // Start conservative: committed_index catches up via leader heartbeats
237                committed_index: last_applied,
238                last_applied,
239                peer_clients: HashMap::new(),
240                next_index: HashMap::new(),
241                match_index: HashMap::new(),
242                log,
243                conn,
244            })),
245            tx_entries: Arc::new(Mutex::new(tx_entries)),
246            rx_entries: Arc::new(Mutex::new(rx_entries)),
247            subscribers: Arc::new(Mutex::new(HashMap::new())),
248            shutdown_tx: Arc::new(shutdown_tx),
249            shutdown_rx,
250        })
251    }
252
253    // todo: initializing a new cluster with a known config without needing to send add_server rpc
254
255    /// Return the stable, persistent UUID that identifies this node.
256    ///
257    /// The ID is generated on first startup and stored in sled so it survives
258    /// restarts.
259    pub fn get_id(&self) -> Uuid {
260        self.id
261    }
262
263    /// Signal the node to stop its background tasks and close the gRPC server.
264    ///
265    /// This must be called before dropping the node in tests or applications
266    /// that restart nodes, because sled holds a file lock that is only released
267    /// once every `Arc` clone of the database handle has been dropped.
268    /// `shutdown` triggers a clean teardown so those clones are released
269    /// promptly.
270    pub fn shutdown(&self) {
271        let _ = self.shutdown_tx.send(true);
272    }
273
274    /// Return the IPv4 address this node is bound to.
275    pub fn get_address(&self) -> Ipv4Addr {
276        self.address
277    }
278
279    /// Return `true` if the cluster has an active leader.
280    ///
281    /// This is `true` when this node is the leader, or when it is a follower
282    /// that has acknowledged a specific leader.  It is `false` during initial
283    /// startup or while an election is in progress.
284    pub async fn is_leader_elected(&self) -> bool {
285        let election_state = self.get_election_state().await;
286        match election_state {
287            ElectionState::Leader => true,
288            ElectionState::Candidate => false,
289            ElectionState::Follower(id) => {
290                if Uuid::nil() == id {
291                    return false;
292                }
293
294                true
295            }
296        }
297    }
298
299    /// Block until the cluster has an elected leader, or until `timeout` elapses.
300    ///
301    /// Polls [`is_leader_elected`](Skiff::is_leader_elected) every 50 ms.
302    /// Call this after spawning [`start`](Skiff::start) to ensure the cluster
303    /// is ready before connecting a client.
304    ///
305    /// # Errors
306    ///
307    /// Returns [`Error::LeaderElectionTimeout`] if no leader is elected within
308    /// `timeout`.
309    ///
310    /// # Example
311    ///
312    /// ```no_run
313    /// use skiff_rs::Builder;
314    /// use std::time::Duration;
315    ///
316    /// # #[tokio::main]
317    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
318    /// let node = Builder::new()
319    ///     .set_dir("/tmp/my-node")
320    ///     .bind("127.0.0.1".parse()?)
321    ///     .build()?;
322    ///
323    /// let node_ref = node.clone();
324    /// tokio::spawn(async move { node_ref.start().await });
325    ///
326    /// // Block until a leader is elected before connecting a client.
327    /// node.wait_for_leader(Duration::from_secs(2)).await?;
328    /// # Ok(())
329    /// # }
330    /// ```
331    pub async fn wait_for_leader(&self, timeout: Duration) -> Result<(), Error> {
332        tokio::time::timeout(timeout, async {
333            loop {
334                if self.is_leader_elected().await {
335                    return;
336                }
337                tokio::time::sleep(Duration::from_millis(50)).await;
338            }
339        })
340        .await
341        .map_err(|_| Error::LeaderElectionTimeout)
342    }
343
344    /// Return the current cluster membership as a map of node ID → address.
345    ///
346    /// The membership is derived from the most recent `Configure` entry in the
347    /// Raft log.
348    ///
349    /// # Errors
350    ///
351    /// Returns [`Error::MissingClusterConfig`] if no configuration entry is
352    /// found (this should not happen under normal operation).
353    pub async fn get_cluster(&self) -> Result<HashMap<Uuid, Ipv4Addr>, Error> {
354        let config = match self
355            .state
356            .lock()
357            .await
358            .log
359            .iter()
360            .rev()
361            .find(|log| matches!(log.action, Action::Configure(_)))
362        {
363            Some(log) => match &log.action {
364                Action::Configure(config) => config.clone(),
365                _ => return Err(Error::MissingClusterConfig),
366            },
367            _ => return Err(Error::MissingClusterConfig),
368        };
369
370        Ok(config)
371    }
372
373    async fn get_peers(&self) -> HashMap<Uuid, Ipv4Addr> {
374        self.get_cluster()
375            .await
376            // Todo: figure out if unwrap is sufficient or if this needs to return Result
377            // affects get_peer_client()
378            .unwrap()
379            .into_iter()
380            .filter(|(_, addr)| *addr != self.address)
381            .collect()
382    }
383
384    async fn get_peer_client(
385        &self,
386        peer: &Uuid,
387    ) -> Result<Arc<Mutex<SkiffClient<Channel>>>, Error> {
388        let peers = self.get_peers().await;
389
390        if !peers.contains_key(peer) {
391            return Err(Error::PeerNotFound);
392        }
393
394        if let Some(client) = self.state.lock().await.peer_clients.get(peer) {
395            return Ok(client.clone());
396        }
397
398        match SkiffClient::connect(format!(
399            "http://{}",
400            SocketAddrV4::new(*peers.get(peer).unwrap(), self.port)
401        ))
402        .await
403        {
404            Ok(client) => {
405                let arc = Arc::new(Mutex::new(client));
406                self.state
407                    .lock()
408                    .await
409                    .peer_clients
410                    .insert(peer.to_owned(), arc.clone());
411                Ok(arc)
412            }
413            Err(_) => Err(Error::PeerConnectFailed),
414        }
415    }
416
417    async fn drop_peer_client(&self, id: &Uuid) {
418        let mut lock = self.state.lock().await;
419        let _ = lock.peer_clients.remove(id);
420    }
421
422    /// Return the node's current [`ElectionState`].
423    pub async fn get_election_state(&self) -> ElectionState {
424        self.state.lock().await.election_state.clone()
425    }
426
427    async fn set_election_state(&self, state: ElectionState) {
428        self.state.lock().await.election_state = state;
429    }
430
431    async fn get_current_term(&self) -> u32 {
432        self.state.lock().await.current_term
433    }
434
435    async fn increment_term(&self) {
436        self.state.lock().await.current_term += 1;
437    }
438
439    async fn set_current_term(&self, term: u32) {
440        self.state.lock().await.current_term = term;
441    }
442
443    // todo: this is a slightly naive approach, retrieving index for log[-1], might not be 100% accurate
444    // should be, since index increases monotonically, but might not in the event of cluster issues?
445    async fn get_last_log_index(&self) -> u32 {
446        self.state
447            .lock()
448            .await
449            .log
450            .last()
451            .map(|last| last.index)
452            .unwrap_or(0)
453    }
454
455    async fn get_last_log_term(&self) -> u32 {
456        self.state
457            .lock()
458            .await
459            .log
460            .last()
461            .map(|last| last.term)
462            .unwrap_or(0)
463    }
464
465    async fn get_commit_index(&self) -> u32 {
466        self.state.lock().await.committed_index
467    }
468
469    async fn vote_for(&self, candidate_id: Option<Uuid>) {
470        self.state.lock().await.voted_for = candidate_id;
471    }
472
473    async fn get_voted_for(&self) -> Option<Uuid> {
474        self.state.lock().await.voted_for
475    }
476
477    async fn log(&self, action: Action) -> Arc<(Mutex<bool>, Notify)> {
478        let mut lock = self.state.lock().await;
479        let current_term = lock.current_term;
480        let last_index = lock.log.last().map(|last: &Log| last.index).unwrap_or(0);
481
482        let commit_pair = Arc::new((Mutex::new(false), Notify::new()));
483        lock.log.push(Log {
484            index: last_index + 1,
485            term: current_term,
486            action,
487            committed: commit_pair.clone(),
488        });
489
490        commit_pair
491    }
492
493    async fn get_prefixes(&self) -> Result<Vec<String>, Error> {
494        match self.state.lock().await.conn.get("trees")? {
495            Some(tree_vec) => match bincode::deserialize::<Vec<String>>(&tree_vec) {
496                Ok(trees) => Ok(trees),
497                Err(_) => Err(Error::DeserializeFailed),
498            },
499            None => Ok(vec![]),
500        }
501    }
502
503    async fn list_keys(&self, prefix: &str) -> Result<Vec<String>, Error> {
504        let mut keys = vec![];
505        let trees = self.get_prefixes().await?;
506
507        // Todo: de-duplicate code
508        if prefix.is_empty() || prefix == "/" {
509            let tree = self.state.lock().await.conn.open_tree("base")?;
510
511            tree.into_iter()
512                .keys()
513                .map(|key| String::from_utf8(key.unwrap().to_vec()).unwrap())
514                .for_each(|key| {
515                    keys.push(key);
516                });
517        }
518
519        for tree_name in &trees {
520            let prefix_trimmed = match prefix.ends_with("/") {
521                true => prefix.trim_end_matches("/"),
522                false => prefix,
523            };
524
525            if tree_name.starts_with(prefix_trimmed) {
526                let tree = self
527                    .state
528                    .lock()
529                    .await
530                    .conn
531                    .open_tree(format!("base_{}", tree_name.replace("/", "_")))?;
532
533                tree.into_iter()
534                    .keys()
535                    .map(|key| String::from_utf8(key.unwrap().to_vec()).unwrap())
536                    .for_each(|key| {
537                        keys.push(format!("{}/{}", tree_name, key));
538                    });
539            }
540        }
541
542        Ok(keys)
543    }
544
545    async fn get_logs(&self, peer: &Uuid) -> (u32, u32, Vec<skiff_proto::Log>) {
546        let lock = self.state.lock().await;
547        let log_next_index: &u32 = lock.next_index.get(peer).unwrap();
548        let mut prev_log_index = 0;
549        let mut prev_log_term = 0;
550        let mut new_logs: Vec<skiff_proto::Log> = vec![];
551
552        for log in &lock.log {
553            // Get latest log that should already be in peer's log
554            if log.index < *log_next_index && log.index > prev_log_index {
555                prev_log_index = log.index;
556                prev_log_term = log.term;
557            } else if log.index >= *log_next_index {
558                new_logs.push(match &log.action {
559                    Action::Insert(key, value) => skiff_proto::Log {
560                        index: log.index,
561                        term: log.term,
562                        action: skiff_proto::Action::Insert as i32,
563                        key: key.clone(),
564                        value: Some(value.clone()),
565                    },
566                    Action::Delete(key) => skiff_proto::Log {
567                        index: log.index,
568                        term: log.term,
569                        action: skiff_proto::Action::Delete as i32,
570                        key: key.clone(),
571                        value: None,
572                    },
573                    Action::Configure(config) => skiff_proto::Log {
574                        index: log.index,
575                        term: log.term,
576                        action: skiff_proto::Action::Configure as i32,
577                        key: "cluster".to_string(),
578                        value: Some(bincode::serialize(&config).unwrap()),
579                    },
580                });
581            }
582        }
583
584        (prev_log_index, prev_log_term, new_logs)
585    }
586
587    async fn commit_logs(&self) -> Result<(), Error> {
588        let committed_index = self.state.lock().await.committed_index;
589        let last_applied = self.state.lock().await.last_applied;
590        if committed_index <= last_applied {
591            return Ok(());
592        }
593
594        trace!("committing log entries");
595        let new_logs: Vec<(usize, Action)> = self
596            .state
597            .lock()
598            .await
599            .log
600            .iter()
601            .enumerate()
602            .filter(|(_, log)| log.index > last_applied)
603            .map(|(i, log)| (i, log.action.clone()))
604            .collect();
605
606        for (i, action) in new_logs {
607            match action {
608                // Todo: handle and document difference between "key" and "/key"
609                Action::Insert(key, value) => {
610                    let full_key = key.clone();
611                    let mut tree_parts: Vec<&str> = key.split("/").collect();
612                    let key = tree_parts.pop().unwrap();
613
614                    let prefix = tree_parts.join("/");
615                    let tree_name =
616                        match prefix.len() {
617                            0 => "base".to_string(),
618                            _ => {
619                                let _ = self.state.lock().await.conn.update_and_fetch(
620                                    "trees",
621                                    |trees| match trees {
622                                        Some(tree_vec) => {
623                                            let mut updated_tree_vec =
624                                                bincode::deserialize::<Vec<String>>(tree_vec)
625                                                    .unwrap();
626
627                                            if !updated_tree_vec.contains(&prefix) {
628                                                updated_tree_vec.push(prefix.clone());
629                                            }
630
631                                            Some(bincode::serialize(&updated_tree_vec).unwrap())
632                                        }
633                                        None => Some(bincode::serialize(&vec![&prefix]).unwrap()),
634                                    },
635                                );
636
637                                format!("base_{}", &prefix.replace("/", "_"))
638                            }
639                        };
640
641                    // Todo: maybe alert subscribers on delete
642                    let mut subscribers = self.subscribers.lock().await;
643                    for (sub_prefix, senders) in subscribers.iter_mut() {
644                        // Todo: this could cause issues when the provided prefix is parent/ and tree is ex. parents/ or parent1/, etc
645                        if prefix.starts_with(sub_prefix.trim_end_matches("/")) {
646                            let mut live_senders = Vec::new();
647                            for sender in senders.drain(..) {
648                                if sender
649                                    .send(SubscribeReply {
650                                        key: full_key.clone(),
651                                        action: skiff_proto::Action::Insert as i32,
652                                        value: Some(value.clone()),
653                                    })
654                                    .await
655                                    .is_ok()
656                                {
657                                    live_senders.push(sender);
658                                } // else: receiver dropped, discard sender
659                            }
660                            *senders = live_senders;
661                        }
662                    }
663
664                    let tree = self.state.lock().await.conn.open_tree(tree_name)?;
665                    tree.insert(key, value)?;
666                }
667                Action::Delete(key) => {
668                    let mut tree_parts: Vec<&str> = key.split("/").collect();
669                    let key = tree_parts.pop().unwrap();
670
671                    let prefix = tree_parts.join("/");
672                    let tree_name = match prefix.len() {
673                        0 => "base".to_string(),
674                        _ => format!("base_{}", &prefix.replace("/", "_")),
675                    };
676
677                    let trees = self.state.lock().await.conn.tree_names();
678                    if trees.contains(&tree_name.as_bytes().into()) {
679                        let lock = self.state.lock().await;
680                        let tree = lock.conn.open_tree(&tree_name)?;
681                        tree.remove(key)?;
682
683                        if tree.is_empty() {
684                            let _ = lock.conn.drop_tree(&tree_name);
685                            let _ = lock.conn.update_and_fetch("trees", |trees| match trees {
686                                Some(tree_vec) => {
687                                    let mut updated_tree_vec =
688                                        bincode::deserialize::<Vec<String>>(tree_vec).unwrap();
689
690                                    trace!(trees = ?updated_tree_vec, name = %tree_name, "dropping tree");
691                                    if let Some(index) =
692                                        updated_tree_vec.iter().position(|name| name == &prefix)
693                                    {
694                                        updated_tree_vec.remove(index);
695                                    }
696
697                                    Some(bincode::serialize(&updated_tree_vec).unwrap())
698                                }
699                                None => Some(bincode::serialize(&Vec::<String>::new()).unwrap()),
700                            });
701                        }
702                    }
703                }
704                Action::Configure(config) => {
705                    // Todo: also save committed_index and last_applied for resuming operation later
706                    self.state
707                        .lock()
708                        .await
709                        .conn
710                        .insert("cluster", bincode::serialize(&config).unwrap())?;
711                }
712            }
713
714            if let Some(log) = self.state.lock().await.log.get_mut(i) {
715                let (committed, notify) = &*log.committed;
716                *committed.lock().await = true;
717                notify.notify_one();
718            }
719        }
720
721        let conn = self.state.lock().await.conn.clone();
722        persist_last_applied(&conn, committed_index)?;
723        conn.flush_async().await.map_err(Error::SledError)?;
724        self.state.lock().await.last_applied = committed_index;
725
726        Ok(())
727    }
728
729    async fn reset_heartbeat_timer(&self) {
730        let _ = self.tx_entries.lock().await.send(1).await; // Can be anything
731    }
732
733    fn initialize_service(&self) -> SkiffServer<Skiff> {
734        let skiff = self.clone();
735        drop(tokio::spawn(async move {
736            // Todo: when we're restoring a cluster from previous operation (ex. after outage / migration)
737            // the cluster len for all nodes will be > 1, but no leader exists yet, so add_server rpc fails.
738            // Might be ok if leader election happens subsequently, but this should a) be verified
739            // and b) logic for this scenario should be made obvious
740            if skiff.get_cluster().await.unwrap().len() > 1 {
741                debug!("joining cluster");
742
743                let mut joined_cluster = false;
744                for id in skiff.get_peers().await.keys() {
745                    debug!(?id, "asking peer to add us to cluster");
746
747                    let mut request = Request::new(ServerRequest {
748                        id: skiff.id.to_string(),
749                        address: skiff.address.to_string(),
750                    });
751
752                    request.set_timeout(Duration::from_millis(300));
753
754                    let client_arc = skiff.get_peer_client(id).await.unwrap();
755                    let mut client = client_arc.lock().await;
756                    match client.add_server(request).await {
757                        Ok(response) => {
758                            let inner = response.into_inner();
759                            if inner.success {
760                                if let Some(cluster) = inner.cluster {
761                                    // Todo: this will get logged twice. Once here and once from append_entry
762                                    // It shouldn't be an issue, but this is an unecessary duplication
763                                    skiff
764                                        .log(Action::Configure(
765                                            bincode::deserialize(&cluster).unwrap(),
766                                        ))
767                                        .await;
768
769                                    joined_cluster = true;
770
771                                    break;
772                                }
773                            }
774                        }
775                        Err(_) => continue,
776                    }
777                }
778
779                if !joined_cluster {
780                    return Err(Error::ClusterJoinFailed);
781                }
782            }
783            // Todo: if we are lone node in cluster we can skip election timeout and make ourselves leader
784
785            let mut server1 = skiff.clone();
786            let _elections: tokio::task::JoinHandle<Result<(), Error>> = tokio::spawn(async move {
787                server1.election_manager().await?;
788                Ok(())
789            });
790
791            Ok(())
792        }));
793
794        SkiffServer::new(self.clone())
795    }
796
797    /// Start the node's gRPC server and block until [`shutdown`](Skiff::shutdown) is called.
798    ///
799    /// This method must be called (usually inside a `tokio::spawn`) for the
800    /// node to participate in the cluster.  It:
801    ///
802    /// 1. Spawns a background task that joins the cluster (if peers were
803    ///    provided) and then runs the election/heartbeat loop.
804    /// 2. Binds a tonic gRPC server on the configured address and port.
805    /// 3. Returns only after [`shutdown`](Skiff::shutdown) is called and all
806    ///    in-flight connections have been closed.
807    ///
808    /// # Errors
809    ///
810    /// Returns [`Error::RPCBindFailed`] if the gRPC server cannot bind.
811    pub async fn start(self) -> Result<(), Error> {
812        let service = self.initialize_service();
813        let mut shutdown = self.shutdown_rx.clone();
814
815        let bind_address = SocketAddr::new(self.address.into(), self.port);
816        tonic::transport::Server::builder()
817            .add_service(service)
818            .serve_with_shutdown(bind_address, async move {
819                // Wait for shutdown signal, then let tonic close all connections cleanly.
820                let _ = shutdown.changed().await;
821            })
822            .await?;
823
824        Ok(())
825    }
826
827    #[allow(unreachable_code)]
828    async fn election_manager(&mut self) -> Result<(), Error> {
829        let mut rng = StdRng::from_entropy();
830
831        let mut rx_entries = self.rx_entries.lock().await;
832        let mut shutdown = self.shutdown_rx.clone();
833
834        loop {
835            self.commit_logs().await?;
836
837            if self.get_election_state().await == ElectionState::Leader {
838                // todo: refactor this block:
839                tokio::select! {
840                    _ = tokio::time::sleep(Duration::from_millis(75)) => {
841
842                        // todo: move peer connection + sending to separate thread so connection timeout
843                        // doesn't result in election timeout
844                        let last_log_index = self.get_last_log_index().await;
845                        let committed_index = self.get_commit_index().await;
846                        let current_term = self.get_current_term().await;
847                        for peer in self.get_peers().await.keys() {
848                            let (peer_last_log_index, peer_last_log_term, entries) = self.get_logs(peer).await;
849                            let num_entries = entries.len();
850                            let mut request = Request::new(EntryRequest {
851                                term: current_term,
852                                leader_id: self.id.to_string(),
853                                prev_log_index: peer_last_log_index,
854                                prev_log_term: peer_last_log_term,
855                                entries,
856                                leader_commit: committed_index
857                            });
858                            // Todo: see if there's a more idiomatic way to set timeout
859                            request.set_timeout(Duration::from_millis(300));
860
861                            let client_arc = match self.get_peer_client(peer).await {
862                                Ok(c) => c,
863                                Err(_) => { self.drop_peer_client(peer).await; continue; }
864                            };
865                            let mut client = client_arc.lock().await;
866                            match client.append_entry(request).await {
867                                Ok(response) => {
868                                    match response.into_inner().success {
869                                        true => {
870                                            if num_entries > 0 {
871                                                if let Some(value) = self.state.lock().await.next_index.get_mut(peer) {
872                                                    *value = last_log_index + 1;
873                                                }
874
875                                                if let Some(value) = self.state.lock().await.match_index.get_mut(peer) {
876                                                    *value = last_log_index;
877                                                }
878                                            }
879                                        },
880                                        false => {
881                                            // Decrement next_index for peer
882                                            if let Some(value) = self.state.lock().await.next_index.get_mut(peer) {
883                                                *value -= 1;
884                                            }
885                                        }
886                                    }
887                                },
888                                Err(_) => { self.drop_peer_client(peer).await; }
889                            }
890                        }
891
892                        // Check if any logs should be commited
893                        // Iterating backwards from last log index to (commited_index + 1)
894                        let num_peers = self.get_cluster().await?.len();
895                        for i in ((committed_index + 1) ..=last_log_index).rev() {
896                            // The number of peers where at least log index i has been applied (+1 for leader)
897                            let num_peers_applied = self.state.lock().await.match_index.iter().filter(|(_, &applied_index)| applied_index >= i).collect::<Vec<_>>().len() + 1;
898                            // Make sure that i is in this term
899                            let correct_term = self.state.lock().await.log.iter().filter(|log| log.index == i && log.term != current_term).collect::<Vec<_>>().is_empty();
900                                if (num_peers_applied > num_peers / 2) && correct_term {
901                                self.state.lock().await.committed_index = i;
902                                break
903                            }
904                        }
905                    }
906                    _ = shutdown.changed() => { return Ok(()); }
907                }
908            } else {
909                tokio::select! {
910                    Some(1) = rx_entries.recv() => {
911                        continue;
912                    }
913
914                    _ = tokio::time::sleep(Duration::from_millis(rng.gen_range(150..300))) => {
915                        debug!("election timeout, starting election");
916                        self.run_election().await?;
917                    }
918
919                    _ = shutdown.changed() => { return Ok(()); }
920                };
921            }
922        }
923
924        Ok(())
925    }
926
927    async fn run_election(&self) -> Result<(), Error> {
928        self.set_election_state(ElectionState::Candidate).await;
929        self.increment_term().await;
930        self.vote_for(Some(self.id)).await;
931        let term = self.get_current_term().await;
932        let state = self.get_election_state().await;
933        debug!(?state, term, "starting election");
934
935        let mut num_votes: u32 = 1; // including self
936
937        for peer in self.get_peers().await.keys() {
938            trace!(?peer, "requesting vote");
939
940            let mut request = Request::new(VoteRequest {
941                term: self.get_current_term().await,
942                candidate_id: self.id.to_string(),
943                last_log_index: self.get_last_log_index().await,
944                last_log_term: self.get_last_log_term().await,
945            });
946            request.set_timeout(Duration::from_millis(300));
947
948            let client = match self.get_peer_client(peer).await {
949                Ok(c) => c,
950                Err(_) => {
951                    self.drop_peer_client(peer).await;
952                    continue;
953                }
954            };
955            let response = match client.lock().await.request_vote(request).await {
956                Ok(response) => response.into_inner(),
957                Err(_) => {
958                    self.drop_peer_client(peer).await;
959                    continue;
960                }
961            };
962
963            if response.vote_granted {
964                debug!(peer = ?&peer, "received vote");
965                num_votes += 1;
966            }
967        }
968
969        if num_votes > self.get_cluster().await?.len() as u32 / 2 {
970            info!("elected leader");
971            self.set_election_state(ElectionState::Leader).await;
972            self.vote_for(None).await;
973
974            let last_log_index = self.get_last_log_index().await;
975            let last_log_term = self.get_last_log_term().await;
976            let committed_index = self.get_commit_index().await;
977            let current_term = self.get_current_term().await;
978
979            {
980                let peers = self.get_peers().await;
981                let mut lock = self.state.lock().await;
982                lock.next_index = peers
983                    .keys()
984                    .map(|peer_id| (*peer_id, last_log_index + 1))
985                    .collect::<HashMap<Uuid, u32>>();
986                lock.match_index = peers
987                    .into_keys()
988                    .map(|peer_id| (peer_id, 0))
989                    .collect::<HashMap<Uuid, u32>>();
990            }
991
992            // Send empty heartbeat
993            for id in self.get_peers().await.keys() {
994                let mut request = Request::new(EntryRequest {
995                    term: current_term,
996                    leader_id: self.id.to_string(),
997                    prev_log_index: last_log_index,
998                    prev_log_term: last_log_term,
999                    entries: vec![],
1000                    leader_commit: committed_index,
1001                });
1002                request.set_timeout(Duration::from_millis(300));
1003
1004                let client_arc = match self.get_peer_client(id).await {
1005                    Ok(c) => c,
1006                    Err(_) => {
1007                        self.drop_peer_client(id).await;
1008                        continue;
1009                    }
1010                };
1011                let mut client = client_arc.lock().await;
1012                if client.append_entry(request).await.is_err() {
1013                    self.drop_peer_client(id).await;
1014                }
1015            }
1016        }
1017
1018        Ok(())
1019    }
1020}
1021
1022// todo: access control. checking if id matches leaderid, if request comes from within cluster, etc
1023#[tonic::async_trait]
1024impl skiff_proto::skiff_server::Skiff for Skiff {
1025    type SubscribeStream = Pin<Box<dyn Stream<Item = Result<SubscribeReply, Status>> + Send>>;
1026
1027    async fn request_vote(
1028        &self,
1029        request: Request<VoteRequest>,
1030    ) -> Result<Response<VoteReply>, Status> {
1031        trace!("received vote request");
1032
1033        let current_term = self.get_current_term().await;
1034        let voted_for = self.get_voted_for().await;
1035        let last_log_index = self.get_last_log_index().await;
1036        let last_log_term = self.get_last_log_term().await;
1037        let conn = self.state.lock().await.conn.clone();
1038
1039        let vote_request = request.into_inner();
1040
1041        let candidate_id = Uuid::from_str(&vote_request.candidate_id)
1042            .map_err(|_| Status::invalid_argument("invalid candidate id"))?;
1043
1044        // Grant vote if:
1045        // - Candidate's term is higher (fresh term, voted_for resets), OR
1046        // - Same term and we haven't voted yet or already voted for this candidate
1047        let can_vote = vote_request.term > current_term
1048            || (vote_request.term == current_term
1049                && (voted_for.is_none() || voted_for == Some(candidate_id)));
1050
1051        if can_vote
1052            && vote_request.last_log_index >= last_log_index
1053            && vote_request.last_log_term >= last_log_term
1054        {
1055            debug!(?candidate_id, "granting vote");
1056
1057            // Persist hard state before responding
1058            persist_hard_state(&conn, vote_request.term, Some(candidate_id))
1059                .map_err(|_| Status::internal("failed to persist hard state"))?;
1060            conn.flush_async()
1061                .await
1062                .map_err(|_| Status::internal("failed to flush"))?;
1063
1064            self.vote_for(Some(candidate_id)).await;
1065            self.set_election_state(ElectionState::Follower(candidate_id))
1066                .await;
1067            self.set_current_term(vote_request.term).await;
1068
1069            return Ok(Response::new(VoteReply {
1070                term: vote_request.term,
1071                vote_granted: true,
1072            }));
1073        }
1074
1075        Ok(Response::new(VoteReply {
1076            term: current_term,
1077            vote_granted: false,
1078        }))
1079    }
1080
1081    async fn append_entry(
1082        &self,
1083        request: Request<EntryRequest>,
1084    ) -> Result<Response<EntryReply>, Status> {
1085        let entry_request = request.into_inner();
1086        let current_term = self.get_current_term().await;
1087        let conn = self.state.lock().await.conn.clone();
1088
1089        if entry_request.term < current_term {
1090            return Ok(Response::new(EntryReply {
1091                term: current_term,
1092                success: false,
1093            }));
1094        }
1095
1096        // Update term and clear voted_for if the leader has a newer term
1097        let term_changed = entry_request.term > current_term;
1098        if term_changed {
1099            self.set_current_term(entry_request.term).await;
1100            self.vote_for(None).await;
1101        }
1102
1103        // Confirmed that we're receiving requests from a verified leader
1104        self.set_election_state(ElectionState::Follower(
1105            Uuid::from_str(&entry_request.leader_id).unwrap(),
1106        ))
1107        .await;
1108        self.reset_heartbeat_timer().await;
1109
1110        if entry_request.prev_log_index > 0 {
1111            let mut found_matching_log = false;
1112            for log in &self.state.lock().await.log {
1113                if log.index == entry_request.prev_log_index
1114                    && log.term == entry_request.prev_log_term
1115                {
1116                    found_matching_log = true;
1117                    break;
1118                }
1119            }
1120
1121            if !found_matching_log {
1122                return Ok(Response::new(EntryReply {
1123                    term: entry_request.term,
1124                    success: false,
1125                }));
1126            }
1127        }
1128
1129        for new_log in &entry_request.entries {
1130            let mut drop_index: Option<u32> = None;
1131            for current_log in &self.state.lock().await.log {
1132                // Conflict: same index but different term — truncate from here
1133                if current_log.index == new_log.index && current_log.term != new_log.term {
1134                    drop_index = Some(current_log.index);
1135                }
1136            }
1137
1138            if let Some(drop_index) = drop_index {
1139                self.state
1140                    .lock()
1141                    .await
1142                    .log
1143                    .retain(|log| log.index < drop_index);
1144                truncate_log_from(&conn, drop_index)
1145                    .map_err(|_| Status::internal("failed to truncate log"))?;
1146            }
1147        }
1148
1149        let last_log_index = self.get_last_log_index().await;
1150        let new_term = entry_request.term;
1151        let mut appended_entries = false;
1152
1153        for new_log in entry_request.entries {
1154            if new_log.index > last_log_index {
1155                trace!("appending log entry");
1156                let log_entry = Log {
1157                    index: new_log.index,
1158                    term: new_log.term,
1159                    action: match skiff_proto::Action::try_from(new_log.action) {
1160                        Ok(skiff_proto::Action::Insert) => {
1161                            Action::Insert(new_log.key, new_log.value.unwrap())
1162                        }
1163                        Ok(skiff_proto::Action::Delete) => Action::Delete(new_log.key),
1164                        Ok(skiff_proto::Action::Configure) => Action::Configure(
1165                            bincode::deserialize(&new_log.value.unwrap()).unwrap(),
1166                        ),
1167                        Err(_) => return Err(Status::invalid_argument("Invalid action")),
1168                    },
1169                    committed: Arc::new((Mutex::new(false), Notify::new())),
1170                };
1171                persist_log_entry(&conn, &log_entry)
1172                    .map_err(|_| Status::internal("failed to persist log entry"))?;
1173                self.state.lock().await.log.push(log_entry);
1174                appended_entries = true;
1175            }
1176        }
1177
1178        // Flush to stable storage before responding — required by Raft safety
1179        if term_changed || appended_entries {
1180            if term_changed {
1181                persist_hard_state(&conn, new_term, None)
1182                    .map_err(|_| Status::internal("failed to persist hard state"))?;
1183            }
1184            conn.flush_async()
1185                .await
1186                .map_err(|_| Status::internal("failed to flush"))?;
1187        }
1188
1189        if entry_request.leader_commit > self.get_commit_index().await {
1190            self.state.lock().await.committed_index =
1191                min(entry_request.leader_commit, self.get_last_log_index().await);
1192        }
1193
1194        Ok(Response::new(EntryReply {
1195            term: new_term,
1196            success: true,
1197        }))
1198    }
1199
1200    async fn add_server(
1201        &self,
1202        request: Request<ServerRequest>,
1203    ) -> Result<Response<ServerReply>, Status> {
1204        let election_state = self.state.lock().await.election_state.clone();
1205        if let ElectionState::Follower(leader) = election_state {
1206            let client = self.get_peer_client(&leader).await;
1207            if let Ok(client_inner) = client {
1208                return client_inner.lock().await.add_server(request).await;
1209            }
1210
1211            return Err(Status::internal("failed to forward request to leader"));
1212        }
1213
1214        debug!("adding server to cluster");
1215        let new_server = request.into_inner();
1216        let new_uuid = Uuid::from_str(&new_server.id).unwrap();
1217
1218        let mut cluster_config: HashMap<Uuid, Ipv4Addr> = self.get_cluster().await.unwrap();
1219
1220        let id = Uuid::from_str(&new_server.id).unwrap();
1221        let addr = Ipv4Addr::from_str(&new_server.address).unwrap();
1222
1223        if let std::collections::hash_map::Entry::Vacant(e) = cluster_config.entry(id) {
1224            e.insert(addr);
1225            self.log(Action::Configure(cluster_config.clone())).await;
1226        }
1227
1228        let last_log_index = self.get_last_log_index().await;
1229        self.state
1230            .lock()
1231            .await
1232            .next_index
1233            .insert(new_uuid, last_log_index + 1);
1234        self.state.lock().await.match_index.insert(new_uuid, 0);
1235
1236        Ok(Response::new(ServerReply {
1237            success: true,
1238            cluster: Some(bincode::serialize(&cluster_config).unwrap()),
1239        }))
1240    }
1241
1242    async fn remove_server(
1243        &self,
1244        request: Request<ServerRequest>,
1245    ) -> Result<Response<ServerReply>, Status> {
1246        let election_state = self.state.lock().await.election_state.clone();
1247        if let ElectionState::Follower(leader) = election_state {
1248            let client = self.get_peer_client(&leader).await;
1249            if let Ok(client_inner) = client {
1250                return client_inner.lock().await.remove_server(request).await;
1251            }
1252            return Err(Status::internal("failed to forward request to leader"));
1253        }
1254
1255        let remove_request = request.into_inner();
1256        let remove_id = Uuid::from_str(&remove_request.id)
1257            .map_err(|_| Status::invalid_argument("invalid server id"))?;
1258
1259        let mut cluster_config = self
1260            .get_cluster()
1261            .await
1262            .map_err(|_| Status::internal("failed to get cluster config"))?;
1263
1264        if cluster_config.remove(&remove_id).is_none() {
1265            return Err(Status::not_found("server not found in cluster"));
1266        }
1267
1268        self.log(Action::Configure(cluster_config.clone())).await;
1269
1270        {
1271            let mut state = self.state.lock().await;
1272            state.next_index.remove(&remove_id);
1273            state.match_index.remove(&remove_id);
1274        }
1275        self.drop_peer_client(&remove_id).await;
1276
1277        Ok(Response::new(ServerReply {
1278            success: true,
1279            cluster: Some(
1280                bincode::serialize(&cluster_config)
1281                    .map_err(|_| Status::internal("failed to serialize cluster"))?,
1282            ),
1283        }))
1284    }
1285
1286    // Todo: maybe add watch_prefix function that communicates changes to clients
1287
1288    // Todo: Forwarding to the leader fails when... there is no leader, or when we are a candiate
1289    // This is an issue when calling add_server from a follower to a server before the latter has elected itself
1290
1291    async fn get(&self, request: Request<GetRequest>) -> Result<Response<GetReply>, Status> {
1292        // If follower, connect to leader and forward request
1293        // Todo: ideally get requests could be done locally w/o forwarding, which would improve performance
1294        // However, if the client makes an insert or delete request then immediately makes a get request,
1295        // the change could have been logged locally and committed by the leader without the change
1296        // being made to the state machine (sled) locally (until the subsequent append_entries call),
1297        // resulting in an outdated get. The current workaround is to just forward the request.
1298        let election_state = self.state.lock().await.election_state.clone();
1299        if let ElectionState::Follower(leader) = election_state {
1300            let client = self.get_peer_client(&leader).await;
1301            if let Ok(client_inner) = client {
1302                return client_inner.lock().await.get(request).await;
1303            }
1304
1305            return Err(Status::internal("failed to forward request to leader"));
1306        }
1307
1308        let get_request = request.into_inner();
1309        let mut tree_parts: Vec<&str> = get_request.key.split("/").collect();
1310        let key = tree_parts.pop().unwrap();
1311
1312        let mut tree_name = tree_parts.join("/");
1313        tree_name = match tree_name.len() {
1314            0 => "base".to_string(),
1315            _ => format!("base_{}", tree_name.replace("/", "_")),
1316        };
1317
1318        if let Ok(tree) = self.state.lock().await.conn.open_tree(tree_name) {
1319            let value = tree.get(key);
1320            match value {
1321                Ok(inner1) => match inner1 {
1322                    Some(data) => Ok(Response::new(GetReply {
1323                        value: Some(data.to_vec()),
1324                    })),
1325                    None => Ok(Response::new(GetReply { value: None })),
1326                },
1327                Err(_) => Err(Status::internal("failed to query sled db")),
1328            }
1329        } else {
1330            Err(Status::internal("failed to open sled tree"))
1331        }
1332    }
1333
1334    async fn insert(
1335        &self,
1336        request: Request<InsertRequest>,
1337    ) -> Result<Response<InsertReply>, Status> {
1338        // If follower, connect to leader and forward request
1339        let election_state = self.state.lock().await.election_state.clone();
1340        if let ElectionState::Follower(leader) = election_state {
1341            let client = self.get_peer_client(&leader).await;
1342            if let Ok(client_inner) = client {
1343                return client_inner.lock().await.insert(request).await;
1344            }
1345
1346            return Err(Status::internal("failed to forward request to leader"));
1347        }
1348
1349        let insert_request = request.into_inner();
1350        let commit_arc = self
1351            .log(Action::Insert(insert_request.key, insert_request.value))
1352            .await;
1353
1354        let (_, notify) = &*commit_arc;
1355        tokio::time::timeout(Duration::from_secs(5), notify.notified())
1356            .await
1357            .map_err(|_| Status::deadline_exceeded("timed out waiting for commit"))?;
1358        Ok(Response::new(InsertReply { success: true }))
1359    }
1360
1361    async fn delete(
1362        &self,
1363        request: Request<DeleteRequest>,
1364    ) -> Result<Response<DeleteReply>, Status> {
1365        // If follower, connect to leader and forward request
1366        let election_state = self.state.lock().await.election_state.clone();
1367        if let ElectionState::Follower(leader) = election_state {
1368            let client = self.get_peer_client(&leader).await;
1369            if let Ok(client_inner) = client {
1370                return client_inner.lock().await.delete(request).await;
1371            }
1372
1373            return Err(Status::internal("failed to forward request to leader"));
1374        }
1375
1376        let delete_request = request.into_inner();
1377        let commit_arc = self.log(Action::Delete(delete_request.key)).await;
1378
1379        let (_, notify) = &*commit_arc;
1380        tokio::time::timeout(Duration::from_secs(5), notify.notified())
1381            .await
1382            .map_err(|_| Status::deadline_exceeded("timed out waiting for commit"))?;
1383        Ok(Response::new(DeleteReply { success: true }))
1384    }
1385
1386    async fn get_prefixes(&self, request: Request<Empty>) -> Result<Response<PrefixReply>, Status> {
1387        // If follower, connect to leader and forward request
1388        let election_state = self.state.lock().await.election_state.clone();
1389        if let ElectionState::Follower(leader) = election_state {
1390            let client = self.get_peer_client(&leader).await;
1391            if let Ok(client_inner) = client {
1392                return client_inner.lock().await.get_prefixes(request).await;
1393            }
1394
1395            return Err(Status::internal("failed to forward request to leader"));
1396        }
1397
1398        match self.get_prefixes().await {
1399            Ok(prefixes) => Ok(Response::new(PrefixReply { prefixes })),
1400            Err(_) => Err(Status::internal("failed to get prefixes")),
1401        }
1402    }
1403
1404    async fn list_keys(
1405        &self,
1406        request: Request<ListKeysRequest>,
1407    ) -> Result<Response<ListKeysReply>, Status> {
1408        // If follower, connect to leader and forward request
1409        let election_state = self.state.lock().await.election_state.clone();
1410        if let ElectionState::Follower(leader) = election_state {
1411            let client = self.get_peer_client(&leader).await;
1412            if let Ok(client_inner) = client {
1413                return client_inner.lock().await.list_keys(request).await;
1414            }
1415
1416            return Err(Status::internal("failed to forward request to leader"));
1417        }
1418
1419        match self.list_keys(request.into_inner().prefix.as_str()).await {
1420            Ok(keys) => Ok(Response::new(ListKeysReply { keys })),
1421            Err(_) => Err(Status::internal("failed to get keys")),
1422        }
1423    }
1424
1425    // This shouldn't need forwarding to leader
1426    async fn subscribe(
1427        &self,
1428        request: Request<SubscribeRequest>,
1429    ) -> Result<Response<Self::SubscribeStream>, Status> {
1430        let prefix = request.into_inner().prefix;
1431
1432        let (sender, receiver) = mpsc::channel(32);
1433        let mut subscribers = self.subscribers.lock().await;
1434        if !subscribers.contains_key(&prefix) {
1435            subscribers.insert(prefix.clone(), Vec::new());
1436        }
1437
1438        let senders = subscribers.get_mut(&prefix).unwrap();
1439        senders.push(sender);
1440
1441        let stream = ReceiverStream::new(receiver).map(Ok);
1442
1443        Ok(Response::new(Box::pin(stream)))
1444    }
1445}