Skip to main content

zlayer_consensus/storage/
mem_store.rs

1//! In-memory storage implementation using the openraft v2 storage API.
2//!
3//! This store keeps all Raft log entries, votes, and state machine snapshots
4//! in memory using `BTreeMap` and `RwLock`. It is generic over the
5//! `RaftTypeConfig` so any application can plug in its own request/response types.
6//!
7//! **Not suitable for production** -- data is lost on process restart.
8//! Use [`RedbStore`](super::redb_store) for durable storage.
9
10use std::collections::BTreeMap;
11use std::fmt::Debug;
12use std::io::Cursor;
13use std::ops::RangeBounds;
14use std::sync::Arc;
15
16use openraft::log_id::RaftLogId;
17use openraft::storage::{LogFlushed, LogState, RaftLogStorage, RaftStateMachine, Snapshot};
18use openraft::{
19    Entry, EntryPayload, LogId, OptionalSend, RaftLogReader, RaftSnapshotBuilder, RaftTypeConfig,
20    SnapshotMeta, StorageError, StoredMembership, Vote,
21};
22use tokio::sync::RwLock;
23use tracing::debug;
24
25use crate::types::NodeId;
26
27// ---------------------------------------------------------------------------
28// Log reader (clone-able handle used by replication tasks)
29// ---------------------------------------------------------------------------
30
31/// A cloneable log reader for the in-memory store.
32pub struct MemLogReader<C: RaftTypeConfig<NodeId = NodeId>> {
33    log: Arc<RwLock<LogData<C>>>,
34}
35
36impl<C: RaftTypeConfig<NodeId = NodeId>> Clone for MemLogReader<C> {
37    fn clone(&self) -> Self {
38        Self {
39            log: Arc::clone(&self.log),
40        }
41    }
42}
43
44impl<C> RaftLogReader<C> for MemLogReader<C>
45where
46    C: RaftTypeConfig<NodeId = NodeId>,
47    C::Entry: Clone,
48{
49    async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
50        &mut self,
51        range: RB,
52    ) -> Result<Vec<C::Entry>, StorageError<NodeId>> {
53        let log = self.log.read().await;
54        let entries = log.entries.range(range).map(|(_, e)| e.clone()).collect();
55        Ok(entries)
56    }
57}
58
59// ---------------------------------------------------------------------------
60// Internal data structs
61// ---------------------------------------------------------------------------
62
63/// Internal log data protected by a `RwLock`.
64struct LogData<C: RaftTypeConfig<NodeId = NodeId>> {
65    last_purged_log_id: Option<LogId<NodeId>>,
66    entries: BTreeMap<u64, C::Entry>,
67    vote: Option<Vote<NodeId>>,
68    committed: Option<LogId<NodeId>>,
69}
70
71impl<C: RaftTypeConfig<NodeId = NodeId>> Default for LogData<C> {
72    fn default() -> Self {
73        Self {
74            last_purged_log_id: None,
75            entries: BTreeMap::new(),
76            vote: None,
77            committed: None,
78        }
79    }
80}
81
82/// Internal state-machine data protected by a `RwLock`.
83///
84/// The generic parameter `S` is the application's state type. It must be
85/// `Default + Clone + serde::Serialize + serde::de::DeserializeOwned`.
86pub struct SmData<C: RaftTypeConfig<NodeId = NodeId>, S> {
87    /// Last applied log id.
88    pub last_applied_log: Option<LogId<NodeId>>,
89    /// Last membership config.
90    pub last_membership: StoredMembership<NodeId, C::Node>,
91    /// Application-specific state.
92    pub state: S,
93    /// Current snapshot bytes (serialized `S`), if any.
94    pub current_snapshot: Option<StoredSnapshot<C>>,
95}
96
97impl<C: RaftTypeConfig<NodeId = NodeId>, S: Default> Default for SmData<C, S> {
98    fn default() -> Self {
99        Self {
100            last_applied_log: None,
101            last_membership: StoredMembership::default(),
102            state: S::default(),
103            current_snapshot: None,
104        }
105    }
106}
107
108/// A serialized snapshot stored in memory.
109pub struct StoredSnapshot<C: RaftTypeConfig> {
110    pub meta: SnapshotMeta<C::NodeId, C::Node>,
111    pub data: Vec<u8>,
112}
113
114impl<C: RaftTypeConfig> Clone for StoredSnapshot<C> {
115    fn clone(&self) -> Self {
116        Self {
117            meta: self.meta.clone(),
118            data: self.data.clone(),
119        }
120    }
121}
122
123// ---------------------------------------------------------------------------
124// MemLogStore -- implements RaftLogStorage
125// ---------------------------------------------------------------------------
126
127/// In-memory Raft log store (v2 API).
128///
129/// Implements `RaftLogReader` and `RaftLogStorage`.
130pub struct MemLogStore<C: RaftTypeConfig<NodeId = NodeId>> {
131    log: Arc<RwLock<LogData<C>>>,
132}
133
134impl<C: RaftTypeConfig<NodeId = NodeId>> MemLogStore<C> {
135    /// Create a new, empty log store.
136    #[must_use]
137    pub fn new() -> Self {
138        Self {
139            log: Arc::new(RwLock::new(LogData::default())),
140        }
141    }
142}
143
144impl<C: RaftTypeConfig<NodeId = NodeId>> Default for MemLogStore<C> {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150impl<C: RaftTypeConfig<NodeId = NodeId>> Clone for MemLogStore<C> {
151    fn clone(&self) -> Self {
152        Self {
153            log: Arc::clone(&self.log),
154        }
155    }
156}
157
158impl<C> RaftLogReader<C> for MemLogStore<C>
159where
160    C: RaftTypeConfig<NodeId = NodeId>,
161    C::Entry: Clone,
162{
163    async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
164        &mut self,
165        range: RB,
166    ) -> Result<Vec<C::Entry>, StorageError<NodeId>> {
167        let log = self.log.read().await;
168        let entries = log.entries.range(range).map(|(_, e)| e.clone()).collect();
169        Ok(entries)
170    }
171}
172
173impl<C> RaftLogStorage<C> for MemLogStore<C>
174where
175    C: RaftTypeConfig<NodeId = NodeId>,
176    C::Entry: Clone,
177{
178    type LogReader = MemLogReader<C>;
179
180    async fn get_log_state(&mut self) -> Result<LogState<C>, StorageError<NodeId>> {
181        let log = self.log.read().await;
182        let last = log
183            .entries
184            .iter()
185            .next_back()
186            .map(|(_, ent)| *ent.get_log_id());
187
188        Ok(LogState {
189            last_purged_log_id: log.last_purged_log_id,
190            last_log_id: last,
191        })
192    }
193
194    async fn get_log_reader(&mut self) -> Self::LogReader {
195        MemLogReader {
196            log: Arc::clone(&self.log),
197        }
198    }
199
200    async fn save_vote(&mut self, vote: &Vote<NodeId>) -> Result<(), StorageError<NodeId>> {
201        let mut log = self.log.write().await;
202        log.vote = Some(*vote);
203        Ok(())
204    }
205
206    async fn read_vote(&mut self) -> Result<Option<Vote<NodeId>>, StorageError<NodeId>> {
207        let log = self.log.read().await;
208        Ok(log.vote)
209    }
210
211    async fn save_committed(
212        &mut self,
213        committed: Option<LogId<NodeId>>,
214    ) -> Result<(), StorageError<NodeId>> {
215        let mut log = self.log.write().await;
216        log.committed = committed;
217        Ok(())
218    }
219
220    async fn read_committed(&mut self) -> Result<Option<LogId<NodeId>>, StorageError<NodeId>> {
221        let log = self.log.read().await;
222        Ok(log.committed)
223    }
224
225    async fn append<I>(
226        &mut self,
227        entries: I,
228        callback: LogFlushed<C>,
229    ) -> Result<(), StorageError<NodeId>>
230    where
231        I: IntoIterator<Item = C::Entry> + OptionalSend,
232        I::IntoIter: OptionalSend,
233    {
234        let mut log = self.log.write().await;
235        for entry in entries {
236            let idx = entry.get_log_id().index;
237            log.entries.insert(idx, entry);
238        }
239        // In-memory store: data is immediately "flushed".
240        callback.log_io_completed(Ok(()));
241        Ok(())
242    }
243
244    async fn truncate(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
245        let mut log = self.log.write().await;
246        let keys: Vec<u64> = log.entries.range(log_id.index..).map(|(k, _)| *k).collect();
247        for key in keys {
248            log.entries.remove(&key);
249        }
250        Ok(())
251    }
252
253    async fn purge(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
254        let mut log = self.log.write().await;
255        let keys: Vec<u64> = log
256            .entries
257            .range(..=log_id.index)
258            .map(|(k, _)| *k)
259            .collect();
260        for key in keys {
261            log.entries.remove(&key);
262        }
263        log.last_purged_log_id = Some(log_id);
264        Ok(())
265    }
266}
267
268// ---------------------------------------------------------------------------
269// MemStateMachine -- implements RaftStateMachine
270// ---------------------------------------------------------------------------
271
272/// In-memory Raft state machine (v2 API).
273///
274/// Generic over:
275/// - `C`: openraft `RaftTypeConfig`
276/// - `S`: application state type
277/// - `F`: a function `fn(&mut S, &C::D) -> C::R` that applies a log entry to the state.
278///
279/// The `apply_fn` approach keeps the state machine generic: the application provides
280/// a closure that knows how to mutate `S` given a log entry payload of type `C::D`.
281pub struct MemStateMachine<C, S, F>
282where
283    C: RaftTypeConfig<NodeId = NodeId>,
284    S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
285    F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
286{
287    sm: Arc<RwLock<SmData<C, S>>>,
288    apply_fn: Arc<F>,
289}
290
291impl<C, S, F> MemStateMachine<C, S, F>
292where
293    C: RaftTypeConfig<NodeId = NodeId>,
294    S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
295    F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
296{
297    /// Create a new state machine with the given apply function.
298    pub fn new(apply_fn: F) -> Self {
299        Self {
300            sm: Arc::new(RwLock::new(SmData::default())),
301            apply_fn: Arc::new(apply_fn),
302        }
303    }
304
305    /// Get a read handle to the inner state machine data (for reading application state).
306    #[must_use]
307    pub fn data(&self) -> Arc<RwLock<SmData<C, S>>> {
308        Arc::clone(&self.sm)
309    }
310}
311
312impl<C, S, F> Clone for MemStateMachine<C, S, F>
313where
314    C: RaftTypeConfig<NodeId = NodeId>,
315    S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
316    F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
317{
318    fn clone(&self) -> Self {
319        Self {
320            sm: Arc::clone(&self.sm),
321            apply_fn: Arc::clone(&self.apply_fn),
322        }
323    }
324}
325
326// -- Snapshot builder -------------------------------------------------------
327
328/// Snapshot builder for the in-memory state machine.
329pub struct MemSnapshotBuilder<C, S, F>
330where
331    C: RaftTypeConfig<NodeId = NodeId>,
332    S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
333    F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
334{
335    sm: Arc<RwLock<SmData<C, S>>>,
336    _phantom: std::marker::PhantomData<F>,
337}
338
339impl<C, S, F> RaftSnapshotBuilder<C> for MemSnapshotBuilder<C, S, F>
340where
341    C: RaftTypeConfig<NodeId = NodeId, SnapshotData = Cursor<Vec<u8>>>,
342    S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
343    F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
344{
345    async fn build_snapshot(&mut self) -> Result<Snapshot<C>, StorageError<NodeId>> {
346        let sm = self.sm.read().await;
347
348        let data = postcard2::to_vec(&sm.state).map_err(|e| {
349            StorageError::from_io_error(
350                openraft::ErrorSubject::StateMachine,
351                openraft::ErrorVerb::Read,
352                std::io::Error::other(e),
353            )
354        })?;
355
356        let snapshot_id = if let Some(ref last) = sm.last_applied_log {
357            format!("{}-{}", last.leader_id, last.index)
358        } else {
359            "0-0".to_string()
360        };
361
362        let meta = SnapshotMeta {
363            last_log_id: sm.last_applied_log,
364            last_membership: sm.last_membership.clone(),
365            snapshot_id,
366        };
367
368        debug!(
369            last_log_id = ?meta.last_log_id,
370            "Built in-memory snapshot"
371        );
372
373        Ok(Snapshot {
374            meta,
375            snapshot: Box::new(Cursor::new(data)),
376        })
377    }
378}
379
380// -- RaftStateMachine impl --------------------------------------------------
381
382impl<C, S, F> RaftStateMachine<C> for MemStateMachine<C, S, F>
383where
384    C: RaftTypeConfig<NodeId = NodeId, SnapshotData = Cursor<Vec<u8>>, Entry = Entry<C>>,
385    C::D: Clone + serde::Serialize + serde::de::DeserializeOwned,
386    C::R: Default,
387    S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
388    F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
389{
390    type SnapshotBuilder = MemSnapshotBuilder<C, S, F>;
391
392    async fn applied_state(
393        &mut self,
394    ) -> Result<(Option<LogId<NodeId>>, StoredMembership<NodeId, C::Node>), StorageError<NodeId>>
395    {
396        let sm = self.sm.read().await;
397        Ok((sm.last_applied_log, sm.last_membership.clone()))
398    }
399
400    async fn apply<I>(&mut self, entries: I) -> Result<Vec<C::R>, StorageError<NodeId>>
401    where
402        I: IntoIterator<Item = C::Entry> + OptionalSend,
403        I::IntoIter: OptionalSend,
404    {
405        let mut sm = self.sm.write().await;
406        let mut responses = Vec::new();
407
408        for entry in entries {
409            sm.last_applied_log = Some(entry.log_id);
410
411            match entry.payload {
412                EntryPayload::Normal(ref data) => {
413                    let resp = (self.apply_fn)(&mut sm.state, data);
414                    responses.push(resp);
415                }
416                EntryPayload::Membership(ref mem) => {
417                    sm.last_membership = StoredMembership::new(Some(entry.log_id), mem.clone());
418                    responses.push(C::R::default());
419                }
420                EntryPayload::Blank => {
421                    responses.push(C::R::default());
422                }
423            }
424        }
425
426        Ok(responses)
427    }
428
429    async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
430        MemSnapshotBuilder {
431            sm: Arc::clone(&self.sm),
432            _phantom: std::marker::PhantomData,
433        }
434    }
435
436    async fn begin_receiving_snapshot(
437        &mut self,
438    ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<NodeId>> {
439        Ok(Box::new(Cursor::new(Vec::new())))
440    }
441
442    async fn install_snapshot(
443        &mut self,
444        meta: &SnapshotMeta<NodeId, C::Node>,
445        snapshot: Box<Cursor<Vec<u8>>>,
446    ) -> Result<(), StorageError<NodeId>> {
447        let data = snapshot.into_inner();
448        let state: S = postcard2::from_bytes(&data).map_err(|e| {
449            StorageError::from_io_error(
450                openraft::ErrorSubject::Snapshot(None),
451                openraft::ErrorVerb::Read,
452                std::io::Error::other(e),
453            )
454        })?;
455
456        let mut sm = self.sm.write().await;
457        sm.last_applied_log = meta.last_log_id;
458        sm.last_membership = meta.last_membership.clone();
459        sm.state = state.clone();
460
461        // Store the snapshot
462        let snapshot_data = postcard2::to_vec(&state).map_err(|e| {
463            StorageError::from_io_error(
464                openraft::ErrorSubject::Snapshot(None),
465                openraft::ErrorVerb::Write,
466                std::io::Error::other(e),
467            )
468        })?;
469        sm.current_snapshot = Some(StoredSnapshot {
470            meta: meta.clone(),
471            data: snapshot_data,
472        });
473
474        debug!(
475            last_log_id = ?meta.last_log_id,
476            "Installed snapshot into in-memory state machine"
477        );
478
479        Ok(())
480    }
481
482    async fn get_current_snapshot(&mut self) -> Result<Option<Snapshot<C>>, StorageError<NodeId>> {
483        let sm = self.sm.read().await;
484        match &sm.current_snapshot {
485            Some(stored) => Ok(Some(Snapshot {
486                meta: stored.meta.clone(),
487                snapshot: Box::new(Cursor::new(stored.data.clone())),
488            })),
489            None => Ok(None),
490        }
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use openraft::storage::RaftLogStorage;
498
499    // Minimal type config for testing
500    openraft::declare_raft_types!(
501        pub TestTypeConfig:
502            D = String,
503            R = String,
504    );
505
506    #[tokio::test]
507    async fn test_log_store_empty_state() {
508        let mut store = MemLogStore::<TestTypeConfig>::new();
509        let state = RaftLogStorage::get_log_state(&mut store).await.unwrap();
510        assert!(state.last_purged_log_id.is_none());
511        assert!(state.last_log_id.is_none());
512    }
513
514    #[tokio::test]
515    async fn test_vote_round_trip() {
516        let mut store = MemLogStore::<TestTypeConfig>::new();
517
518        let vote = RaftLogStorage::read_vote(&mut store).await.unwrap();
519        assert!(vote.is_none());
520
521        let new_vote = Vote::new(1, 1);
522        RaftLogStorage::save_vote(&mut store, &new_vote)
523            .await
524            .unwrap();
525
526        let vote = RaftLogStorage::read_vote(&mut store).await.unwrap();
527        assert_eq!(vote, Some(new_vote));
528    }
529
530    #[tokio::test]
531    async fn test_state_machine_new() {
532        let sm = MemStateMachine::<TestTypeConfig, Vec<String>, _>::new(
533            |state: &mut Vec<String>, data: &String| {
534                state.push(data.clone());
535                format!("applied: {data}")
536            },
537        );
538
539        let data = sm.data();
540        {
541            let d = data.read().await;
542            assert!(d.state.is_empty());
543        }
544    }
545}