1use std::collections::BTreeMap;
11use std::fmt::Debug;
12use std::io::Cursor;
13use std::ops::RangeBounds;
14use std::sync::Arc;
15
16use openraft::log_id::RaftLogId;
17use openraft::storage::{LogFlushed, LogState, RaftLogStorage, RaftStateMachine, Snapshot};
18use openraft::{
19 Entry, EntryPayload, LogId, OptionalSend, RaftLogReader, RaftSnapshotBuilder, RaftTypeConfig,
20 SnapshotMeta, StorageError, StoredMembership, Vote,
21};
22use tokio::sync::RwLock;
23use tracing::debug;
24
25use crate::types::NodeId;
26
27pub struct MemLogReader<C: RaftTypeConfig<NodeId = NodeId>> {
33 log: Arc<RwLock<LogData<C>>>,
34}
35
36impl<C: RaftTypeConfig<NodeId = NodeId>> Clone for MemLogReader<C> {
37 fn clone(&self) -> Self {
38 Self {
39 log: Arc::clone(&self.log),
40 }
41 }
42}
43
44impl<C> RaftLogReader<C> for MemLogReader<C>
45where
46 C: RaftTypeConfig<NodeId = NodeId>,
47 C::Entry: Clone,
48{
49 async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
50 &mut self,
51 range: RB,
52 ) -> Result<Vec<C::Entry>, StorageError<NodeId>> {
53 let log = self.log.read().await;
54 let entries = log.entries.range(range).map(|(_, e)| e.clone()).collect();
55 Ok(entries)
56 }
57}
58
59struct LogData<C: RaftTypeConfig<NodeId = NodeId>> {
65 last_purged_log_id: Option<LogId<NodeId>>,
66 entries: BTreeMap<u64, C::Entry>,
67 vote: Option<Vote<NodeId>>,
68 committed: Option<LogId<NodeId>>,
69}
70
71impl<C: RaftTypeConfig<NodeId = NodeId>> Default for LogData<C> {
72 fn default() -> Self {
73 Self {
74 last_purged_log_id: None,
75 entries: BTreeMap::new(),
76 vote: None,
77 committed: None,
78 }
79 }
80}
81
82pub struct SmData<C: RaftTypeConfig<NodeId = NodeId>, S> {
87 pub last_applied_log: Option<LogId<NodeId>>,
89 pub last_membership: StoredMembership<NodeId, C::Node>,
91 pub state: S,
93 pub current_snapshot: Option<StoredSnapshot<C>>,
95}
96
97impl<C: RaftTypeConfig<NodeId = NodeId>, S: Default> Default for SmData<C, S> {
98 fn default() -> Self {
99 Self {
100 last_applied_log: None,
101 last_membership: StoredMembership::default(),
102 state: S::default(),
103 current_snapshot: None,
104 }
105 }
106}
107
108pub struct StoredSnapshot<C: RaftTypeConfig> {
110 pub meta: SnapshotMeta<C::NodeId, C::Node>,
111 pub data: Vec<u8>,
112}
113
114impl<C: RaftTypeConfig> Clone for StoredSnapshot<C> {
115 fn clone(&self) -> Self {
116 Self {
117 meta: self.meta.clone(),
118 data: self.data.clone(),
119 }
120 }
121}
122
123pub struct MemLogStore<C: RaftTypeConfig<NodeId = NodeId>> {
131 log: Arc<RwLock<LogData<C>>>,
132}
133
134impl<C: RaftTypeConfig<NodeId = NodeId>> MemLogStore<C> {
135 #[must_use]
137 pub fn new() -> Self {
138 Self {
139 log: Arc::new(RwLock::new(LogData::default())),
140 }
141 }
142}
143
144impl<C: RaftTypeConfig<NodeId = NodeId>> Default for MemLogStore<C> {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150impl<C: RaftTypeConfig<NodeId = NodeId>> Clone for MemLogStore<C> {
151 fn clone(&self) -> Self {
152 Self {
153 log: Arc::clone(&self.log),
154 }
155 }
156}
157
158impl<C> RaftLogReader<C> for MemLogStore<C>
159where
160 C: RaftTypeConfig<NodeId = NodeId>,
161 C::Entry: Clone,
162{
163 async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + OptionalSend>(
164 &mut self,
165 range: RB,
166 ) -> Result<Vec<C::Entry>, StorageError<NodeId>> {
167 let log = self.log.read().await;
168 let entries = log.entries.range(range).map(|(_, e)| e.clone()).collect();
169 Ok(entries)
170 }
171}
172
173impl<C> RaftLogStorage<C> for MemLogStore<C>
174where
175 C: RaftTypeConfig<NodeId = NodeId>,
176 C::Entry: Clone,
177{
178 type LogReader = MemLogReader<C>;
179
180 async fn get_log_state(&mut self) -> Result<LogState<C>, StorageError<NodeId>> {
181 let log = self.log.read().await;
182 let last = log
183 .entries
184 .iter()
185 .next_back()
186 .map(|(_, ent)| *ent.get_log_id());
187
188 Ok(LogState {
189 last_purged_log_id: log.last_purged_log_id,
190 last_log_id: last,
191 })
192 }
193
194 async fn get_log_reader(&mut self) -> Self::LogReader {
195 MemLogReader {
196 log: Arc::clone(&self.log),
197 }
198 }
199
200 async fn save_vote(&mut self, vote: &Vote<NodeId>) -> Result<(), StorageError<NodeId>> {
201 let mut log = self.log.write().await;
202 log.vote = Some(*vote);
203 Ok(())
204 }
205
206 async fn read_vote(&mut self) -> Result<Option<Vote<NodeId>>, StorageError<NodeId>> {
207 let log = self.log.read().await;
208 Ok(log.vote)
209 }
210
211 async fn save_committed(
212 &mut self,
213 committed: Option<LogId<NodeId>>,
214 ) -> Result<(), StorageError<NodeId>> {
215 let mut log = self.log.write().await;
216 log.committed = committed;
217 Ok(())
218 }
219
220 async fn read_committed(&mut self) -> Result<Option<LogId<NodeId>>, StorageError<NodeId>> {
221 let log = self.log.read().await;
222 Ok(log.committed)
223 }
224
225 async fn append<I>(
226 &mut self,
227 entries: I,
228 callback: LogFlushed<C>,
229 ) -> Result<(), StorageError<NodeId>>
230 where
231 I: IntoIterator<Item = C::Entry> + OptionalSend,
232 I::IntoIter: OptionalSend,
233 {
234 let mut log = self.log.write().await;
235 for entry in entries {
236 let idx = entry.get_log_id().index;
237 log.entries.insert(idx, entry);
238 }
239 callback.log_io_completed(Ok(()));
241 Ok(())
242 }
243
244 async fn truncate(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
245 let mut log = self.log.write().await;
246 let keys: Vec<u64> = log.entries.range(log_id.index..).map(|(k, _)| *k).collect();
247 for key in keys {
248 log.entries.remove(&key);
249 }
250 Ok(())
251 }
252
253 async fn purge(&mut self, log_id: LogId<NodeId>) -> Result<(), StorageError<NodeId>> {
254 let mut log = self.log.write().await;
255 let keys: Vec<u64> = log
256 .entries
257 .range(..=log_id.index)
258 .map(|(k, _)| *k)
259 .collect();
260 for key in keys {
261 log.entries.remove(&key);
262 }
263 log.last_purged_log_id = Some(log_id);
264 Ok(())
265 }
266}
267
268pub struct MemStateMachine<C, S, F>
282where
283 C: RaftTypeConfig<NodeId = NodeId>,
284 S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
285 F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
286{
287 sm: Arc<RwLock<SmData<C, S>>>,
288 apply_fn: Arc<F>,
289}
290
291impl<C, S, F> MemStateMachine<C, S, F>
292where
293 C: RaftTypeConfig<NodeId = NodeId>,
294 S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
295 F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
296{
297 pub fn new(apply_fn: F) -> Self {
299 Self {
300 sm: Arc::new(RwLock::new(SmData::default())),
301 apply_fn: Arc::new(apply_fn),
302 }
303 }
304
305 #[must_use]
307 pub fn data(&self) -> Arc<RwLock<SmData<C, S>>> {
308 Arc::clone(&self.sm)
309 }
310}
311
312impl<C, S, F> Clone for MemStateMachine<C, S, F>
313where
314 C: RaftTypeConfig<NodeId = NodeId>,
315 S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
316 F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
317{
318 fn clone(&self) -> Self {
319 Self {
320 sm: Arc::clone(&self.sm),
321 apply_fn: Arc::clone(&self.apply_fn),
322 }
323 }
324}
325
326pub struct MemSnapshotBuilder<C, S, F>
330where
331 C: RaftTypeConfig<NodeId = NodeId>,
332 S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
333 F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
334{
335 sm: Arc<RwLock<SmData<C, S>>>,
336 _phantom: std::marker::PhantomData<F>,
337}
338
339impl<C, S, F> RaftSnapshotBuilder<C> for MemSnapshotBuilder<C, S, F>
340where
341 C: RaftTypeConfig<NodeId = NodeId, SnapshotData = Cursor<Vec<u8>>>,
342 S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
343 F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
344{
345 async fn build_snapshot(&mut self) -> Result<Snapshot<C>, StorageError<NodeId>> {
346 let sm = self.sm.read().await;
347
348 let data = postcard2::to_vec(&sm.state).map_err(|e| {
349 StorageError::from_io_error(
350 openraft::ErrorSubject::StateMachine,
351 openraft::ErrorVerb::Read,
352 std::io::Error::other(e),
353 )
354 })?;
355
356 let snapshot_id = if let Some(ref last) = sm.last_applied_log {
357 format!("{}-{}", last.leader_id, last.index)
358 } else {
359 "0-0".to_string()
360 };
361
362 let meta = SnapshotMeta {
363 last_log_id: sm.last_applied_log,
364 last_membership: sm.last_membership.clone(),
365 snapshot_id,
366 };
367
368 debug!(
369 last_log_id = ?meta.last_log_id,
370 "Built in-memory snapshot"
371 );
372
373 Ok(Snapshot {
374 meta,
375 snapshot: Box::new(Cursor::new(data)),
376 })
377 }
378}
379
380impl<C, S, F> RaftStateMachine<C> for MemStateMachine<C, S, F>
383where
384 C: RaftTypeConfig<NodeId = NodeId, SnapshotData = Cursor<Vec<u8>>, Entry = Entry<C>>,
385 C::D: Clone + serde::Serialize + serde::de::DeserializeOwned,
386 C::R: Default,
387 S: Default + Clone + serde::Serialize + serde::de::DeserializeOwned + Send + Sync + 'static,
388 F: Fn(&mut S, &C::D) -> C::R + Send + Sync + 'static,
389{
390 type SnapshotBuilder = MemSnapshotBuilder<C, S, F>;
391
392 async fn applied_state(
393 &mut self,
394 ) -> Result<(Option<LogId<NodeId>>, StoredMembership<NodeId, C::Node>), StorageError<NodeId>>
395 {
396 let sm = self.sm.read().await;
397 Ok((sm.last_applied_log, sm.last_membership.clone()))
398 }
399
400 async fn apply<I>(&mut self, entries: I) -> Result<Vec<C::R>, StorageError<NodeId>>
401 where
402 I: IntoIterator<Item = C::Entry> + OptionalSend,
403 I::IntoIter: OptionalSend,
404 {
405 let mut sm = self.sm.write().await;
406 let mut responses = Vec::new();
407
408 for entry in entries {
409 sm.last_applied_log = Some(entry.log_id);
410
411 match entry.payload {
412 EntryPayload::Normal(ref data) => {
413 let resp = (self.apply_fn)(&mut sm.state, data);
414 responses.push(resp);
415 }
416 EntryPayload::Membership(ref mem) => {
417 sm.last_membership = StoredMembership::new(Some(entry.log_id), mem.clone());
418 responses.push(C::R::default());
419 }
420 EntryPayload::Blank => {
421 responses.push(C::R::default());
422 }
423 }
424 }
425
426 Ok(responses)
427 }
428
429 async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
430 MemSnapshotBuilder {
431 sm: Arc::clone(&self.sm),
432 _phantom: std::marker::PhantomData,
433 }
434 }
435
436 async fn begin_receiving_snapshot(
437 &mut self,
438 ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<NodeId>> {
439 Ok(Box::new(Cursor::new(Vec::new())))
440 }
441
442 async fn install_snapshot(
443 &mut self,
444 meta: &SnapshotMeta<NodeId, C::Node>,
445 snapshot: Box<Cursor<Vec<u8>>>,
446 ) -> Result<(), StorageError<NodeId>> {
447 let data = snapshot.into_inner();
448 let state: S = postcard2::from_bytes(&data).map_err(|e| {
449 StorageError::from_io_error(
450 openraft::ErrorSubject::Snapshot(None),
451 openraft::ErrorVerb::Read,
452 std::io::Error::other(e),
453 )
454 })?;
455
456 let mut sm = self.sm.write().await;
457 sm.last_applied_log = meta.last_log_id;
458 sm.last_membership = meta.last_membership.clone();
459 sm.state = state.clone();
460
461 let snapshot_data = postcard2::to_vec(&state).map_err(|e| {
463 StorageError::from_io_error(
464 openraft::ErrorSubject::Snapshot(None),
465 openraft::ErrorVerb::Write,
466 std::io::Error::other(e),
467 )
468 })?;
469 sm.current_snapshot = Some(StoredSnapshot {
470 meta: meta.clone(),
471 data: snapshot_data,
472 });
473
474 debug!(
475 last_log_id = ?meta.last_log_id,
476 "Installed snapshot into in-memory state machine"
477 );
478
479 Ok(())
480 }
481
482 async fn get_current_snapshot(&mut self) -> Result<Option<Snapshot<C>>, StorageError<NodeId>> {
483 let sm = self.sm.read().await;
484 match &sm.current_snapshot {
485 Some(stored) => Ok(Some(Snapshot {
486 meta: stored.meta.clone(),
487 snapshot: Box::new(Cursor::new(stored.data.clone())),
488 })),
489 None => Ok(None),
490 }
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497 use openraft::storage::RaftLogStorage;
498
499 openraft::declare_raft_types!(
501 pub TestTypeConfig:
502 D = String,
503 R = String,
504 );
505
506 #[tokio::test]
507 async fn test_log_store_empty_state() {
508 let mut store = MemLogStore::<TestTypeConfig>::new();
509 let state = RaftLogStorage::get_log_state(&mut store).await.unwrap();
510 assert!(state.last_purged_log_id.is_none());
511 assert!(state.last_log_id.is_none());
512 }
513
514 #[tokio::test]
515 async fn test_vote_round_trip() {
516 let mut store = MemLogStore::<TestTypeConfig>::new();
517
518 let vote = RaftLogStorage::read_vote(&mut store).await.unwrap();
519 assert!(vote.is_none());
520
521 let new_vote = Vote::new(1, 1);
522 RaftLogStorage::save_vote(&mut store, &new_vote)
523 .await
524 .unwrap();
525
526 let vote = RaftLogStorage::read_vote(&mut store).await.unwrap();
527 assert_eq!(vote, Some(new_vote));
528 }
529
530 #[tokio::test]
531 async fn test_state_machine_new() {
532 let sm = MemStateMachine::<TestTypeConfig, Vec<String>, _>::new(
533 |state: &mut Vec<String>, data: &String| {
534 state.push(data.clone());
535 format!("applied: {data}")
536 },
537 );
538
539 let data = sm.data();
540 {
541 let d = data.read().await;
542 assert!(d.state.is_empty());
543 }
544 }
545}