Skip to main content

rust_query/
migrate.rs

1pub mod config;
2pub mod migration;
3#[cfg(test)]
4mod test;
5
6use std::{
7    collections::{BTreeSet, HashMap},
8    marker::PhantomData,
9    sync::atomic::AtomicI64,
10};
11
12use annotate_snippets::{Renderer, renderer::DecorStyle};
13use rusqlite::config::DbConfig;
14use sea_query::{Alias, ColumnDef, IntoIden, SqliteQueryBuilder};
15use self_cell::MutBorrow;
16
17use crate::{
18    Table, Transaction,
19    alias::Scope,
20    migrate::{
21        config::Config,
22        migration::{SchemaBuilder, TransactionMigrate},
23    },
24    pool::Pool,
25    schema::{from_db, from_macro, read::read_schema},
26    transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
27};
28
29pub struct TableTypBuilder<S> {
30    pub(crate) ast: from_macro::Schema,
31    _p: PhantomData<S>,
32}
33
34impl<S> Default for TableTypBuilder<S> {
35    fn default() -> Self {
36        Self {
37            ast: Default::default(),
38            _p: Default::default(),
39        }
40    }
41}
42
43impl<S> TableTypBuilder<S> {
44    pub fn table<T: Table<Schema = S>>(&mut self) {
45        let table = from_macro::Table::new::<T>();
46        let old = self.ast.tables.insert(T::NAME, table);
47        debug_assert!(old.is_none());
48    }
49}
50
51pub trait Schema: Sized + 'static {
52    const VERSION: i64;
53    const SOURCE: &str;
54    const PATH: &str;
55    const SPAN: (usize, usize);
56    fn typs(b: &mut TableTypBuilder<Self>);
57}
58
59fn new_table_inner(table: &crate::schema::from_macro::Table, alias: impl IntoIden) -> String {
60    let alias = alias.into_iden();
61    let mut create = table.create();
62    create
63        .table(alias.clone())
64        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
65    let mut sql = create.to_string(SqliteQueryBuilder);
66    sql.push_str(" STRICT");
67    sql
68}
69
70pub trait SchemaMigration<'a> {
71    type From: Schema;
72    type To: Schema;
73
74    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
75}
76
77impl<S: Schema> Database<S> {
78    /// Create a [Migrator] to migrate a database.
79    ///
80    /// Returns [None] if the database `user_version` on disk is older than `S`.
81    pub fn migrator(config: Config) -> Option<Migrator<S>> {
82        let synchronous = config.synchronous.as_str();
83        let foreign_keys = config.foreign_keys.as_str();
84        let manager = config.manager.with_init(move |inner| {
85            inner.pragma_update(None, "journal_mode", "WAL")?;
86            inner.pragma_update(None, "synchronous", synchronous)?;
87            inner.pragma_update(None, "foreign_keys", foreign_keys)?;
88            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
89            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
90            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
91
92            #[cfg(feature = "bundled")]
93            inner.create_scalar_function(
94                "floor",
95                1,
96                rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
97                |ctx| {
98                    assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
99                    let res = ctx.get::<Option<f64>>(0)?.map(|x| x.floor());
100                    Ok(res)
101                },
102            )?;
103
104            #[cfg(feature = "bundled")]
105            inner.create_scalar_function(
106                "ceil",
107                1,
108                rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
109                |ctx| {
110                    assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
111                    let res = ctx.get::<Option<f64>>(0)?.map(|x| x.ceil());
112                    Ok(res)
113                },
114            )?;
115
116            #[cfg(feature = "jiff-02")]
117            inner.create_scalar_function(
118                "timestamp_add_nanosecond",
119                2,
120                rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
121                |ctx| {
122                    use crate::value::DbTyp;
123                    assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
124                    if matches!(ctx.get_raw(0), rusqlite::types::ValueRef::Null)
125                        || matches!(ctx.get_raw(1), rusqlite::types::ValueRef::Null)
126                    {
127                        return Ok(None);
128                    }
129
130                    let timestamp = jiff::Timestamp::from_sql(ctx.get_raw(0))?;
131                    let seconds = ctx.get::<i64>(1)?;
132                    let new = timestamp + jiff::SignedDuration::from_nanos(seconds);
133                    let sea_query::Value::String(Some(res)) = jiff::Timestamp::out_to_value(new)
134                    else {
135                        unreachable!("func always returns some string")
136                    };
137                    Ok(Some(res))
138                },
139            )?;
140
141            #[cfg(feature = "jiff-02")]
142            inner.create_scalar_function(
143                "timestamp_subsec_nanosecond",
144                1,
145                rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
146                |ctx| {
147                    use crate::value::DbTyp;
148                    assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
149                    if matches!(ctx.get_raw(0), rusqlite::types::ValueRef::Null) {
150                        return Ok(None);
151                    }
152
153                    let timestamp = jiff::Timestamp::from_sql(ctx.get_raw(0))?;
154                    Ok(Some(timestamp.subsec_nanosecond()))
155                },
156            )?;
157
158            #[cfg(feature = "jiff-02")]
159            inner.create_scalar_function(
160                "timestamp_to_second",
161                1,
162                rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
163                |ctx| {
164                    use crate::value::DbTyp;
165                    assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
166                    if matches!(ctx.get_raw(0), rusqlite::types::ValueRef::Null) {
167                        return Ok(None);
168                    }
169
170                    let timestamp = jiff::Timestamp::from_sql(ctx.get_raw(0))?;
171                    Ok(Some(timestamp.as_second()))
172                },
173            )?;
174
175            #[cfg(feature = "jiff-02")]
176            inner.create_scalar_function(
177                "timestamp_to_date",
178                2,
179                rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
180                |ctx| {
181                    use jiff::fmt::temporal;
182
183                    use crate::value::DbTyp;
184                    assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
185                    if matches!(ctx.get_raw(0), rusqlite::types::ValueRef::Null)
186                        || matches!(ctx.get_raw(1), rusqlite::types::ValueRef::Null)
187                    {
188                        return Ok(None);
189                    }
190
191                    static PARSER: temporal::DateTimeParser = temporal::DateTimeParser::new();
192
193                    let timestamp = jiff::Timestamp::from_sql(ctx.get_raw(0))?;
194                    let timezone = PARSER
195                        .parse_time_zone(ctx.get_raw(1).as_str()?)
196                        .expect("time zone was serialized with jiff");
197                    let date = timezone.to_datetime(timestamp).date();
198                    let sea_query::Value::String(Some(res)) = jiff::civil::Date::out_to_value(date)
199                    else {
200                        unreachable!("func always returns some string")
201                    };
202                    Ok(Some(res))
203                },
204            )?;
205
206            Ok(())
207        });
208
209        use r2d2::ManageConnection;
210        let conn = manager.connect().unwrap();
211        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
212        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
213            Some(
214                conn.borrow_mut()
215                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
216                    .unwrap(),
217            )
218        });
219
220        let mut user_version = Some(user_version(txn.get()).unwrap());
221
222        // check if this database is newly created
223        if schema_version(txn.get()) == 0 {
224            user_version = None;
225
226            let schema = crate::schema::from_macro::Schema::new::<S>();
227
228            for (&table_name, table) in &schema.tables {
229                txn.get()
230                    .execute(&new_table_inner(table, table_name), [])
231                    .unwrap();
232                for stmt in table.delayed_indices(table_name) {
233                    txn.get().execute(&stmt, []).unwrap();
234                }
235            }
236            (config.init)(txn.get());
237        } else if user_version.unwrap() < S::VERSION {
238            // We can not migrate databases older than `S`
239            return None;
240        }
241
242        debug_assert_eq!(
243            foreign_key_check(txn.get()),
244            None,
245            "foreign key constraint violated"
246        );
247
248        Some(Migrator {
249            user_version,
250            manager,
251            transaction: txn,
252            _p: PhantomData,
253        })
254    }
255}
256
257/// [Migrator] is used to apply database migrations.
258///
259/// When all migrations are done, it can be turned into a [Database] instance with
260/// [Migrator::finish].
261pub struct Migrator<S> {
262    manager: r2d2_sqlite::SqliteConnectionManager,
263    transaction: OwnedTransaction,
264    // Initialized to the user version when the transaction starts.
265    // This is set to None if the schema user_version is updated.
266    // Fixups are only applied if the user_version is None.
267    // Indices are fixed before this is set to None.
268    user_version: Option<i64>,
269    _p: PhantomData<S>,
270}
271
272impl<S: Schema> Migrator<S> {
273    fn with_transaction(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
274        assert!(self.user_version.is_none_or(|x| x == S::VERSION));
275        let res = std::thread::scope(|s| {
276            s.spawn(|| {
277                TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
278                let txn = Transaction::new_ref();
279
280                // check if this is the first migration that is applied
281                if self.user_version.take().is_some() {
282                    // we check the schema before doing any migrations
283                    check_schema::<S>(txn);
284                    // fixing indices before migrations can help with migration performance
285                    fix_indices::<S>(txn);
286                }
287
288                f(txn);
289
290                let transaction = TXN.take().unwrap();
291
292                transaction.into_owner()
293            })
294            .join()
295        });
296        match res {
297            Ok(val) => self.transaction = val,
298            Err(payload) => std::panic::resume_unwind(payload),
299        }
300        self
301    }
302
303    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
304    ///
305    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
306    pub fn migrate<'x, M>(
307        mut self,
308        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
309    ) -> Migrator<M::To>
310    where
311        M: SchemaMigration<'x, From = S>,
312    {
313        if self.user_version.is_none_or(|x| x == S::VERSION) {
314            self = self.with_transaction(|txn| {
315                let mut txn = TransactionMigrate {
316                    inner: txn.copy(),
317                    scope: Default::default(),
318                    rename_map: HashMap::new(),
319                    extra_index: Vec::new(),
320                };
321                let m = m(&mut txn);
322
323                let mut builder = SchemaBuilder {
324                    drop: vec![],
325                    foreign_key: HashMap::new(),
326                    inner: txn,
327                };
328                m.tables(&mut builder);
329
330                let transaction = TXN.take().unwrap();
331
332                for drop in builder.drop {
333                    let sql = drop.to_string(SqliteQueryBuilder);
334                    transaction.get().execute(&sql, []).unwrap();
335                }
336                for (to, tmp) in builder.inner.rename_map {
337                    let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
338                    let sql = rename.to_string(SqliteQueryBuilder);
339                    transaction.get().execute(&sql, []).unwrap();
340                }
341                #[allow(
342                    unreachable_code,
343                    reason = "rustc is stupid and thinks this is unreachable"
344                )]
345                if let Some(fk) = foreign_key_check(transaction.get()) {
346                    (builder.foreign_key.remove(&*fk).unwrap())();
347                }
348                // adding non unique indexes is fine to do after checking foreign keys
349                for stmt in builder.inner.extra_index {
350                    transaction.get().execute(&stmt, []).unwrap();
351                }
352
353                TXN.set(Some(transaction));
354            });
355        }
356
357        Migrator {
358            user_version: self.user_version,
359            manager: self.manager,
360            transaction: self.transaction,
361            _p: PhantomData,
362        }
363    }
364
365    /// Mutate the database as part of migrations.
366    ///
367    /// The closure will only be executed if the database got migrated to schema version `S`
368    /// by this [Migrator] instance.
369    /// If [Migrator::fixup] is used before all [Migrator::migrate], then the closures is only executed
370    /// when the database is created.
371    pub fn fixup(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
372        if self.user_version.is_none() {
373            self = self.with_transaction(f);
374        }
375        self
376    }
377
378    /// Commit the migration transaction and return a [Database].
379    ///
380    /// Returns [None] if the database schema version is newer than `S`.
381    ///
382    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
383    pub fn finish(mut self) -> Option<Database<S>> {
384        if self.user_version.is_some_and(|x| x != S::VERSION) {
385            return None;
386        }
387
388        // This checks that the schema is correct and fixes indices etc
389        self = self.with_transaction(|txn| {
390            // sanity check, this should never fail
391            check_schema::<S>(txn);
392        });
393
394        // adds an sqlite_stat1 table
395        self.transaction
396            .get()
397            .execute_batch("PRAGMA optimize;")
398            .unwrap();
399
400        set_user_version(self.transaction.get(), S::VERSION).unwrap();
401        let schema_version = schema_version(self.transaction.get());
402        self.transaction.with(|x| x.commit().unwrap());
403
404        Some(Database {
405            manager: Pool::new(self.manager),
406            schema_version: AtomicI64::new(schema_version),
407            schema: PhantomData,
408            mut_lock: parking_lot::FairMutex::new(()),
409        })
410    }
411}
412
413fn fix_indices<S: Schema>(txn: &Transaction<S>) {
414    let schema = read_schema(txn);
415    let expected_schema = crate::schema::from_macro::Schema::new::<S>();
416
417    fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
418        let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
419        let actual: BTreeSet<_> = actual.indices.values().collect();
420        expected == actual
421    }
422
423    for (&table_name, expected_table) in &expected_schema.tables {
424        let table = &schema.tables[table_name];
425
426        if !check_eq(expected_table, table) {
427            // Unique constraints that are part of a table definition
428            // can not be dropped, so we assume the worst and just recreate
429            // the whole table.
430
431            let scope = Scope::default();
432            let tmp_name = scope.tmp_table();
433
434            txn.execute(&new_table_inner(expected_table, tmp_name));
435
436            let mut columns: Vec<_> = expected_table.columns.keys().map(Alias::new).collect();
437            columns.push(Alias::new("id"));
438
439            txn.execute(
440                &sea_query::InsertStatement::new()
441                    .into_table(tmp_name)
442                    .columns(columns.clone())
443                    .select_from(
444                        sea_query::SelectStatement::new()
445                            .from(table_name)
446                            .columns(columns)
447                            .take(),
448                    )
449                    .unwrap()
450                    .build(SqliteQueryBuilder)
451                    .0,
452            );
453
454            txn.execute(
455                &sea_query::TableDropStatement::new()
456                    .table(table_name)
457                    .build(SqliteQueryBuilder),
458            );
459
460            txn.execute(
461                &sea_query::TableRenameStatement::new()
462                    .table(tmp_name, table_name)
463                    .build(SqliteQueryBuilder),
464            );
465            // Add the new non-unique indices
466            for sql in expected_table.delayed_indices(table_name) {
467                txn.execute(&sql);
468            }
469        }
470    }
471
472    // check that we solved the mismatch
473    let schema = read_schema(txn);
474    for (name, table) in schema.tables {
475        let expected_table = &expected_schema.tables[&*name];
476        assert!(check_eq(expected_table, &table));
477    }
478}
479
480impl<S> Transaction<S> {
481    #[track_caller]
482    pub(crate) fn execute(&self, sql: &str) {
483        TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
484            .unwrap();
485    }
486}
487
488pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
489    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
490        .unwrap()
491}
492
493// Read user version field from the SQLite db
494pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
495    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
496}
497
498// Set user version field from the SQLite db
499fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
500    conn.pragma_update(None, "user_version", v)
501}
502
503pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
504    let from_macro = crate::schema::from_macro::Schema::new::<S>();
505    let from_db = read_schema(txn);
506    let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
507    if !report.is_empty() {
508        let renderer = if cfg!(test) {
509            Renderer::plain().anonymized_line_numbers(true)
510        } else {
511            Renderer::styled()
512        }
513        .decor_style(DecorStyle::Unicode);
514        panic!("{}", renderer.render(&report));
515    }
516}
517
518fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
519    let error = conn
520        .prepare("PRAGMA foreign_key_check")
521        .unwrap()
522        .query_map([], |row| row.get(2))
523        .unwrap()
524        .next();
525    error.transpose().unwrap()
526}
527
528impl<S> Transaction<S> {
529    #[cfg(test)]
530    pub(crate) fn schema(&self) -> Vec<String> {
531        TXN.with_borrow(|x| {
532            x.as_ref()
533                .unwrap()
534                .get()
535                .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
536                .unwrap()
537                .query_map([], |row| row.get::<_, Option<String>>("sql"))
538                .unwrap()
539                .flat_map(|x| x.unwrap())
540                .collect()
541        })
542    }
543}