Skip to main content

synwire_checkpoint_sqlite/
saver.rs

1//! SQLite-backed checkpoint saver implementation.
2
3use std::path::Path;
4use std::sync::Arc;
5
6use r2d2::Pool;
7use r2d2_sqlite::SqliteConnectionManager;
8use synwire_checkpoint::base::BaseCheckpointSaver;
9use synwire_checkpoint::types::{
10    Checkpoint, CheckpointConfig, CheckpointError, CheckpointMetadata, CheckpointTuple,
11};
12
13use crate::schema::CREATE_CHECKPOINTS_TABLE;
14
15/// Default maximum checkpoint size in bytes (16 MiB).
16const DEFAULT_MAX_CHECKPOINT_SIZE: usize = 16 * 1024 * 1024;
17
18/// SQLite-backed checkpoint saver.
19///
20/// Persists checkpoints to a `SQLite` database file with configurable
21/// maximum checkpoint size. The database file is created with mode 0600
22/// permissions on Unix systems.
23#[derive(Debug, Clone)]
24pub struct SqliteSaver {
25    pool: Arc<Pool<SqliteConnectionManager>>,
26    max_checkpoint_size: usize,
27}
28
29impl SqliteSaver {
30    /// Create a new `SqliteSaver` at the given path.
31    ///
32    /// Creates the database file (with 0600 permissions on Unix) and
33    /// initialises the schema if it does not already exist.
34    ///
35    /// # Errors
36    ///
37    /// Returns `CheckpointError::Storage` if the database cannot be opened
38    /// or the schema cannot be created.
39    pub fn new(path: &Path) -> Result<Self, CheckpointError> {
40        Self::with_max_size(path, DEFAULT_MAX_CHECKPOINT_SIZE)
41    }
42
43    /// Create a new `SqliteSaver` with a custom maximum checkpoint size.
44    ///
45    /// # Errors
46    ///
47    /// Returns `CheckpointError::Storage` if the database cannot be opened
48    /// or the schema cannot be created.
49    pub fn with_max_size(path: &Path, max_checkpoint_size: usize) -> Result<Self, CheckpointError> {
50        // Set file permissions to 0600 on Unix before opening.
51        #[cfg(unix)]
52        {
53            if !path.exists() {
54                // Create the file first so we can set permissions.
55                let _file = std::fs::File::create(path)
56                    .map_err(|e| CheckpointError::Storage(e.to_string()))?;
57                Self::set_permissions_0600(path)?;
58            }
59        }
60
61        let manager = SqliteConnectionManager::file(path);
62        let pool = Pool::builder()
63            .max_size(4)
64            .build(manager)
65            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
66
67        // Initialise schema.
68        let conn = pool
69            .get()
70            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
71        conn.execute_batch(CREATE_CHECKPOINTS_TABLE)
72            .map_err(|e| CheckpointError::Storage(e.to_string()))?;
73
74        Ok(Self {
75            pool: Arc::new(pool),
76            max_checkpoint_size,
77        })
78    }
79
80    /// Set file permissions to 0600 on Unix.
81    #[cfg(unix)]
82    fn set_permissions_0600(path: &Path) -> Result<(), CheckpointError> {
83        use std::os::unix::fs::PermissionsExt;
84        let perms = std::fs::Permissions::from_mode(0o600);
85        std::fs::set_permissions(path, perms).map_err(|e| CheckpointError::Storage(e.to_string()))
86    }
87}
88
89#[allow(clippy::significant_drop_tightening)]
90impl BaseCheckpointSaver for SqliteSaver {
91    fn get_tuple<'a>(
92        &'a self,
93        config: &'a CheckpointConfig,
94    ) -> synwire_core::BoxFuture<'a, Result<Option<CheckpointTuple>, CheckpointError>> {
95        Box::pin(async move {
96            let conn = self
97                .pool
98                .get()
99                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
100
101            let (query, params): (&str, Vec<Box<dyn rusqlite::types::ToSql>>) =
102                config.checkpoint_id.as_ref().map_or_else(
103                    || -> (&str, Vec<Box<dyn rusqlite::types::ToSql>>) {
104                        (
105                            "SELECT checkpoint_id, data, metadata, parent_checkpoint_id \
106                             FROM checkpoints WHERE thread_id = ?1 \
107                             ORDER BY rowid DESC LIMIT 1",
108                            vec![Box::new(config.thread_id.clone())],
109                        )
110                    },
111                    |checkpoint_id| {
112                        (
113                            "SELECT checkpoint_id, data, metadata, parent_checkpoint_id \
114                             FROM checkpoints WHERE thread_id = ?1 AND checkpoint_id = ?2",
115                            vec![
116                                Box::new(config.thread_id.clone()),
117                                Box::new(checkpoint_id.clone()),
118                            ],
119                        )
120                    },
121                );
122
123            let param_refs: Vec<&dyn rusqlite::types::ToSql> =
124                params.iter().map(AsRef::as_ref).collect();
125
126            let mut stmt = conn
127                .prepare(query)
128                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
129
130            let result = stmt
131                .query_row(&*param_refs, |row| {
132                    let checkpoint_id: String = row.get(0)?;
133                    let data: Vec<u8> = row.get(1)?;
134                    let metadata_json: String = row.get(2)?;
135                    let parent_id: Option<String> = row.get(3)?;
136                    Ok((checkpoint_id, data, metadata_json, parent_id))
137                })
138                .optional()
139                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
140
141            let Some((checkpoint_id, data, metadata_json, parent_id)) = result else {
142                return Ok(None);
143            };
144
145            let checkpoint: Checkpoint = serde_json::from_slice(&data)
146                .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
147            let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json)
148                .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
149
150            let tuple_config = CheckpointConfig {
151                thread_id: config.thread_id.clone(),
152                checkpoint_id: Some(checkpoint_id),
153            };
154            let parent_config = parent_id.map(|pid| CheckpointConfig {
155                thread_id: config.thread_id.clone(),
156                checkpoint_id: Some(pid),
157            });
158
159            Ok(Some(CheckpointTuple {
160                config: tuple_config,
161                checkpoint,
162                metadata,
163                parent_config,
164            }))
165        })
166    }
167
168    fn list<'a>(
169        &'a self,
170        config: &'a CheckpointConfig,
171        limit: Option<usize>,
172    ) -> synwire_core::BoxFuture<'a, Result<Vec<CheckpointTuple>, CheckpointError>> {
173        Box::pin(async move {
174            let conn = self
175                .pool
176                .get()
177                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
178
179            let limit_val: i64 = limit
180                .and_then(|l| i64::try_from(l).ok())
181                .unwrap_or(i64::MAX);
182
183            let mut stmt = conn
184                .prepare(
185                    "SELECT checkpoint_id, data, metadata, parent_checkpoint_id \
186                     FROM checkpoints WHERE thread_id = ?1 \
187                     ORDER BY rowid DESC LIMIT ?2",
188                )
189                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
190
191            let rows = stmt
192                .query_map(rusqlite::params![config.thread_id, limit_val], |row| {
193                    let checkpoint_id: String = row.get(0)?;
194                    let data: Vec<u8> = row.get(1)?;
195                    let metadata_json: String = row.get(2)?;
196                    let parent_id: Option<String> = row.get(3)?;
197                    Ok((checkpoint_id, data, metadata_json, parent_id))
198                })
199                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
200
201            let mut tuples = Vec::new();
202            for row in rows {
203                let (checkpoint_id, data, metadata_json, parent_id) =
204                    row.map_err(|e| CheckpointError::Storage(e.to_string()))?;
205
206                let checkpoint: Checkpoint = serde_json::from_slice(&data)
207                    .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
208                let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json)
209                    .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
210
211                let tuple_config = CheckpointConfig {
212                    thread_id: config.thread_id.clone(),
213                    checkpoint_id: Some(checkpoint_id),
214                };
215                let parent_config = parent_id.map(|pid| CheckpointConfig {
216                    thread_id: config.thread_id.clone(),
217                    checkpoint_id: Some(pid),
218                });
219
220                tuples.push(CheckpointTuple {
221                    config: tuple_config,
222                    checkpoint,
223                    metadata,
224                    parent_config,
225                });
226            }
227
228            Ok(tuples)
229        })
230    }
231
232    fn put<'a>(
233        &'a self,
234        config: &'a CheckpointConfig,
235        checkpoint: Checkpoint,
236        metadata: CheckpointMetadata,
237    ) -> synwire_core::BoxFuture<'a, Result<CheckpointConfig, CheckpointError>> {
238        Box::pin(async move {
239            let data = serde_json::to_vec(&checkpoint)
240                .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
241
242            if data.len() > self.max_checkpoint_size {
243                return Err(CheckpointError::StateTooLarge {
244                    size: data.len(),
245                    max: self.max_checkpoint_size,
246                });
247            }
248
249            let metadata_json = serde_json::to_string(&metadata)
250                .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
251
252            // Determine parent: the latest existing checkpoint for this thread.
253            let conn = self
254                .pool
255                .get()
256                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
257
258            let parent_id: Option<String> = conn
259                .prepare(
260                    "SELECT checkpoint_id FROM checkpoints \
261                     WHERE thread_id = ?1 ORDER BY rowid DESC LIMIT 1",
262                )
263                .map_err(|e| CheckpointError::Storage(e.to_string()))?
264                .query_row(rusqlite::params![config.thread_id], |row| row.get(0))
265                .optional()
266                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
267
268            let _rows_changed = conn
269                .execute(
270                    "INSERT OR REPLACE INTO checkpoints \
271                     (thread_id, checkpoint_id, data, metadata, parent_checkpoint_id) \
272                     VALUES (?1, ?2, ?3, ?4, ?5)",
273                    rusqlite::params![
274                        config.thread_id,
275                        checkpoint.id,
276                        data,
277                        metadata_json,
278                        parent_id,
279                    ],
280                )
281                .map_err(|e| CheckpointError::Storage(e.to_string()))?;
282
283            Ok(CheckpointConfig {
284                thread_id: config.thread_id.clone(),
285                checkpoint_id: Some(checkpoint.id),
286            })
287        })
288    }
289}
290
291/// Extension trait for `rusqlite` optional query results.
292trait OptionalExt<T> {
293    /// Convert a `QueryReturnedNoRows` error to `Ok(None)`.
294    fn optional(self) -> Result<Option<T>, rusqlite::Error>;
295}
296
297impl<T> OptionalExt<T> for Result<T, rusqlite::Error> {
298    fn optional(self) -> Result<Option<T>, rusqlite::Error> {
299        match self {
300            Ok(v) => Ok(Some(v)),
301            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
302            Err(e) => Err(e),
303        }
304    }
305}
306
307#[cfg(test)]
308#[allow(clippy::unwrap_used)]
309mod tests {
310    use super::*;
311    use std::collections::HashMap;
312    use synwire_checkpoint::types::CheckpointSource;
313
314    fn make_checkpoint(id: &str, step: i64) -> (Checkpoint, CheckpointMetadata) {
315        let mut cp = Checkpoint::new(id.to_owned());
316        let _prev = cp
317            .channel_values
318            .insert("messages".into(), serde_json::json!([]));
319        let metadata = CheckpointMetadata {
320            source: CheckpointSource::Loop,
321            step,
322            writes: HashMap::new(),
323            parents: HashMap::new(),
324        };
325        (cp, metadata)
326    }
327
328    /// T221: `SqliteSaver` put/get/list.
329    #[tokio::test]
330    async fn put_get_list() {
331        let dir = tempfile::tempdir().unwrap();
332        let db_path = dir.path().join("test.db");
333        let saver = SqliteSaver::new(&db_path).unwrap();
334
335        let config = CheckpointConfig {
336            thread_id: "thread-1".into(),
337            checkpoint_id: None,
338        };
339
340        // Put
341        let (cp, meta) = make_checkpoint("cp-1", 0);
342        let result = saver.put(&config, cp, meta).await.unwrap();
343        assert_eq!(result.checkpoint_id.as_deref(), Some("cp-1"));
344
345        // Get latest
346        let tuple = saver.get_tuple(&config).await.unwrap().unwrap();
347        assert_eq!(tuple.checkpoint.id, "cp-1");
348
349        // Put second
350        let (cp2, meta2) = make_checkpoint("cp-2", 1);
351        let _result2 = saver.put(&config, cp2, meta2).await.unwrap();
352
353        // Get latest should be cp-2
354        let tuple = saver.get_tuple(&config).await.unwrap().unwrap();
355        assert_eq!(tuple.checkpoint.id, "cp-2");
356        assert!(tuple.parent_config.is_some());
357        assert_eq!(
358            tuple.parent_config.unwrap().checkpoint_id.as_deref(),
359            Some("cp-1")
360        );
361
362        // Get specific
363        let specific = CheckpointConfig {
364            thread_id: "thread-1".into(),
365            checkpoint_id: Some("cp-1".into()),
366        };
367        let tuple = saver.get_tuple(&specific).await.unwrap().unwrap();
368        assert_eq!(tuple.checkpoint.id, "cp-1");
369
370        // List
371        let all = saver.list(&config, None).await.unwrap();
372        assert_eq!(all.len(), 2);
373        assert_eq!(all[0].checkpoint.id, "cp-2");
374        assert_eq!(all[1].checkpoint.id, "cp-1");
375
376        // List with limit
377        let limited = saver.list(&config, Some(1)).await.unwrap();
378        assert_eq!(limited.len(), 1);
379        assert_eq!(limited[0].checkpoint.id, "cp-2");
380
381        // Get non-existent
382        let missing = CheckpointConfig {
383            thread_id: "no-such-thread".into(),
384            checkpoint_id: None,
385        };
386        assert!(saver.get_tuple(&missing).await.unwrap().is_none());
387    }
388
389    /// T222: `SqliteSaver` file permissions are 0600.
390    #[cfg(unix)]
391    #[tokio::test]
392    async fn file_permissions_0600() {
393        use std::os::unix::fs::PermissionsExt;
394
395        let dir = tempfile::tempdir().unwrap();
396        let db_path = dir.path().join("perms.db");
397        let _saver = SqliteSaver::new(&db_path).unwrap();
398
399        let meta = std::fs::metadata(&db_path).unwrap();
400        let mode = meta.permissions().mode() & 0o777;
401        assert_eq!(mode, 0o600, "expected 0600, got {mode:o}");
402    }
403
404    /// T224: `max_checkpoint_size` enforcement.
405    #[tokio::test]
406    async fn max_checkpoint_size_enforcement() {
407        let dir = tempfile::tempdir().unwrap();
408        let db_path = dir.path().join("size.db");
409        // Very small max size.
410        let saver = SqliteSaver::with_max_size(&db_path, 10).unwrap();
411
412        let config = CheckpointConfig {
413            thread_id: "thread-1".into(),
414            checkpoint_id: None,
415        };
416
417        let (cp, meta) = make_checkpoint("cp-1", 0);
418        let err = saver.put(&config, cp, meta).await.unwrap_err();
419        assert!(
420            matches!(err, CheckpointError::StateTooLarge { .. }),
421            "expected StateTooLarge, got {err:?}"
422        );
423    }
424}