Skip to main content

synaptic_graph/
checkpoint.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use synaptic_core::SynapticError;
6use tokio::sync::RwLock;
7
8/// Configuration identifying a checkpoint (thread/conversation).
9#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
10pub struct CheckpointConfig {
11    pub thread_id: String,
12    /// Optional: target a specific checkpoint for time-travel.
13    /// When `None`, operations target the latest checkpoint.
14    pub checkpoint_id: Option<String>,
15}
16
17impl CheckpointConfig {
18    pub fn new(thread_id: impl Into<String>) -> Self {
19        Self {
20            thread_id: thread_id.into(),
21            checkpoint_id: None,
22        }
23    }
24
25    /// Create a config targeting a specific checkpoint (for time-travel).
26    pub fn with_checkpoint_id(
27        thread_id: impl Into<String>,
28        checkpoint_id: impl Into<String>,
29    ) -> Self {
30        Self {
31            thread_id: thread_id.into(),
32            checkpoint_id: Some(checkpoint_id.into()),
33        }
34    }
35}
36
37/// A snapshot of graph state at a point in execution.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Checkpoint {
40    /// Unique identifier for this checkpoint.
41    pub id: String,
42    /// Serialized graph state.
43    pub state: serde_json::Value,
44    /// The next node to execute (or None if graph completed).
45    pub next_node: Option<String>,
46    /// ID of the previous checkpoint (for traversing history).
47    pub parent_id: Option<String>,
48    /// Metadata about this checkpoint (node name, timestamp, etc.).
49    pub metadata: HashMap<String, serde_json::Value>,
50}
51
52impl Checkpoint {
53    /// Create a new checkpoint with auto-generated ID.
54    pub fn new(state: serde_json::Value, next_node: Option<String>) -> Self {
55        Self {
56            id: generate_checkpoint_id(),
57            state,
58            next_node,
59            parent_id: None,
60            metadata: HashMap::new(),
61        }
62    }
63
64    /// Set the parent checkpoint ID.
65    pub fn with_parent(mut self, parent_id: impl Into<String>) -> Self {
66        self.parent_id = Some(parent_id.into());
67        self
68    }
69
70    /// Add metadata to the checkpoint.
71    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
72        self.metadata.insert(key.into(), value);
73        self
74    }
75}
76
77fn generate_checkpoint_id() -> String {
78    use std::sync::atomic::{AtomicU64, Ordering};
79    use std::time::{SystemTime, UNIX_EPOCH};
80
81    static COUNTER: AtomicU64 = AtomicU64::new(0);
82
83    let ts = SystemTime::now()
84        .duration_since(UNIX_EPOCH)
85        .unwrap_or_default()
86        .as_nanos();
87    let seq = COUNTER.fetch_add(1, Ordering::Relaxed);
88    format!("{ts:x}-{seq:04x}")
89}
90
91/// Trait for persisting graph state checkpoints.
92#[async_trait]
93pub trait Checkpointer: Send + Sync {
94    /// Save a checkpoint for the given thread.
95    async fn put(
96        &self,
97        config: &CheckpointConfig,
98        checkpoint: &Checkpoint,
99    ) -> Result<(), SynapticError>;
100
101    /// Get a checkpoint. If `config.checkpoint_id` is set, returns that specific
102    /// checkpoint; otherwise returns the latest checkpoint for the thread.
103    async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapticError>;
104
105    /// List all checkpoints for a thread, ordered oldest to newest.
106    async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapticError>;
107}
108
109/// In-memory checkpointer (for development/testing).
110#[derive(Default)]
111pub struct MemorySaver {
112    store: RwLock<HashMap<String, Vec<Checkpoint>>>,
113}
114
115impl MemorySaver {
116    pub fn new() -> Self {
117        Self::default()
118    }
119}
120
121#[async_trait]
122impl Checkpointer for MemorySaver {
123    async fn put(
124        &self,
125        config: &CheckpointConfig,
126        checkpoint: &Checkpoint,
127    ) -> Result<(), SynapticError> {
128        let mut store = self.store.write().await;
129        store
130            .entry(config.thread_id.clone())
131            .or_default()
132            .push(checkpoint.clone());
133        Ok(())
134    }
135
136    async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapticError> {
137        let store = self.store.read().await;
138        let checkpoints = match store.get(&config.thread_id) {
139            Some(v) => v,
140            None => return Ok(None),
141        };
142
143        // If a specific checkpoint_id is requested, find it
144        if let Some(ref target_id) = config.checkpoint_id {
145            return Ok(checkpoints.iter().find(|c| &c.id == target_id).cloned());
146        }
147
148        // Otherwise return the latest
149        Ok(checkpoints.last().cloned())
150    }
151
152    async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapticError> {
153        let store = self.store.read().await;
154        Ok(store.get(&config.thread_id).cloned().unwrap_or_default())
155    }
156}