smol_workflow_engine/durable/
sqlite.rs1use anyhow::{anyhow, bail, Context};
8use rusqlite::{params, Connection};
9use std::path::{Path, PathBuf};
10use std::time::{SystemTime, UNIX_EPOCH};
11
12mod embedded_migrations {
13 include!(concat!(env!("OUT_DIR"), "/smol_workflow_migrations.rs"));
14}
15
16const MIGRATIONS_TABLE_SQL: &str = r#"
17CREATE TABLE IF NOT EXISTS sw_migrations (
18 id INTEGER PRIMARY KEY,
19 introduced_version TEXT NOT NULL,
20 applied_time INTEGER NOT NULL
21)
22"#;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct MigrationRecord {
27 pub id: i64,
28 pub introduced_version: String,
29 pub applied_time: i64,
30}
31
32pub struct SqliteDurableStore {
34 connection: Connection,
35 path: Option<PathBuf>,
36}
37
38impl SqliteDurableStore {
39 pub fn open(path: impl AsRef<Path>) -> anyhow::Result<Self> {
41 let connection = Connection::open(path.as_ref()).with_context(|| {
42 format!(
43 "failed to open durable SQLite database {}",
44 path.as_ref().display()
45 )
46 })?;
47 configure_connection(&connection)?;
48 Ok(Self {
49 connection,
50 path: Some(path.as_ref().to_path_buf()),
51 })
52 }
53
54 pub fn in_memory() -> anyhow::Result<Self> {
56 let connection = Connection::open_in_memory()
57 .context("failed to open in-memory durable SQLite database")?;
58 configure_connection(&connection)?;
59 Ok(Self {
60 connection,
61 path: None,
62 })
63 }
64
65 pub fn path(&self) -> Option<&Path> {
67 self.path.as_deref()
68 }
69
70 pub fn connection(&self) -> &Connection {
72 &self.connection
73 }
74
75 pub fn connection_mut(&mut self) -> &mut Connection {
77 &mut self.connection
78 }
79
80 pub fn init(&mut self) -> anyhow::Result<usize> {
82 self.apply_migrations(None)
83 }
84
85 pub fn apply_migrations(&mut self, target_version: Option<i64>) -> anyhow::Result<usize> {
90 configure_connection(&self.connection)?;
91 let tx = self
92 .connection
93 .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
94 .context("failed to begin durable migration transaction")?;
95
96 let result = (|| -> anyhow::Result<usize> {
97 ensure_migrations_table(&tx)?;
98
99 let applied = applied_migration_ids(&tx)?;
100 let max_applied = applied.iter().copied().max().unwrap_or(0);
101 let max_known = embedded_migrations::MIGRATIONS
102 .last()
103 .map(|migration| migration.id)
104 .unwrap_or(0);
105 let target = target_version.unwrap_or(max_known);
106
107 if target < max_applied {
108 bail!(
109 "target migration version {target} is older than applied version {max_applied}"
110 );
111 }
112 if target > max_known {
113 bail!(
114 "target migration version {target} is newer than available version {max_known}"
115 );
116 }
117
118 let mut applied_count = 0usize;
119 for migration in embedded_migrations::MIGRATIONS
120 .iter()
121 .filter(|migration| migration.id <= target)
122 {
123 if applied.contains(&migration.id) {
124 continue;
125 }
126 tx.execute_batch(migration.sql).with_context(|| {
127 format!("failed to apply durable migration {}", migration.id)
128 })?;
129 tx.execute(
130 r#"
131 INSERT INTO sw_migrations (
132 id,
133 introduced_version,
134 applied_time
135 )
136 VALUES (?1, ?2, ?3)
137 "#,
138 params![migration.id, migration.introduced_version, now_ms()],
139 )
140 .with_context(|| format!("failed to record durable migration {}", migration.id))?;
141 applied_count += 1;
142 }
143 Ok(applied_count)
144 })();
145
146 match result {
147 Ok(applied_count) => {
148 tx.commit()
149 .context("failed to commit durable migration transaction")?;
150 Ok(applied_count)
151 }
152 Err(error) => {
153 let _ = tx.rollback();
154 Err(error)
155 }
156 }
157 }
158
159 pub fn migration_records(&self) -> anyhow::Result<Vec<MigrationRecord>> {
161 ensure_migrations_table(&self.connection)?;
162 let mut statement = self
163 .connection
164 .prepare(
165 r#"
166 SELECT id, introduced_version, applied_time
167 FROM sw_migrations
168 ORDER BY id
169 "#,
170 )
171 .context("failed to prepare durable migration records query")?;
172 let rows = statement
173 .query_map([], |row| {
174 Ok(MigrationRecord {
175 id: row.get(0)?,
176 introduced_version: row.get(1)?,
177 applied_time: row.get(2)?,
178 })
179 })
180 .context("failed to query durable migration records")?;
181
182 let mut records = Vec::new();
183 for row in rows {
184 records.push(row.context("failed to read durable migration record")?);
185 }
186 Ok(records)
187 }
188
189 pub fn current_schema_version(&self) -> anyhow::Result<i64> {
191 ensure_migrations_table(&self.connection)?;
192 self.connection
193 .query_row(
194 r#"
195 SELECT COALESCE(MAX(id), 0)
196 FROM sw_migrations
197 "#,
198 [],
199 |row| row.get(0),
200 )
201 .context("failed to read durable schema version")
202 }
203}
204
205fn configure_connection(connection: &Connection) -> anyhow::Result<()> {
206 connection
207 .pragma_update(None, "foreign_keys", "ON")
208 .context("failed to enable SQLite foreign_keys")?;
209 connection
210 .busy_timeout(std::time::Duration::from_millis(5_000))
211 .context("failed to configure SQLite busy_timeout")?;
212
213 let journal_mode: String = connection
214 .pragma_query_value(None, "journal_mode", |row| row.get(0))
215 .context("failed to read SQLite journal_mode")?;
216 if !journal_mode.eq_ignore_ascii_case("memory") {
217 let mode: String = connection
218 .pragma_update_and_check(None, "journal_mode", "WAL", |row| row.get(0))
219 .context("failed to enable SQLite WAL journal_mode")?;
220 if !mode.eq_ignore_ascii_case("wal") {
221 return Err(anyhow!("expected SQLite journal_mode WAL, found {mode}"));
222 }
223 }
224
225 Ok(())
226}
227
228fn ensure_migrations_table(connection: &Connection) -> anyhow::Result<()> {
229 connection
230 .execute_batch(MIGRATIONS_TABLE_SQL)
231 .context("failed to ensure durable migrations table")
232}
233
234fn applied_migration_ids(
235 connection: &Connection,
236) -> anyhow::Result<std::collections::HashSet<i64>> {
237 let mut statement = connection
238 .prepare(
239 r#"
240 SELECT id
241 FROM sw_migrations
242 ORDER BY id
243 "#,
244 )
245 .context("failed to prepare applied durable migrations query")?;
246 let rows = statement
247 .query_map([], |row| row.get::<_, i64>(0))
248 .context("failed to query applied durable migrations")?;
249 let mut ids = std::collections::HashSet::new();
250 for row in rows {
251 ids.insert(row.context("failed to read applied durable migration id")?);
252 }
253 Ok(ids)
254}
255
256pub(crate) fn now_ms() -> i64 {
257 SystemTime::now()
258 .duration_since(UNIX_EPOCH)
259 .unwrap_or_default()
260 .as_millis() as i64
261}
262
263pub(crate) fn new_id(prefix: &str) -> String {
275 format!(
276 "{prefix}_{}",
277 ulid::Ulid::new().to_string().to_ascii_lowercase()
278 )
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn generated_ids_use_lowercase_ulid_suffixes() {
287 for prefix in ["task", "run", "step", "budget"] {
288 let id = new_id(prefix);
289 assert!(id.starts_with(&format!("{prefix}_")));
290 assert_eq!(id, id.to_ascii_lowercase());
291 }
292 }
293
294 #[test]
295 fn initializes_schema_and_records_migration() {
296 let mut store = SqliteDurableStore::in_memory().expect("store should open");
297 let applied = store.init().expect("migrations should apply");
298 assert_eq!(applied, embedded_migrations::MIGRATIONS.len());
299 assert_eq!(
300 store.current_schema_version().unwrap(),
301 embedded_migrations::MIGRATIONS.last().unwrap().id
302 );
303
304 let records = store.migration_records().unwrap();
305 assert_eq!(records.len(), embedded_migrations::MIGRATIONS.len());
306 assert_eq!(records[0].id, 1);
307 assert_eq!(
308 records[0].introduced_version,
309 embedded_migrations::MIGRATIONS[0].introduced_version
310 );
311
312 let table_count: i64 = store
313 .connection()
314 .query_row(
315 r#"
316 SELECT COUNT(*)
317 FROM sqlite_master
318 WHERE type = 'table'
319 AND name IN (
320 'sw_workflow_tasks',
321 'sw_workflow_runs',
322 'sw_workflow_steps',
323 'sw_budget_ledger'
324 )
325 "#,
326 [],
327 |row| row.get(0),
328 )
329 .unwrap();
330 assert_eq!(table_count, 4);
331 }
332
333 #[test]
334 fn migrations_are_idempotent() {
335 let mut store = SqliteDurableStore::in_memory().expect("store should open");
336 assert_eq!(store.init().unwrap(), embedded_migrations::MIGRATIONS.len());
337 assert_eq!(store.init().unwrap(), 0);
338 assert_eq!(
339 store.migration_records().unwrap().len(),
340 embedded_migrations::MIGRATIONS.len()
341 );
342 }
343
344 #[test]
345 fn rejects_target_older_than_applied_version() {
346 let mut store = SqliteDurableStore::in_memory().expect("store should open");
347 store.init().unwrap();
348 let error = store.apply_migrations(Some(0)).unwrap_err();
349 assert!(
350 error.to_string().contains("older than applied"),
351 "unexpected error: {error:#}"
352 );
353 }
354}