synaptic_graph/
checkpoint.rs1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use synaptic_core::SynapticError;
6use tokio::sync::RwLock;
7
8#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
10pub struct CheckpointConfig {
11 pub thread_id: String,
12 pub checkpoint_id: Option<String>,
15}
16
17impl CheckpointConfig {
18 pub fn new(thread_id: impl Into<String>) -> Self {
19 Self {
20 thread_id: thread_id.into(),
21 checkpoint_id: None,
22 }
23 }
24
25 pub fn with_checkpoint_id(
27 thread_id: impl Into<String>,
28 checkpoint_id: impl Into<String>,
29 ) -> Self {
30 Self {
31 thread_id: thread_id.into(),
32 checkpoint_id: Some(checkpoint_id.into()),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Checkpoint {
40 pub id: String,
42 pub state: serde_json::Value,
44 pub next_node: Option<String>,
46 pub parent_id: Option<String>,
48 pub metadata: HashMap<String, serde_json::Value>,
50}
51
52impl Checkpoint {
53 pub fn new(state: serde_json::Value, next_node: Option<String>) -> Self {
55 Self {
56 id: generate_checkpoint_id(),
57 state,
58 next_node,
59 parent_id: None,
60 metadata: HashMap::new(),
61 }
62 }
63
64 pub fn with_parent(mut self, parent_id: impl Into<String>) -> Self {
66 self.parent_id = Some(parent_id.into());
67 self
68 }
69
70 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
72 self.metadata.insert(key.into(), value);
73 self
74 }
75}
76
77fn generate_checkpoint_id() -> String {
78 use std::sync::atomic::{AtomicU64, Ordering};
79 use std::time::{SystemTime, UNIX_EPOCH};
80
81 static COUNTER: AtomicU64 = AtomicU64::new(0);
82
83 let ts = SystemTime::now()
84 .duration_since(UNIX_EPOCH)
85 .unwrap_or_default()
86 .as_nanos();
87 let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
88 format!("{ts:x}-{seq:04x}")
89}
90
91#[async_trait]
93pub trait Checkpointer: Send + Sync {
94 async fn put(
96 &self,
97 config: &CheckpointConfig,
98 checkpoint: &Checkpoint,
99 ) -> Result<(), SynapticError>;
100
101 async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapticError>;
104
105 async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapticError>;
107}
108
109#[derive(Default)]
111pub struct MemorySaver {
112 store: RwLock<HashMap<String, Vec<Checkpoint>>>,
113}
114
115impl MemorySaver {
116 pub fn new() -> Self {
117 Self::default()
118 }
119}
120
121#[async_trait]
122impl Checkpointer for MemorySaver {
123 async fn put(
124 &self,
125 config: &CheckpointConfig,
126 checkpoint: &Checkpoint,
127 ) -> Result<(), SynapticError> {
128 let mut store = self.store.write().await;
129 store
130 .entry(config.thread_id.clone())
131 .or_default()
132 .push(checkpoint.clone());
133 Ok(())
134 }
135
136 async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapticError> {
137 let store = self.store.read().await;
138 let checkpoints = match store.get(&config.thread_id) {
139 Some(v) => v,
140 None => return Ok(None),
141 };
142
143 if let Some(ref target_id) = config.checkpoint_id {
145 return Ok(checkpoints.iter().find(|c| &c.id == target_id).cloned());
146 }
147
148 Ok(checkpoints.last().cloned())
150 }
151
152 async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapticError> {
153 let store = self.store.read().await;
154 Ok(store.get(&config.thread_id).cloned().unwrap_or_default())
155 }
156}