synaptic_sqlite/
checkpointer.rs1use std::sync::{Arc, Mutex};
2
3use async_trait::async_trait;
4use rusqlite::{params, Connection};
5use synaptic_core::SynapticError;
6use synaptic_graph::{Checkpoint, CheckpointConfig, Checkpointer};
7
8pub struct SqliteCheckpointer {
29 conn: Arc<Mutex<Connection>>,
30}
31
32impl SqliteCheckpointer {
33 pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self, SynapticError> {
35 let conn = Connection::open(path)
36 .map_err(|e| SynapticError::Store(format!("SQLite open: {e}")))?;
37 conn.execute_batch(
38 "CREATE TABLE IF NOT EXISTS synaptic_checkpoints (
39 thread_id TEXT NOT NULL,
40 checkpoint_id TEXT NOT NULL,
41 state TEXT NOT NULL,
42 created_at INTEGER NOT NULL,
43 PRIMARY KEY (thread_id, checkpoint_id)
44 );
45 CREATE TABLE IF NOT EXISTS synaptic_checkpoint_idx (
46 thread_id TEXT NOT NULL,
47 checkpoint_id TEXT NOT NULL,
48 seq INTEGER NOT NULL,
49 PRIMARY KEY (thread_id, checkpoint_id)
50 );",
51 )
52 .map_err(|e| SynapticError::Store(format!("SQLite create tables: {e}")))?;
53 Ok(Self {
54 conn: Arc::new(Mutex::new(conn)),
55 })
56 }
57
58 pub fn in_memory() -> Result<Self, SynapticError> {
60 Self::new(":memory:")
61 }
62}
63
64#[async_trait]
65impl Checkpointer for SqliteCheckpointer {
66 async fn put(
67 &self,
68 config: &CheckpointConfig,
69 checkpoint: &Checkpoint,
70 ) -> Result<(), SynapticError> {
71 let conn = Arc::clone(&self.conn);
72 let thread_id = config.thread_id.clone();
73 let checkpoint_id = checkpoint.id.clone();
74 let data = serde_json::to_string(checkpoint)
75 .map_err(|e| SynapticError::Store(format!("Serialize: {e}")))?;
76 let now = std::time::SystemTime::now()
77 .duration_since(std::time::UNIX_EPOCH)
78 .unwrap_or_default()
79 .as_secs() as i64;
80
81 tokio::task::spawn_blocking(move || {
82 let conn = conn
83 .lock()
84 .map_err(|e| SynapticError::Store(format!("Lock: {e}")))?;
85
86 conn.execute(
87 "INSERT OR REPLACE INTO synaptic_checkpoints \
88 (thread_id, checkpoint_id, state, created_at) \
89 VALUES (?1, ?2, ?3, ?4)",
90 params![thread_id, checkpoint_id, data, now],
91 )
92 .map_err(|e| SynapticError::Store(format!("SQLite INSERT: {e}")))?;
93
94 let max_seq: i64 = conn
96 .query_row(
97 "SELECT COALESCE(MAX(seq), -1) FROM synaptic_checkpoint_idx \
98 WHERE thread_id = ?1",
99 params![thread_id],
100 |row| row.get(0),
101 )
102 .unwrap_or(-1);
103
104 conn.execute(
105 "INSERT OR IGNORE INTO synaptic_checkpoint_idx \
106 (thread_id, checkpoint_id, seq) VALUES (?1, ?2, ?3)",
107 params![thread_id, checkpoint_id, max_seq + 1],
108 )
109 .map_err(|e| SynapticError::Store(format!("SQLite INSERT idx: {e}")))?;
110
111 Ok(())
112 })
113 .await
114 .map_err(|e| SynapticError::Store(format!("spawn_blocking: {e}")))?
115 }
116
117 async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapticError> {
118 let conn = Arc::clone(&self.conn);
119 let thread_id = config.thread_id.clone();
120 let checkpoint_id = config.checkpoint_id.clone();
121
122 tokio::task::spawn_blocking(move || {
123 let conn = conn
124 .lock()
125 .map_err(|e| SynapticError::Store(format!("Lock: {e}")))?;
126
127 let resolved_id: Option<String> = if let Some(ref id) = checkpoint_id {
129 Some(id.clone())
130 } else {
131 conn.query_row(
132 "SELECT checkpoint_id FROM synaptic_checkpoint_idx \
133 WHERE thread_id = ?1 ORDER BY seq DESC LIMIT 1",
134 params![thread_id],
135 |row| row.get(0),
136 )
137 .ok()
138 };
139
140 let id = match resolved_id {
141 Some(id) => id,
142 None => return Ok(None),
143 };
144
145 let data: Option<String> = conn
146 .query_row(
147 "SELECT state FROM synaptic_checkpoints \
148 WHERE thread_id = ?1 AND checkpoint_id = ?2",
149 params![thread_id, id],
150 |row| row.get(0),
151 )
152 .ok();
153
154 match data {
155 None => Ok(None),
156 Some(json) => {
157 let cp: Checkpoint = serde_json::from_str(&json)
158 .map_err(|e| SynapticError::Store(format!("Deserialize: {e}")))?;
159 Ok(Some(cp))
160 }
161 }
162 })
163 .await
164 .map_err(|e| SynapticError::Store(format!("spawn_blocking: {e}")))?
165 }
166
167 async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapticError> {
168 let conn = Arc::clone(&self.conn);
169 let thread_id = config.thread_id.clone();
170
171 tokio::task::spawn_blocking(move || {
172 let conn = conn
173 .lock()
174 .map_err(|e| SynapticError::Store(format!("Lock: {e}")))?;
175
176 let mut stmt = conn
177 .prepare(
178 "SELECT c.state \
179 FROM synaptic_checkpoints c \
180 JOIN synaptic_checkpoint_idx i \
181 ON c.thread_id = i.thread_id AND c.checkpoint_id = i.checkpoint_id \
182 WHERE c.thread_id = ?1 \
183 ORDER BY i.seq ASC",
184 )
185 .map_err(|e| SynapticError::Store(format!("SQLite prepare: {e}")))?;
186
187 let checkpoints: Result<Vec<Checkpoint>, SynapticError> = stmt
188 .query_map(params![thread_id], |row| row.get::<_, String>(0))
189 .map_err(|e| SynapticError::Store(format!("SQLite query: {e}")))?
190 .filter_map(|r| r.ok())
191 .map(|json| {
192 serde_json::from_str(&json)
193 .map_err(|e| SynapticError::Store(format!("Deserialize: {e}")))
194 })
195 .collect();
196
197 checkpoints
198 })
199 .await
200 .map_err(|e| SynapticError::Store(format!("spawn_blocking: {e}")))?
201 }
202}