synwire_checkpoint_sqlite/
saver.rs1use std::path::Path;
4use std::sync::Arc;
5
6use r2d2::Pool;
7use r2d2_sqlite::SqliteConnectionManager;
8use synwire_checkpoint::base::BaseCheckpointSaver;
9use synwire_checkpoint::types::{
10 Checkpoint, CheckpointConfig, CheckpointError, CheckpointMetadata, CheckpointTuple,
11};
12
13use crate::schema::CREATE_CHECKPOINTS_TABLE;
14
15const DEFAULT_MAX_CHECKPOINT_SIZE: usize = 16 * 1024 * 1024;
17
18#[derive(Debug, Clone)]
24pub struct SqliteSaver {
25 pool: Arc<Pool<SqliteConnectionManager>>,
26 max_checkpoint_size: usize,
27}
28
29impl SqliteSaver {
30 pub fn new(path: &Path) -> Result<Self, CheckpointError> {
40 Self::with_max_size(path, DEFAULT_MAX_CHECKPOINT_SIZE)
41 }
42
43 pub fn with_max_size(path: &Path, max_checkpoint_size: usize) -> Result<Self, CheckpointError> {
50 #[cfg(unix)]
52 {
53 if !path.exists() {
54 let _file = std::fs::File::create(path)
56 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
57 Self::set_permissions_0600(path)?;
58 }
59 }
60
61 let manager = SqliteConnectionManager::file(path);
62 let pool = Pool::builder()
63 .max_size(4)
64 .build(manager)
65 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
66
67 let conn = pool
69 .get()
70 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
71 conn.execute_batch(CREATE_CHECKPOINTS_TABLE)
72 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
73
74 Ok(Self {
75 pool: Arc::new(pool),
76 max_checkpoint_size,
77 })
78 }
79
80 #[cfg(unix)]
82 fn set_permissions_0600(path: &Path) -> Result<(), CheckpointError> {
83 use std::os::unix::fs::PermissionsExt;
84 let perms = std::fs::Permissions::from_mode(0o600);
85 std::fs::set_permissions(path, perms).map_err(|e| CheckpointError::Storage(e.to_string()))
86 }
87}
88
89#[allow(clippy::significant_drop_tightening)]
90impl BaseCheckpointSaver for SqliteSaver {
91 fn get_tuple<'a>(
92 &'a self,
93 config: &'a CheckpointConfig,
94 ) -> synwire_core::BoxFuture<'a, Result<Option<CheckpointTuple>, CheckpointError>> {
95 Box::pin(async move {
96 let conn = self
97 .pool
98 .get()
99 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
100
101 let (query, params): (&str, Vec<Box<dyn rusqlite::types::ToSql>>) =
102 config.checkpoint_id.as_ref().map_or_else(
103 || -> (&str, Vec<Box<dyn rusqlite::types::ToSql>>) {
104 (
105 "SELECT checkpoint_id, data, metadata, parent_checkpoint_id \
106 FROM checkpoints WHERE thread_id = ?1 \
107 ORDER BY rowid DESC LIMIT 1",
108 vec![Box::new(config.thread_id.clone())],
109 )
110 },
111 |checkpoint_id| {
112 (
113 "SELECT checkpoint_id, data, metadata, parent_checkpoint_id \
114 FROM checkpoints WHERE thread_id = ?1 AND checkpoint_id = ?2",
115 vec![
116 Box::new(config.thread_id.clone()),
117 Box::new(checkpoint_id.clone()),
118 ],
119 )
120 },
121 );
122
123 let param_refs: Vec<&dyn rusqlite::types::ToSql> =
124 params.iter().map(AsRef::as_ref).collect();
125
126 let mut stmt = conn
127 .prepare(query)
128 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
129
130 let result = stmt
131 .query_row(&*param_refs, |row| {
132 let checkpoint_id: String = row.get(0)?;
133 let data: Vec<u8> = row.get(1)?;
134 let metadata_json: String = row.get(2)?;
135 let parent_id: Option<String> = row.get(3)?;
136 Ok((checkpoint_id, data, metadata_json, parent_id))
137 })
138 .optional()
139 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
140
141 let Some((checkpoint_id, data, metadata_json, parent_id)) = result else {
142 return Ok(None);
143 };
144
145 let checkpoint: Checkpoint = serde_json::from_slice(&data)
146 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
147 let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json)
148 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
149
150 let tuple_config = CheckpointConfig {
151 thread_id: config.thread_id.clone(),
152 checkpoint_id: Some(checkpoint_id),
153 };
154 let parent_config = parent_id.map(|pid| CheckpointConfig {
155 thread_id: config.thread_id.clone(),
156 checkpoint_id: Some(pid),
157 });
158
159 Ok(Some(CheckpointTuple {
160 config: tuple_config,
161 checkpoint,
162 metadata,
163 parent_config,
164 }))
165 })
166 }
167
168 fn list<'a>(
169 &'a self,
170 config: &'a CheckpointConfig,
171 limit: Option<usize>,
172 ) -> synwire_core::BoxFuture<'a, Result<Vec<CheckpointTuple>, CheckpointError>> {
173 Box::pin(async move {
174 let conn = self
175 .pool
176 .get()
177 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
178
179 let limit_val: i64 = limit
180 .and_then(|l| i64::try_from(l).ok())
181 .unwrap_or(i64::MAX);
182
183 let mut stmt = conn
184 .prepare(
185 "SELECT checkpoint_id, data, metadata, parent_checkpoint_id \
186 FROM checkpoints WHERE thread_id = ?1 \
187 ORDER BY rowid DESC LIMIT ?2",
188 )
189 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
190
191 let rows = stmt
192 .query_map(rusqlite::params![config.thread_id, limit_val], |row| {
193 let checkpoint_id: String = row.get(0)?;
194 let data: Vec<u8> = row.get(1)?;
195 let metadata_json: String = row.get(2)?;
196 let parent_id: Option<String> = row.get(3)?;
197 Ok((checkpoint_id, data, metadata_json, parent_id))
198 })
199 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
200
201 let mut tuples = Vec::new();
202 for row in rows {
203 let (checkpoint_id, data, metadata_json, parent_id) =
204 row.map_err(|e| CheckpointError::Storage(e.to_string()))?;
205
206 let checkpoint: Checkpoint = serde_json::from_slice(&data)
207 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
208 let metadata: CheckpointMetadata = serde_json::from_str(&metadata_json)
209 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
210
211 let tuple_config = CheckpointConfig {
212 thread_id: config.thread_id.clone(),
213 checkpoint_id: Some(checkpoint_id),
214 };
215 let parent_config = parent_id.map(|pid| CheckpointConfig {
216 thread_id: config.thread_id.clone(),
217 checkpoint_id: Some(pid),
218 });
219
220 tuples.push(CheckpointTuple {
221 config: tuple_config,
222 checkpoint,
223 metadata,
224 parent_config,
225 });
226 }
227
228 Ok(tuples)
229 })
230 }
231
232 fn put<'a>(
233 &'a self,
234 config: &'a CheckpointConfig,
235 checkpoint: Checkpoint,
236 metadata: CheckpointMetadata,
237 ) -> synwire_core::BoxFuture<'a, Result<CheckpointConfig, CheckpointError>> {
238 Box::pin(async move {
239 let data = serde_json::to_vec(&checkpoint)
240 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
241
242 if data.len() > self.max_checkpoint_size {
243 return Err(CheckpointError::StateTooLarge {
244 size: data.len(),
245 max: self.max_checkpoint_size,
246 });
247 }
248
249 let metadata_json = serde_json::to_string(&metadata)
250 .map_err(|e| CheckpointError::Serialization(e.to_string()))?;
251
252 let conn = self
254 .pool
255 .get()
256 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
257
258 let parent_id: Option<String> = conn
259 .prepare(
260 "SELECT checkpoint_id FROM checkpoints \
261 WHERE thread_id = ?1 ORDER BY rowid DESC LIMIT 1",
262 )
263 .map_err(|e| CheckpointError::Storage(e.to_string()))?
264 .query_row(rusqlite::params![config.thread_id], |row| row.get(0))
265 .optional()
266 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
267
268 let _rows_changed = conn
269 .execute(
270 "INSERT OR REPLACE INTO checkpoints \
271 (thread_id, checkpoint_id, data, metadata, parent_checkpoint_id) \
272 VALUES (?1, ?2, ?3, ?4, ?5)",
273 rusqlite::params![
274 config.thread_id,
275 checkpoint.id,
276 data,
277 metadata_json,
278 parent_id,
279 ],
280 )
281 .map_err(|e| CheckpointError::Storage(e.to_string()))?;
282
283 Ok(CheckpointConfig {
284 thread_id: config.thread_id.clone(),
285 checkpoint_id: Some(checkpoint.id),
286 })
287 })
288 }
289}
290
291trait OptionalExt<T> {
293 fn optional(self) -> Result<Option<T>, rusqlite::Error>;
295}
296
297impl<T> OptionalExt<T> for Result<T, rusqlite::Error> {
298 fn optional(self) -> Result<Option<T>, rusqlite::Error> {
299 match self {
300 Ok(v) => Ok(Some(v)),
301 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
302 Err(e) => Err(e),
303 }
304 }
305}
306
307#[cfg(test)]
308#[allow(clippy::unwrap_used)]
309mod tests {
310 use super::*;
311 use std::collections::HashMap;
312 use synwire_checkpoint::types::CheckpointSource;
313
314 fn make_checkpoint(id: &str, step: i64) -> (Checkpoint, CheckpointMetadata) {
315 let mut cp = Checkpoint::new(id.to_owned());
316 let _prev = cp
317 .channel_values
318 .insert("messages".into(), serde_json::json!([]));
319 let metadata = CheckpointMetadata {
320 source: CheckpointSource::Loop,
321 step,
322 writes: HashMap::new(),
323 parents: HashMap::new(),
324 };
325 (cp, metadata)
326 }
327
328 #[tokio::test]
330 async fn put_get_list() {
331 let dir = tempfile::tempdir().unwrap();
332 let db_path = dir.path().join("test.db");
333 let saver = SqliteSaver::new(&db_path).unwrap();
334
335 let config = CheckpointConfig {
336 thread_id: "thread-1".into(),
337 checkpoint_id: None,
338 };
339
340 let (cp, meta) = make_checkpoint("cp-1", 0);
342 let result = saver.put(&config, cp, meta).await.unwrap();
343 assert_eq!(result.checkpoint_id.as_deref(), Some("cp-1"));
344
345 let tuple = saver.get_tuple(&config).await.unwrap().unwrap();
347 assert_eq!(tuple.checkpoint.id, "cp-1");
348
349 let (cp2, meta2) = make_checkpoint("cp-2", 1);
351 let _result2 = saver.put(&config, cp2, meta2).await.unwrap();
352
353 let tuple = saver.get_tuple(&config).await.unwrap().unwrap();
355 assert_eq!(tuple.checkpoint.id, "cp-2");
356 assert!(tuple.parent_config.is_some());
357 assert_eq!(
358 tuple.parent_config.unwrap().checkpoint_id.as_deref(),
359 Some("cp-1")
360 );
361
362 let specific = CheckpointConfig {
364 thread_id: "thread-1".into(),
365 checkpoint_id: Some("cp-1".into()),
366 };
367 let tuple = saver.get_tuple(&specific).await.unwrap().unwrap();
368 assert_eq!(tuple.checkpoint.id, "cp-1");
369
370 let all = saver.list(&config, None).await.unwrap();
372 assert_eq!(all.len(), 2);
373 assert_eq!(all[0].checkpoint.id, "cp-2");
374 assert_eq!(all[1].checkpoint.id, "cp-1");
375
376 let limited = saver.list(&config, Some(1)).await.unwrap();
378 assert_eq!(limited.len(), 1);
379 assert_eq!(limited[0].checkpoint.id, "cp-2");
380
381 let missing = CheckpointConfig {
383 thread_id: "no-such-thread".into(),
384 checkpoint_id: None,
385 };
386 assert!(saver.get_tuple(&missing).await.unwrap().is_none());
387 }
388
389 #[cfg(unix)]
391 #[tokio::test]
392 async fn file_permissions_0600() {
393 use std::os::unix::fs::PermissionsExt;
394
395 let dir = tempfile::tempdir().unwrap();
396 let db_path = dir.path().join("perms.db");
397 let _saver = SqliteSaver::new(&db_path).unwrap();
398
399 let meta = std::fs::metadata(&db_path).unwrap();
400 let mode = meta.permissions().mode() & 0o777;
401 assert_eq!(mode, 0o600, "expected 0600, got {mode:o}");
402 }
403
404 #[tokio::test]
406 async fn max_checkpoint_size_enforcement() {
407 let dir = tempfile::tempdir().unwrap();
408 let db_path = dir.path().join("size.db");
409 let saver = SqliteSaver::with_max_size(&db_path, 10).unwrap();
411
412 let config = CheckpointConfig {
413 thread_id: "thread-1".into(),
414 checkpoint_id: None,
415 };
416
417 let (cp, meta) = make_checkpoint("cp-1", 0);
418 let err = saver.put(&config, cp, meta).await.unwrap_err();
419 assert!(
420 matches!(err, CheckpointError::StateTooLarge { .. }),
421 "expected StateTooLarge, got {err:?}"
422 );
423 }
424}