Skip to main content

synaptic_graph/
store_checkpointer.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use synaptic_core::{Store, SynapticError};
5
6use crate::checkpoint::{Checkpoint, CheckpointConfig, Checkpointer};
7
8/// `Checkpointer` implementation backed by any [`Store`].
9///
10/// Checkpoints are stored under namespace `["checkpoints", "{thread_id}"]`
11/// with the checkpoint ID as the key.
12///
13/// This replaces `MemorySaver` (in-memory only) and `FileSaver` (file-only)
14/// with a single implementation that works with any Store backend.
15pub struct StoreCheckpointer {
16    store: Arc<dyn Store>,
17}
18
19impl StoreCheckpointer {
20    /// Create a new checkpointer backed by the given store.
21    pub fn new(store: Arc<dyn Store>) -> Self {
22        Self { store }
23    }
24}
25
26#[async_trait]
27impl Checkpointer for StoreCheckpointer {
28    async fn put(
29        &self,
30        config: &CheckpointConfig,
31        checkpoint: &Checkpoint,
32    ) -> Result<(), SynapticError> {
33        let value = serde_json::to_value(checkpoint)
34            .map_err(|e| SynapticError::Graph(format!("failed to serialize checkpoint: {e}")))?;
35        self.store
36            .put(&["checkpoints", &config.thread_id], &checkpoint.id, value)
37            .await
38    }
39
40    async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapticError> {
41        // If a specific checkpoint_id is requested, fetch it directly
42        if let Some(ref target_id) = config.checkpoint_id {
43            let item = self
44                .store
45                .get(&["checkpoints", &config.thread_id], target_id)
46                .await?;
47            return match item {
48                Some(item) => {
49                    let checkpoint: Checkpoint =
50                        serde_json::from_value(item.value).map_err(|e| {
51                            SynapticError::Graph(format!("failed to deserialize checkpoint: {e}"))
52                        })?;
53                    Ok(Some(checkpoint))
54                }
55                None => Ok(None),
56            };
57        }
58
59        // Otherwise return the latest — search all, sort by ID (timestamp-hex), take last
60        let items = self
61            .store
62            .search(&["checkpoints", &config.thread_id], None, 10_000)
63            .await?;
64
65        if items.is_empty() {
66            return Ok(None);
67        }
68
69        // IDs are timestamp-hex format, alphabetical = chronological
70        let latest = items.into_iter().max_by(|a, b| a.key.cmp(&b.key)).unwrap();
71
72        let checkpoint: Checkpoint = serde_json::from_value(latest.value)
73            .map_err(|e| SynapticError::Graph(format!("failed to deserialize checkpoint: {e}")))?;
74        Ok(Some(checkpoint))
75    }
76
77    async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapticError> {
78        let items = self
79            .store
80            .search(&["checkpoints", &config.thread_id], None, 10_000)
81            .await?;
82
83        let mut checkpoints: Vec<Checkpoint> = items
84            .into_iter()
85            .map(|item| {
86                serde_json::from_value(item.value).map_err(|e| {
87                    SynapticError::Graph(format!("failed to deserialize checkpoint: {e}"))
88                })
89            })
90            .collect::<Result<Vec<_>, _>>()?;
91
92        // Sort by ID (oldest first) — IDs are timestamp-hex, alphabetical = chronological
93        checkpoints.sort_by(|a, b| a.id.cmp(&b.id));
94        Ok(checkpoints)
95    }
96}