rust_query/
migrate.rs

1use std::{
2    collections::{HashMap, HashSet},
3    convert::Infallible,
4    marker::PhantomData,
5    ops::Deref,
6    path::Path,
7    sync::atomic::AtomicI64,
8};
9
10use rusqlite::{Connection, config::DbConfig};
11use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder, TableDropStatement};
12use self_cell::MutBorrow;
13
14use crate::{
15    IntoExpr, Lazy, Table, TableRow, Transaction,
16    alias::{Scope, TmpTable},
17    hash,
18    schema_pragma::read_schema,
19    transaction::{Database, OwnedTransaction, TXN, TransactionWithRows, try_insert_private},
20};
21
22pub struct TableTypBuilder<S> {
23    pub(crate) ast: hash::Schema,
24    _p: PhantomData<S>,
25}
26
27impl<S> Default for TableTypBuilder<S> {
28    fn default() -> Self {
29        Self {
30            ast: Default::default(),
31            _p: Default::default(),
32        }
33    }
34}
35
36impl<S> TableTypBuilder<S> {
37    pub fn table<T: Table<Schema = S>>(&mut self) {
38        let table = hash::Table::new::<T>();
39        let old = self.ast.tables.insert(T::NAME.to_owned(), table);
40        debug_assert!(old.is_none());
41    }
42}
43
44pub trait Schema: Sized + 'static {
45    const VERSION: i64;
46    fn typs(b: &mut TableTypBuilder<Self>);
47}
48
49pub trait Migration {
50    type FromSchema: 'static;
51    type From: Table<Schema = Self::FromSchema>;
52    type To: Table<MigrateFrom = Self::From>;
53    type Conflict;
54
55    #[doc(hidden)]
56    fn prepare(
57        val: Self,
58        prev: crate::Expr<'static, Self::FromSchema, Self::From>,
59    ) -> <Self::To as Table>::Insert;
60    #[doc(hidden)]
61    fn map_conflict(val: TableRow<Self::From>) -> Self::Conflict;
62}
63
64/// Transaction type for use in migrations.
65pub struct TransactionMigrate<FromSchema> {
66    inner: Transaction<FromSchema>,
67    scope: Scope,
68    rename_map: HashMap<&'static str, TmpTable>,
69    // creating indices is delayed so that they don't need to be renamed
70    extra_index: Vec<String>,
71}
72
73impl<FromSchema> Deref for TransactionMigrate<FromSchema> {
74    type Target = Transaction<FromSchema>;
75
76    fn deref(&self) -> &Self::Target {
77        &self.inner
78    }
79}
80
81impl<FromSchema: 'static> TransactionMigrate<FromSchema> {
82    fn new_table_name<T: Table>(&mut self) -> TmpTable {
83        *self.rename_map.entry(T::NAME).or_insert_with(|| {
84            let new_table_name = self.scope.tmp_table();
85            TXN.with_borrow(|txn| {
86                let conn = txn.as_ref().unwrap().get();
87                let table = crate::hash::Table::new::<T>();
88                let extra_indices = new_table_inner(conn, &table, new_table_name, T::NAME);
89                self.extra_index.extend(extra_indices);
90            });
91            new_table_name
92        })
93    }
94
95    fn unmigrated<M: Migration<FromSchema = FromSchema>>(
96        &self,
97        new_name: TmpTable,
98    ) -> impl Iterator<Item = TableRow<M::From>> {
99        let data = self.inner.query(|rows| {
100            let old = rows.join_private::<M::From>();
101            rows.into_vec(old)
102        });
103
104        let migrated = Transaction::new().query(|rows| {
105            let new = rows.join_tmp::<M::From>(new_name);
106            rows.into_vec(new)
107        });
108        let migrated: HashSet<_> = migrated.into_iter().map(|x| x.inner.idx).collect();
109
110        data.into_iter()
111            .filter(move |row| !migrated.contains(&row.inner.idx))
112    }
113
114    /// Migrate some rows to the new schema.
115    ///
116    /// This will return an error when there is a conflict.
117    /// The error type depends on the number of unique constraints that the
118    /// migration can violate:
119    /// - 0 => [Infallible]
120    /// - 1.. => `TableRow<T::From>` (row in the old table that could not be migrated)
121    pub fn migrate_optional<'t, M: Migration<FromSchema = FromSchema>>(
122        &'t mut self,
123        mut f: impl FnMut(Lazy<'t, M::From>) -> Option<M>,
124    ) -> Result<(), M::Conflict> {
125        let new_name = self.new_table_name::<M::To>();
126
127        for row in self.unmigrated::<M>(new_name) {
128            if let Some(new) = f(self.lazy(row)) {
129                try_insert_private::<M::To>(
130                    new_name.into_table_ref(),
131                    Some(row.inner.idx),
132                    M::prepare(new, row.into_expr()),
133                )
134                .map_err(|_| M::map_conflict(row))?;
135            };
136        }
137        Ok(())
138    }
139
140    /// Migrate all rows to the new schema.
141    ///
142    /// Conflict errors work the same as in [Self::migrate_optional].
143    ///
144    /// However, this method will return [Migrated] when all rows are migrated.
145    /// This can then be used as proof that there will be no foreign key violations.
146    pub fn migrate<'t, M: Migration<FromSchema = FromSchema>>(
147        &'t mut self,
148        mut f: impl FnMut(Lazy<'t, M::From>) -> M,
149    ) -> Result<Migrated<'static, FromSchema, M::To>, M::Conflict> {
150        self.migrate_optional::<M>(|x| Some(f(x)))?;
151
152        Ok(Migrated {
153            _p: PhantomData,
154            f: Box::new(|_| {}),
155            _local: PhantomData,
156        })
157    }
158
159    /// Helper method for [Self::migrate].
160    ///
161    /// It can only be used when the migration is known to never cause unique constraint conflicts.
162    pub fn migrate_ok<'t, M: Migration<FromSchema = FromSchema, Conflict = Infallible>>(
163        &'t mut self,
164        f: impl FnMut(Lazy<'t, M::From>) -> M,
165    ) -> Migrated<'static, FromSchema, M::To> {
166        let Ok(res) = self.migrate(f);
167        res
168    }
169}
170
171pub struct SchemaBuilder<'t, FromSchema> {
172    inner: TransactionMigrate<FromSchema>,
173    drop: Vec<TableDropStatement>,
174    foreign_key: HashMap<&'static str, Box<dyn 't + FnOnce() -> Infallible>>,
175}
176
177impl<'t, FromSchema: 'static> SchemaBuilder<'t, FromSchema> {
178    pub fn foreign_key<To: Table>(&mut self, err: impl 't + FnOnce() -> Infallible) {
179        self.inner.new_table_name::<To>();
180
181        self.foreign_key.insert(To::NAME, Box::new(err));
182    }
183
184    pub fn create_empty<To: Table>(&mut self) {
185        self.inner.new_table_name::<To>();
186    }
187
188    pub fn drop_table<T: Table>(&mut self) {
189        let name = Alias::new(T::NAME);
190        let step = sea_query::Table::drop().table(name).take();
191        self.drop.push(step);
192    }
193}
194
195fn new_table_inner(
196    conn: &Connection,
197    table: &crate::hash::Table,
198    alias: impl IntoTableRef,
199    index_table: &str,
200) -> Vec<String> {
201    let mut extra_indices = Vec::new();
202    let mut create = table.create(&mut extra_indices);
203    create
204        .table(alias)
205        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
206    let mut sql = create.to_string(SqliteQueryBuilder);
207    sql.push_str(" STRICT");
208    conn.execute(&sql, []).unwrap();
209
210    let index_table_ref = Alias::new(index_table);
211    extra_indices
212        .into_iter()
213        .enumerate()
214        .map(|(index_num, mut index)| {
215            index
216                .table(index_table_ref.clone())
217                .name(format!("{index_table}_index_{index_num}"))
218                .to_string(SqliteQueryBuilder)
219        })
220        .collect()
221}
222
223pub trait SchemaMigration<'a> {
224    type From: Schema;
225    type To: Schema;
226
227    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
228}
229
230/// [Config] is used to open a database from a file or in memory.
231///
232/// This is the first step in the [Config] -> [Migrator] -> [Database] chain to
233/// get a [Database] instance.
234///
235/// # Sqlite config
236///
237/// Sqlite is configured to be in [WAL mode](https://www.sqlite.org/wal.html).
238/// The effect of this mode is that there can be any number of readers with one concurrent writer.
239/// What is nice about this is that a `&`[crate::Transaction] can always be made immediately.
240/// Making a `&mut`[crate::Transaction] has to wait until all other `&mut`[crate::Transaction]s are finished.
241pub struct Config {
242    manager: r2d2_sqlite::SqliteConnectionManager,
243    init: Box<dyn FnOnce(&rusqlite::Transaction)>,
244    /// Configure how often SQLite will synchronize the database to disk.
245    ///
246    /// The default is [Synchronous::Full].
247    pub synchronous: Synchronous,
248    /// Configure how foreign keys should be checked.
249    pub foreign_keys: ForeignKeys,
250}
251
252/// <https://www.sqlite.org/pragma.html#pragma_synchronous>
253///
254/// Note that the database uses WAL mode, so make sure to read the WAL specific section.
255#[non_exhaustive]
256pub enum Synchronous {
257    /// SQLite will fsync after every transaction.
258    ///
259    /// Transactions are durable, even following a power failure or hard reboot.
260    Full,
261
262    /// SQLite will only do essential fsync to prevent corruption.
263    ///
264    /// The database will not rollback transactions due to application crashes, but it might rollback due to a hardware reset or power loss.
265    /// Use this when performance is more important than durability.
266    Normal,
267}
268
269impl Synchronous {
270    fn as_str(self) -> &'static str {
271        match self {
272            Synchronous::Full => "FULL",
273            Synchronous::Normal => "NORMAL",
274        }
275    }
276}
277
278/// Which method should be used to check foreign-key constraints.
279///
280/// The default is [ForeignKeys::SQLite], but this is likely to change to [ForeignKeys::Rust].
281#[non_exhaustive]
282pub enum ForeignKeys {
283    /// Foreign-key constraints are checked by rust-query only.
284    ///
285    /// Most foreign-key checks are done at compile time and are thus completely free.
286    /// However, some runtime checks are required for deletes.
287    Rust,
288
289    /// Foreign-key constraints are checked by SQLite in addition to the checks done by rust-query.
290    ///
291    /// This is useful when using rust-query with [crate::TransactionWeak::rusqlite_transaction]
292    /// or when other software can write to the database.
293    /// Both can result in "dangling" foreign keys (which point at a non-existent row) if written incorrectly.
294    /// Dangling foreign keys can result in wrong results, but these dangling foreign keys can also turn
295    /// into "false" foreign keys if a new record is inserted that makes the foreign key valid.
296    /// This is a lot worse than a dangling foreign key, because it is generally not possible to detect.
297    ///
298    /// With the [ForeignKeys::SQLite] option, rust-query will prevent creating such false foreign keys
299    /// and panic instead.
300    /// The downside is that indexes are required on all foreign keys to make the checks efficient.
301    SQLite,
302}
303
304impl ForeignKeys {
305    fn as_str(self) -> &'static str {
306        match self {
307            ForeignKeys::Rust => "OFF",
308            ForeignKeys::SQLite => "ON",
309        }
310    }
311}
312
313impl Config {
314    /// Open a database that is stored in a file.
315    /// Creates the database if it does not exist.
316    ///
317    /// Opening the same database multiple times at the same time is fine,
318    /// as long as they migrate to or use the same schema.
319    /// All locking is done by sqlite, so connections can even be made using different client implementations.
320    pub fn open(p: impl AsRef<Path>) -> Self {
321        let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
322        Self::open_internal(manager)
323    }
324
325    /// Creates a new empty database in memory.
326    pub fn open_in_memory() -> Self {
327        let manager = r2d2_sqlite::SqliteConnectionManager::memory();
328        Self::open_internal(manager)
329    }
330
331    fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
332        Self {
333            manager,
334            init: Box::new(|_| {}),
335            synchronous: Synchronous::Full,
336            foreign_keys: ForeignKeys::SQLite,
337        }
338    }
339
340    /// Append a raw sql statement to be executed if the database was just created.
341    ///
342    /// The statement is executed after creating the empty database and executing all previous statements.
343    pub fn init_stmt(mut self, sql: &'static str) -> Self {
344        self.init = Box::new(move |txn| {
345            (self.init)(txn);
346
347            txn.execute_batch(sql)
348                .expect("raw sql statement to populate db failed");
349        });
350        self
351    }
352}
353
354impl<S: Schema> Database<S> {
355    /// Create a [Migrator] to migrate a database.
356    ///
357    /// Returns [None] if the database `user_version` on disk is older than `S`.
358    pub fn migrator(config: Config) -> Option<Migrator<S>> {
359        let synchronous = config.synchronous.as_str();
360        let foreign_keys = config.foreign_keys.as_str();
361        let manager = config.manager.with_init(move |inner| {
362            inner.pragma_update(None, "journal_mode", "WAL")?;
363            inner.pragma_update(None, "synchronous", synchronous)?;
364            inner.pragma_update(None, "foreign_keys", foreign_keys)?;
365            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
366            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
367            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
368            Ok(())
369        });
370
371        use r2d2::ManageConnection;
372        let conn = manager.connect().unwrap();
373        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
374        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
375            Some(
376                conn.borrow_mut()
377                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
378                    .unwrap(),
379            )
380        });
381
382        // check if this database is newly created
383        if schema_version(txn.get()) == 0 {
384            let schema = crate::hash::Schema::new::<S>();
385
386            for (table_name, table) in &schema.tables {
387                let table_name_ref = Alias::new(table_name);
388                let extra_indices = new_table_inner(txn.get(), table, table_name_ref, table_name);
389                for stmt in extra_indices {
390                    txn.get().execute(&stmt, []).unwrap();
391                }
392            }
393            (config.init)(txn.get());
394            set_user_version(txn.get(), S::VERSION).unwrap();
395        }
396
397        let user_version = user_version(txn.get()).unwrap();
398        // We can not migrate databases older than `S`
399        if user_version < S::VERSION {
400            return None;
401        }
402        debug_assert_eq!(
403            foreign_key_check(txn.get()),
404            None,
405            "foreign key constraint violated"
406        );
407
408        Some(Migrator {
409            manager,
410            transaction: txn,
411            _p: PhantomData,
412        })
413    }
414}
415
416/// [Migrator] is used to apply database migrations.
417///
418/// When all migrations are done, it can be turned into a [Database] instance with
419/// [Migrator::finish].
420pub struct Migrator<S> {
421    manager: r2d2_sqlite::SqliteConnectionManager,
422    transaction: OwnedTransaction,
423    _p: PhantomData<S>,
424}
425
426/// [Migrated] provides a proof of migration.
427///
428/// This only needs to be provided for tables that are migrated from a previous table.
429pub struct Migrated<'t, FromSchema, T> {
430    _p: PhantomData<T>,
431    f: Box<dyn 't + FnOnce(&mut SchemaBuilder<'t, FromSchema>)>,
432    _local: PhantomData<*const ()>,
433}
434
435impl<'t, FromSchema: 'static, T: Table> Migrated<'t, FromSchema, T> {
436    /// Don't migrate the remaining rows.
437    ///
438    /// This can cause foreign key constraint violations, which is why an error callback needs to be provided.
439    pub fn map_fk_err(err: impl 't + FnOnce() -> Infallible) -> Self {
440        Self {
441            _p: PhantomData,
442            f: Box::new(|x| x.foreign_key::<T>(err)),
443            _local: PhantomData,
444        }
445    }
446
447    #[doc(hidden)]
448    pub fn apply(self, b: &mut SchemaBuilder<'t, FromSchema>) {
449        (self.f)(b)
450    }
451}
452
453impl<S: Schema> Migrator<S> {
454    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
455    ///
456    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
457    pub fn migrate<'x, M>(
458        mut self,
459        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
460    ) -> Migrator<M::To>
461    where
462        M: SchemaMigration<'x, From = S>,
463    {
464        if user_version(self.transaction.get()).unwrap() == S::VERSION {
465            let res = std::thread::scope(|s| {
466                s.spawn(|| {
467                    TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
468
469                    check_schema::<S>();
470
471                    let mut txn = TransactionMigrate {
472                        inner: Transaction::new(),
473                        scope: Default::default(),
474                        rename_map: HashMap::new(),
475                        extra_index: Vec::new(),
476                    };
477                    let m = m(&mut txn);
478
479                    let mut builder = SchemaBuilder {
480                        drop: vec![],
481                        foreign_key: HashMap::new(),
482                        inner: txn,
483                    };
484                    m.tables(&mut builder);
485
486                    let transaction = TXN.take().unwrap();
487
488                    for drop in builder.drop {
489                        let sql = drop.to_string(SqliteQueryBuilder);
490                        transaction.get().execute(&sql, []).unwrap();
491                    }
492                    for (to, tmp) in builder.inner.rename_map {
493                        let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
494                        let sql = rename.to_string(SqliteQueryBuilder);
495                        transaction.get().execute(&sql, []).unwrap();
496                    }
497                    if let Some(fk) = foreign_key_check(transaction.get()) {
498                        (builder.foreign_key.remove(&*fk).unwrap())();
499                    }
500                    #[allow(
501                        unreachable_code,
502                        reason = "rustc is stupid and thinks this is unreachable"
503                    )]
504                    // adding indexes is fine to do after checking foreign keys
505                    for stmt in builder.inner.extra_index {
506                        transaction.get().execute(&stmt, []).unwrap();
507                    }
508                    set_user_version(transaction.get(), M::To::VERSION).unwrap();
509
510                    transaction.into_owner()
511                })
512                .join()
513            });
514            match res {
515                Ok(val) => self.transaction = val,
516                Err(payload) => std::panic::resume_unwind(payload),
517            }
518        }
519
520        Migrator {
521            manager: self.manager,
522            transaction: self.transaction,
523            _p: PhantomData,
524        }
525    }
526
527    /// Commit the migration transaction and return a [Database].
528    ///
529    /// Returns [None] if the database schema version is newer than `S`.
530    ///
531    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
532    pub fn finish(mut self) -> Option<Database<S>> {
533        let conn = &self.transaction;
534        if user_version(conn.get()).unwrap() != S::VERSION {
535            return None;
536        }
537
538        let res = std::thread::scope(|s| {
539            s.spawn(|| {
540                TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
541                check_schema::<S>();
542                TXN.take().unwrap().into_owner()
543            })
544            .join()
545        });
546        match res {
547            Ok(val) => self.transaction = val,
548            Err(payload) => std::panic::resume_unwind(payload),
549        }
550
551        // adds an sqlite_stat1 table
552        self.transaction
553            .get()
554            .execute_batch("PRAGMA optimize;")
555            .unwrap();
556
557        let schema_version = schema_version(self.transaction.get());
558        self.transaction.with(|x| x.commit().unwrap());
559
560        Some(Database {
561            manager: self.manager,
562            schema_version: AtomicI64::new(schema_version),
563            schema: PhantomData,
564            mut_lock: parking_lot::FairMutex::new(()),
565        })
566    }
567}
568
569pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
570    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
571        .unwrap()
572}
573
574// Read user version field from the SQLite db
575pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
576    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
577}
578
579// Set user version field from the SQLite db
580fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
581    conn.pragma_update(None, "user_version", v)
582}
583
584pub(crate) fn check_schema<S: Schema>() {
585    // normalize both sides, because we only care about compatibility
586    pretty_assertions::assert_eq!(
587        crate::hash::Schema::new::<S>().normalize(),
588        read_schema(&crate::Transaction::new()).normalize(),
589        "schema is different (expected left, but got right)",
590    );
591}
592
593fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
594    let error = conn
595        .prepare("PRAGMA foreign_key_check")
596        .unwrap()
597        .query_map([], |row| row.get(2))
598        .unwrap()
599        .next();
600    error.transpose().unwrap()
601}
602
603#[test]
604fn open_multiple() {
605    #[crate::migration::schema(Empty)]
606    pub mod vN {}
607
608    let _a = Database::<v0::Empty>::migrator(Config::open_in_memory());
609    let _b = Database::<v0::Empty>::migrator(Config::open_in_memory());
610}