synaptic_graph/
checkpoint.rs1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use synaptic_core::SynapseError;
6use tokio::sync::RwLock;
7
8#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
10pub struct CheckpointConfig {
11 pub thread_id: String,
12}
13
14impl CheckpointConfig {
15 pub fn new(thread_id: impl Into<String>) -> Self {
16 Self {
17 thread_id: thread_id.into(),
18 }
19 }
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Checkpoint {
25 pub state: serde_json::Value,
26 pub next_node: Option<String>,
27}
28
29#[async_trait]
31pub trait Checkpointer: Send + Sync {
32 async fn put(
33 &self,
34 config: &CheckpointConfig,
35 checkpoint: &Checkpoint,
36 ) -> Result<(), SynapseError>;
37 async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapseError>;
38 async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapseError>;
39}
40
41#[derive(Default)]
43pub struct MemorySaver {
44 store: RwLock<HashMap<String, Vec<Checkpoint>>>,
45}
46
47impl MemorySaver {
48 pub fn new() -> Self {
49 Self::default()
50 }
51}
52
53#[async_trait]
54impl Checkpointer for MemorySaver {
55 async fn put(
56 &self,
57 config: &CheckpointConfig,
58 checkpoint: &Checkpoint,
59 ) -> Result<(), SynapseError> {
60 let mut store = self.store.write().await;
61 store
62 .entry(config.thread_id.clone())
63 .or_default()
64 .push(checkpoint.clone());
65 Ok(())
66 }
67
68 async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapseError> {
69 let store = self.store.read().await;
70 Ok(store.get(&config.thread_id).and_then(|v| v.last().cloned()))
71 }
72
73 async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapseError> {
74 let store = self.store.read().await;
75 Ok(store.get(&config.thread_id).cloned().unwrap_or_default())
76 }
77}