Skip to main content

synaptic_sqlite/
checkpointer.rs

1use std::sync::{Arc, Mutex};
2
3use async_trait::async_trait;
4use rusqlite::{params, Connection};
5use synaptic_core::SynapticError;
6use synaptic_graph::{Checkpoint, CheckpointConfig, Checkpointer};
7
8/// SQLite-backed graph checkpointer.
9///
10/// Stores graph state checkpoints in a local SQLite database file (or in-memory
11/// for testing). Uses `tokio::task::spawn_blocking` to avoid blocking the async
12/// runtime during SQLite operations.
13///
14/// # Example
15///
16/// ```rust,no_run
17/// use synaptic_sqlite::SqliteCheckpointer;
18///
19/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
20/// // File-based (persists across restarts)
21/// let cp = SqliteCheckpointer::new("/var/lib/myapp/checkpoints.db")?;
22///
23/// // In-memory (for testing)
24/// let cp = SqliteCheckpointer::in_memory()?;
25/// # Ok(())
26/// # }
27/// ```
28pub struct SqliteCheckpointer {
29    conn: Arc<Mutex<Connection>>,
30}
31
32impl SqliteCheckpointer {
33    /// Create a new checkpointer backed by a SQLite database file.
34    pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self, SynapticError> {
35        let conn = Connection::open(path)
36            .map_err(|e| SynapticError::Store(format!("SQLite open: {e}")))?;
37        conn.execute_batch(
38            "CREATE TABLE IF NOT EXISTS synaptic_checkpoints (
39                thread_id     TEXT    NOT NULL,
40                checkpoint_id TEXT    NOT NULL,
41                state         TEXT    NOT NULL,
42                created_at    INTEGER NOT NULL,
43                PRIMARY KEY (thread_id, checkpoint_id)
44            );
45            CREATE TABLE IF NOT EXISTS synaptic_checkpoint_idx (
46                thread_id     TEXT    NOT NULL,
47                checkpoint_id TEXT    NOT NULL,
48                seq           INTEGER NOT NULL,
49                PRIMARY KEY (thread_id, checkpoint_id)
50            );",
51        )
52        .map_err(|e| SynapticError::Store(format!("SQLite create tables: {e}")))?;
53        Ok(Self {
54            conn: Arc::new(Mutex::new(conn)),
55        })
56    }
57
58    /// Create an in-memory checkpointer (useful for testing).
59    pub fn in_memory() -> Result<Self, SynapticError> {
60        Self::new(":memory:")
61    }
62}
63
64#[async_trait]
65impl Checkpointer for SqliteCheckpointer {
66    async fn put(
67        &self,
68        config: &CheckpointConfig,
69        checkpoint: &Checkpoint,
70    ) -> Result<(), SynapticError> {
71        let conn = Arc::clone(&self.conn);
72        let thread_id = config.thread_id.clone();
73        let checkpoint_id = checkpoint.id.clone();
74        let data = serde_json::to_string(checkpoint)
75            .map_err(|e| SynapticError::Store(format!("Serialize: {e}")))?;
76        let now = std::time::SystemTime::now()
77            .duration_since(std::time::UNIX_EPOCH)
78            .unwrap_or_default()
79            .as_secs() as i64;
80
81        tokio::task::spawn_blocking(move || {
82            let conn = conn
83                .lock()
84                .map_err(|e| SynapticError::Store(format!("Lock: {e}")))?;
85
86            conn.execute(
87                "INSERT OR REPLACE INTO synaptic_checkpoints \
88                 (thread_id, checkpoint_id, state, created_at) \
89                 VALUES (?1, ?2, ?3, ?4)",
90                params![thread_id, checkpoint_id, data, now],
91            )
92            .map_err(|e| SynapticError::Store(format!("SQLite INSERT: {e}")))?;
93
94            // Determine next sequence number for this thread
95            let max_seq: i64 = conn
96                .query_row(
97                    "SELECT COALESCE(MAX(seq), -1) FROM synaptic_checkpoint_idx \
98                     WHERE thread_id = ?1",
99                    params![thread_id],
100                    |row| row.get(0),
101                )
102                .unwrap_or(-1);
103
104            conn.execute(
105                "INSERT OR IGNORE INTO synaptic_checkpoint_idx \
106                 (thread_id, checkpoint_id, seq) VALUES (?1, ?2, ?3)",
107                params![thread_id, checkpoint_id, max_seq + 1],
108            )
109            .map_err(|e| SynapticError::Store(format!("SQLite INSERT idx: {e}")))?;
110
111            Ok(())
112        })
113        .await
114        .map_err(|e| SynapticError::Store(format!("spawn_blocking: {e}")))?
115    }
116
117    async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapticError> {
118        let conn = Arc::clone(&self.conn);
119        let thread_id = config.thread_id.clone();
120        let checkpoint_id = config.checkpoint_id.clone();
121
122        tokio::task::spawn_blocking(move || {
123            let conn = conn
124                .lock()
125                .map_err(|e| SynapticError::Store(format!("Lock: {e}")))?;
126
127            // Resolve checkpoint ID: explicit or latest by seq
128            let resolved_id: Option<String> = if let Some(ref id) = checkpoint_id {
129                Some(id.clone())
130            } else {
131                conn.query_row(
132                    "SELECT checkpoint_id FROM synaptic_checkpoint_idx \
133                     WHERE thread_id = ?1 ORDER BY seq DESC LIMIT 1",
134                    params![thread_id],
135                    |row| row.get(0),
136                )
137                .ok()
138            };
139
140            let id = match resolved_id {
141                Some(id) => id,
142                None => return Ok(None),
143            };
144
145            let data: Option<String> = conn
146                .query_row(
147                    "SELECT state FROM synaptic_checkpoints \
148                     WHERE thread_id = ?1 AND checkpoint_id = ?2",
149                    params![thread_id, id],
150                    |row| row.get(0),
151                )
152                .ok();
153
154            match data {
155                None => Ok(None),
156                Some(json) => {
157                    let cp: Checkpoint = serde_json::from_str(&json)
158                        .map_err(|e| SynapticError::Store(format!("Deserialize: {e}")))?;
159                    Ok(Some(cp))
160                }
161            }
162        })
163        .await
164        .map_err(|e| SynapticError::Store(format!("spawn_blocking: {e}")))?
165    }
166
167    async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapticError> {
168        let conn = Arc::clone(&self.conn);
169        let thread_id = config.thread_id.clone();
170
171        tokio::task::spawn_blocking(move || {
172            let conn = conn
173                .lock()
174                .map_err(|e| SynapticError::Store(format!("Lock: {e}")))?;
175
176            let mut stmt = conn
177                .prepare(
178                    "SELECT c.state \
179                     FROM synaptic_checkpoints c \
180                     JOIN synaptic_checkpoint_idx i \
181                       ON c.thread_id = i.thread_id AND c.checkpoint_id = i.checkpoint_id \
182                     WHERE c.thread_id = ?1 \
183                     ORDER BY i.seq ASC",
184                )
185                .map_err(|e| SynapticError::Store(format!("SQLite prepare: {e}")))?;
186
187            let checkpoints: Result<Vec<Checkpoint>, SynapticError> = stmt
188                .query_map(params![thread_id], |row| row.get::<_, String>(0))
189                .map_err(|e| SynapticError::Store(format!("SQLite query: {e}")))?
190                .filter_map(|r| r.ok())
191                .map(|json| {
192                    serde_json::from_str(&json)
193                        .map_err(|e| SynapticError::Store(format!("Deserialize: {e}")))
194                })
195                .collect();
196
197            checkpoints
198        })
199        .await
200        .map_err(|e| SynapticError::Store(format!("spawn_blocking: {e}")))?
201    }
202}