Skip to main content

smol_workflow_engine/durable/
sqlite.rs

1//! SQLite durable workflow store infrastructure.
2//!
3//! This module owns database opening, pragmatic SQLite setup, and schema
4//! migrations. Numbered SQL files are embedded at compile time, applied inside
5//! one immediate transaction, and recorded in a migrations table.
6
7use anyhow::{anyhow, bail, Context};
8use rusqlite::{params, Connection};
9use std::path::{Path, PathBuf};
10use std::time::{SystemTime, UNIX_EPOCH};
11
12mod embedded_migrations {
13    include!(concat!(env!("OUT_DIR"), "/smol_workflow_migrations.rs"));
14}
15
16const MIGRATIONS_TABLE_SQL: &str = r#"
17CREATE TABLE IF NOT EXISTS sw_migrations (
18    id INTEGER PRIMARY KEY,
19    introduced_version TEXT NOT NULL,
20    applied_time INTEGER NOT NULL
21)
22"#;
23
24/// One applied durable schema migration.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct MigrationRecord {
27    pub id: i64,
28    pub introduced_version: String,
29    pub applied_time: i64,
30}
31
32/// SQLite-backed durable workflow store.
33pub struct SqliteDurableStore {
34    connection: Connection,
35    path: Option<PathBuf>,
36}
37
38impl SqliteDurableStore {
39    /// Open a durable store at `path` and apply connection pragmas.
40    pub fn open(path: impl AsRef<Path>) -> anyhow::Result<Self> {
41        let connection = Connection::open(path.as_ref()).with_context(|| {
42            format!(
43                "failed to open durable SQLite database {}",
44                path.as_ref().display()
45            )
46        })?;
47        configure_connection(&connection)?;
48        Ok(Self {
49            connection,
50            path: Some(path.as_ref().to_path_buf()),
51        })
52    }
53
54    /// Create an in-memory durable store. Useful for tests.
55    pub fn in_memory() -> anyhow::Result<Self> {
56        let connection = Connection::open_in_memory()
57            .context("failed to open in-memory durable SQLite database")?;
58        configure_connection(&connection)?;
59        Ok(Self {
60            connection,
61            path: None,
62        })
63    }
64
65    /// Return the durable database path when this store was opened from a file.
66    pub fn path(&self) -> Option<&Path> {
67        self.path.as_deref()
68    }
69
70    /// Borrow the underlying SQLite connection.
71    pub fn connection(&self) -> &Connection {
72        &self.connection
73    }
74
75    /// Mutably borrow the underlying SQLite connection.
76    pub fn connection_mut(&mut self) -> &mut Connection {
77        &mut self.connection
78    }
79
80    /// Initialize the durable schema by applying all available migrations.
81    pub fn init(&mut self) -> anyhow::Result<usize> {
82        self.apply_migrations(None)
83    }
84
85    /// Apply migrations up to `target_version`, or all available migrations when
86    /// `target_version` is `None`.
87    ///
88    /// Returns the number of migrations applied in this call.
89    pub fn apply_migrations(&mut self, target_version: Option<i64>) -> anyhow::Result<usize> {
90        configure_connection(&self.connection)?;
91        let tx = self
92            .connection
93            .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
94            .context("failed to begin durable migration transaction")?;
95
96        let result = (|| -> anyhow::Result<usize> {
97            ensure_migrations_table(&tx)?;
98
99            let applied = applied_migration_ids(&tx)?;
100            let max_applied = applied.iter().copied().max().unwrap_or(0);
101            let max_known = embedded_migrations::MIGRATIONS
102                .last()
103                .map(|migration| migration.id)
104                .unwrap_or(0);
105            let target = target_version.unwrap_or(max_known);
106
107            if target < max_applied {
108                bail!(
109                    "target migration version {target} is older than applied version {max_applied}"
110                );
111            }
112            if target > max_known {
113                bail!(
114                    "target migration version {target} is newer than available version {max_known}"
115                );
116            }
117
118            let mut applied_count = 0usize;
119            for migration in embedded_migrations::MIGRATIONS
120                .iter()
121                .filter(|migration| migration.id <= target)
122            {
123                if applied.contains(&migration.id) {
124                    continue;
125                }
126                tx.execute_batch(migration.sql).with_context(|| {
127                    format!("failed to apply durable migration {}", migration.id)
128                })?;
129                tx.execute(
130                    r#"
131                    INSERT INTO sw_migrations (
132                        id,
133                        introduced_version,
134                        applied_time
135                    )
136                    VALUES (?1, ?2, ?3)
137                    "#,
138                    params![migration.id, migration.introduced_version, now_ms()],
139                )
140                .with_context(|| format!("failed to record durable migration {}", migration.id))?;
141                applied_count += 1;
142            }
143            Ok(applied_count)
144        })();
145
146        match result {
147            Ok(applied_count) => {
148                tx.commit()
149                    .context("failed to commit durable migration transaction")?;
150                Ok(applied_count)
151            }
152            Err(error) => {
153                let _ = tx.rollback();
154                Err(error)
155            }
156        }
157    }
158
159    /// Return applied migrations in ascending id order.
160    pub fn migration_records(&self) -> anyhow::Result<Vec<MigrationRecord>> {
161        ensure_migrations_table(&self.connection)?;
162        let mut statement = self
163            .connection
164            .prepare(
165                r#"
166                SELECT id, introduced_version, applied_time
167                FROM sw_migrations
168                ORDER BY id
169                "#,
170            )
171            .context("failed to prepare durable migration records query")?;
172        let rows = statement
173            .query_map([], |row| {
174                Ok(MigrationRecord {
175                    id: row.get(0)?,
176                    introduced_version: row.get(1)?,
177                    applied_time: row.get(2)?,
178                })
179            })
180            .context("failed to query durable migration records")?;
181
182        let mut records = Vec::new();
183        for row in rows {
184            records.push(row.context("failed to read durable migration record")?);
185        }
186        Ok(records)
187    }
188
189    /// Return the latest applied migration id, or `0` when none are applied.
190    pub fn current_schema_version(&self) -> anyhow::Result<i64> {
191        ensure_migrations_table(&self.connection)?;
192        self.connection
193            .query_row(
194                r#"
195                SELECT COALESCE(MAX(id), 0)
196                FROM sw_migrations
197                "#,
198                [],
199                |row| row.get(0),
200            )
201            .context("failed to read durable schema version")
202    }
203}
204
205fn configure_connection(connection: &Connection) -> anyhow::Result<()> {
206    connection
207        .pragma_update(None, "foreign_keys", "ON")
208        .context("failed to enable SQLite foreign_keys")?;
209    connection
210        .busy_timeout(std::time::Duration::from_millis(5_000))
211        .context("failed to configure SQLite busy_timeout")?;
212
213    let journal_mode: String = connection
214        .pragma_query_value(None, "journal_mode", |row| row.get(0))
215        .context("failed to read SQLite journal_mode")?;
216    if !journal_mode.eq_ignore_ascii_case("memory") {
217        let mode: String = connection
218            .pragma_update_and_check(None, "journal_mode", "WAL", |row| row.get(0))
219            .context("failed to enable SQLite WAL journal_mode")?;
220        if !mode.eq_ignore_ascii_case("wal") {
221            return Err(anyhow!("expected SQLite journal_mode WAL, found {mode}"));
222        }
223    }
224
225    Ok(())
226}
227
228fn ensure_migrations_table(connection: &Connection) -> anyhow::Result<()> {
229    connection
230        .execute_batch(MIGRATIONS_TABLE_SQL)
231        .context("failed to ensure durable migrations table")
232}
233
234fn applied_migration_ids(
235    connection: &Connection,
236) -> anyhow::Result<std::collections::HashSet<i64>> {
237    let mut statement = connection
238        .prepare(
239            r#"
240            SELECT id
241            FROM sw_migrations
242            ORDER BY id
243            "#,
244        )
245        .context("failed to prepare applied durable migrations query")?;
246    let rows = statement
247        .query_map([], |row| row.get::<_, i64>(0))
248        .context("failed to query applied durable migrations")?;
249    let mut ids = std::collections::HashSet::new();
250    for row in rows {
251        ids.insert(row.context("failed to read applied durable migration id")?);
252    }
253    Ok(ids)
254}
255
256pub(crate) fn now_ms() -> i64 {
257    SystemTime::now()
258        .duration_since(UNIX_EPOCH)
259        .unwrap_or_default()
260        .as_millis() as i64
261}
262
263/// Generate a durable-engine ID with a stable lowercase text form.
264///
265/// ULID's canonical display form is uppercase Crockford Base32, but these IDs are
266/// used as opaque workflow/database identifiers rather than values that users
267/// need to parse as canonical ULIDs. We normalize the suffix to lowercase so IDs
268/// remain consistent with the snake_case prefixes (`run_`, `task_`, `step_`,
269/// `budget_`) and are easier to copy through logs, URLs, shells, and
270/// case-sensitive external systems without mixed-case surprises.
271///
272/// Any code that compares durable IDs should treat this lowercase form as the
273/// stored/canonical engine representation.
274pub(crate) fn new_id(prefix: &str) -> String {
275    format!(
276        "{prefix}_{}",
277        ulid::Ulid::new().to_string().to_ascii_lowercase()
278    )
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn generated_ids_use_lowercase_ulid_suffixes() {
287        for prefix in ["task", "run", "step", "budget"] {
288            let id = new_id(prefix);
289            assert!(id.starts_with(&format!("{prefix}_")));
290            assert_eq!(id, id.to_ascii_lowercase());
291        }
292    }
293
294    #[test]
295    fn initializes_schema_and_records_migration() {
296        let mut store = SqliteDurableStore::in_memory().expect("store should open");
297        let applied = store.init().expect("migrations should apply");
298        assert_eq!(applied, embedded_migrations::MIGRATIONS.len());
299        assert_eq!(
300            store.current_schema_version().unwrap(),
301            embedded_migrations::MIGRATIONS.last().unwrap().id
302        );
303
304        let records = store.migration_records().unwrap();
305        assert_eq!(records.len(), embedded_migrations::MIGRATIONS.len());
306        assert_eq!(records[0].id, 1);
307        assert_eq!(
308            records[0].introduced_version,
309            embedded_migrations::MIGRATIONS[0].introduced_version
310        );
311
312        let table_count: i64 = store
313            .connection()
314            .query_row(
315                r#"
316                SELECT COUNT(*)
317                FROM sqlite_master
318                WHERE type = 'table'
319                  AND name IN (
320                      'sw_workflow_tasks',
321                      'sw_workflow_runs',
322                      'sw_workflow_steps',
323                      'sw_budget_ledger'
324                  )
325                "#,
326                [],
327                |row| row.get(0),
328            )
329            .unwrap();
330        assert_eq!(table_count, 4);
331    }
332
333    #[test]
334    fn migrations_are_idempotent() {
335        let mut store = SqliteDurableStore::in_memory().expect("store should open");
336        assert_eq!(store.init().unwrap(), embedded_migrations::MIGRATIONS.len());
337        assert_eq!(store.init().unwrap(), 0);
338        assert_eq!(
339            store.migration_records().unwrap().len(),
340            embedded_migrations::MIGRATIONS.len()
341        );
342    }
343
344    #[test]
345    fn rejects_target_older_than_applied_version() {
346        let mut store = SqliteDurableStore::in_memory().expect("store should open");
347        store.init().unwrap();
348        let error = store.apply_migrations(Some(0)).unwrap_err();
349        assert!(
350            error.to_string().contains("older than applied"),
351            "unexpected error: {error:#}"
352        );
353    }
354}