rust_query/
migrate.rs

1use std::{
2    collections::{HashMap, HashSet},
3    convert::Infallible,
4    marker::PhantomData,
5    ops::{Deref, Not},
6    path::Path,
7};
8
9use rusqlite::{Connection, config::DbConfig};
10use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder, TableDropStatement};
11use self_cell::MutBorrow;
12
13use crate::{
14    FromExpr, IntoExpr, Table, TableRow, Transaction,
15    alias::{Scope, TmpTable},
16    hash,
17    schema_pragma::read_schema,
18    transaction::{Database, OwnedTransaction, TXN, try_insert_private},
19};
20
21pub struct TableTypBuilder<S> {
22    pub(crate) ast: hash::Schema,
23    _p: PhantomData<S>,
24}
25
26impl<S> Default for TableTypBuilder<S> {
27    fn default() -> Self {
28        Self {
29            ast: Default::default(),
30            _p: Default::default(),
31        }
32    }
33}
34
35impl<S> TableTypBuilder<S> {
36    pub fn table<T: Table<Schema = S>>(&mut self) {
37        let mut b = hash::TypBuilder::default();
38        T::typs(&mut b);
39        self.ast.tables.insert((T::NAME.to_owned(), b.ast));
40    }
41}
42
43pub trait Schema: Sized + 'static {
44    const VERSION: i64;
45    fn typs(b: &mut TableTypBuilder<Self>);
46}
47
48pub trait Migration {
49    type FromSchema: 'static;
50    type From: Table<Schema = Self::FromSchema>;
51    type To: Table<MigrateFrom = Self::From>;
52    type Conflict;
53
54    #[doc(hidden)]
55    fn prepare(
56        val: Self,
57        prev: crate::Expr<'static, Self::FromSchema, Self::From>,
58    ) -> <Self::To as Table>::Insert;
59    #[doc(hidden)]
60    fn map_conflict(val: TableRow<Self::From>) -> Self::Conflict;
61}
62
63/// Transaction type for use in migrations.
64pub struct TransactionMigrate<FromSchema> {
65    inner: Transaction<FromSchema>,
66    scope: Scope,
67    rename_map: HashMap<&'static str, TmpTable>,
68}
69
70impl<FromSchema> Deref for TransactionMigrate<FromSchema> {
71    type Target = Transaction<FromSchema>;
72
73    fn deref(&self) -> &Self::Target {
74        &self.inner
75    }
76}
77
78impl<FromSchema> TransactionMigrate<FromSchema> {
79    fn new_table_name<T: Table>(&mut self) -> TmpTable {
80        *self.rename_map.entry(T::NAME).or_insert_with(|| {
81            let new_table_name = self.scope.tmp_table();
82            TXN.with_borrow(|txn| new_table::<T>(txn.as_ref().unwrap().get(), new_table_name));
83            new_table_name
84        })
85    }
86
87    fn unmigrated<M: Migration<FromSchema = FromSchema>, Out>(
88        &self,
89        new_name: TmpTable,
90    ) -> impl Iterator<Item = (i64, Out)>
91    where
92        Out: FromExpr<FromSchema, M::From>,
93    {
94        let data = self.inner.query(|rows| {
95            let old = rows.join(<M::From as Table>::TOKEN);
96            rows.into_vec((&old, Out::from_expr(&old)))
97        });
98
99        let migrated = Transaction::new().query(|rows| {
100            let new = rows.join_tmp::<M::From>(new_name);
101            rows.into_vec(new)
102        });
103        let migrated: HashSet<_> = migrated.into_iter().map(|x| x.inner.idx).collect();
104
105        data.into_iter().filter_map(move |(row, data)| {
106            migrated
107                .contains(&row.inner.idx)
108                .not()
109                .then_some((row.inner.idx, data))
110        })
111    }
112
113    /// Migrate some rows to the new schema.
114    ///
115    /// This will return an error when there is a conflict.
116    /// The error type depends on the number of unique constraints that the
117    /// migration can violate:
118    /// - 0 => [Infallible]
119    /// - 1.. => `TableRow<T::From>` (row in the old table that could not be migrated)
120    pub fn migrate_optional<
121        M: Migration<FromSchema = FromSchema>,
122        X: FromExpr<FromSchema, M::From>,
123    >(
124        &mut self,
125        mut f: impl FnMut(X) -> Option<M>,
126    ) -> Result<(), M::Conflict> {
127        let new_name = self.new_table_name::<M::To>();
128
129        for (idx, x) in self.unmigrated::<M, X>(new_name) {
130            if let Some(new) = f(x) {
131                try_insert_private::<M::To>(
132                    new_name.into_table_ref(),
133                    Some(idx),
134                    M::prepare(new, TableRow::new(idx).into_expr()),
135                )
136                .map_err(|_| M::map_conflict(TableRow::new(idx)))?;
137            };
138        }
139        Ok(())
140    }
141
142    /// Migrate all rows to the new schema.
143    ///
144    /// Conflict errors work the same as in [Self::migrate_optional].
145    ///
146    /// However, this method will return [Migrated] when all rows are migrated.
147    /// This can then be used as proof that there will be no foreign key violations.
148    pub fn migrate<M: Migration<FromSchema = FromSchema>, X: FromExpr<FromSchema, M::From>>(
149        &mut self,
150        mut f: impl FnMut(X) -> M,
151    ) -> Result<Migrated<'static, FromSchema, M::To>, M::Conflict> {
152        self.migrate_optional::<M, X>(|x| Some(f(x)))?;
153
154        Ok(Migrated {
155            _p: PhantomData,
156            f: Box::new(|_| {}),
157            _local: PhantomData,
158        })
159    }
160
161    /// Helper method for [Self::migrate].
162    ///
163    /// It can only be used when the migration is known to never cause unique constraint conflicts.
164    pub fn migrate_ok<
165        M: Migration<FromSchema = FromSchema, Conflict = Infallible>,
166        X: FromExpr<FromSchema, M::From>,
167    >(
168        &mut self,
169        f: impl FnMut(X) -> M,
170    ) -> Migrated<'static, FromSchema, M::To> {
171        let Ok(res) = self.migrate(f);
172        res
173    }
174}
175
176pub struct SchemaBuilder<'t, FromSchema> {
177    inner: TransactionMigrate<FromSchema>,
178    drop: Vec<TableDropStatement>,
179    foreign_key: HashMap<&'static str, Box<dyn 't + FnOnce() -> Infallible>>,
180}
181
182impl<'t, FromSchema: 'static> SchemaBuilder<'t, FromSchema> {
183    pub fn foreign_key<To: Table>(&mut self, err: impl 't + FnOnce() -> Infallible) {
184        self.inner.new_table_name::<To>();
185
186        self.foreign_key.insert(To::NAME, Box::new(err));
187    }
188
189    pub fn create_empty<To: Table>(&mut self) {
190        self.inner.new_table_name::<To>();
191    }
192
193    pub fn drop_table<T: Table>(&mut self) {
194        let name = Alias::new(T::NAME);
195        let step = sea_query::Table::drop().table(name).take();
196        self.drop.push(step);
197    }
198}
199
200fn new_table<T: Table>(conn: &Connection, alias: TmpTable) {
201    let mut f = crate::hash::TypBuilder::default();
202    T::typs(&mut f);
203    new_table_inner(conn, &f.ast, alias);
204}
205
206fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
207    let mut create = table.create();
208    create
209        .table(alias)
210        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
211    let mut sql = create.to_string(SqliteQueryBuilder);
212    sql.push_str(" STRICT");
213    conn.execute(&sql, []).unwrap();
214}
215
216pub trait SchemaMigration<'a> {
217    type From: Schema;
218    type To: Schema;
219
220    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
221}
222
223/// [Config] is used to open a database from a file or in memory.
224///
225/// This is the first step in the [Config] -> [Migrator] -> [Database] chain to
226/// get a [Database] instance.
227///
228/// # Sqlite config
229///
230/// Sqlite is configured to be in [WAL mode](https://www.sqlite.org/wal.html).
231/// The effect of this mode is that there can be any number of readers with one concurrent writer.
232/// What is nice about this is that a `&`[crate::Transaction] can always be made immediately.
233/// Making a `&mut`[crate::Transaction] has to wait until all other `&mut`[crate::Transaction]s are finished.
234pub struct Config {
235    manager: r2d2_sqlite::SqliteConnectionManager,
236    init: Box<dyn FnOnce(&rusqlite::Transaction)>,
237    /// Configure how often SQLite will synchronize the database to disk.
238    ///
239    /// The default is [Synchronous::Full].
240    pub synchronous: Synchronous,
241}
242
243/// <https://www.sqlite.org/pragma.html#pragma_synchronous>
244///
245/// Note that the database uses WAL mode, so make sure to read the WAL specific section.
246#[non_exhaustive]
247pub enum Synchronous {
248    /// SQLite will fsync after every transaction.
249    ///
250    /// Transactions are durable, even following a power failure or hard reboot.
251    Full,
252
253    /// SQLite will only do essential fsync to prevent corruption.
254    ///
255    /// The database will not rollback transactions due to application crashes, but it might rollback due to a hardware reset or power loss.
256    /// Use this when performance is more important than durability.
257    Normal,
258}
259
260impl Synchronous {
261    fn as_str(self) -> &'static str {
262        match self {
263            Synchronous::Full => "FULL",
264            Synchronous::Normal => "NORMAL",
265        }
266    }
267}
268
269impl Config {
270    /// Open a database that is stored in a file.
271    /// Creates the database if it does not exist.
272    ///
273    /// Opening the same database multiple times at the same time is fine,
274    /// as long as they migrate to or use the same schema.
275    /// All locking is done by sqlite, so connections can even be made using different client implementations.
276    pub fn open(p: impl AsRef<Path>) -> Self {
277        let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
278        Self::open_internal(manager)
279    }
280
281    /// Creates a new empty database in memory.
282    pub fn open_in_memory() -> Self {
283        let manager = r2d2_sqlite::SqliteConnectionManager::memory();
284        Self::open_internal(manager)
285    }
286
287    fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
288        Self {
289            manager,
290            init: Box::new(|_| {}),
291            synchronous: Synchronous::Full,
292        }
293    }
294
295    /// Append a raw sql statement to be executed if the database was just created.
296    ///
297    /// The statement is executed after creating the empty database and executing all previous statements.
298    pub fn init_stmt(mut self, sql: &'static str) -> Self {
299        self.init = Box::new(move |txn| {
300            (self.init)(txn);
301
302            txn.execute_batch(sql)
303                .expect("raw sql statement to populate db failed");
304        });
305        self
306    }
307}
308
309impl<S: Schema> Database<S> {
310    /// Create a [Migrator] to migrate a database.
311    ///
312    /// Returns [None] if the database `user_version` on disk is older than `S`.
313    pub fn migrator(config: Config) -> Option<Migrator<S>> {
314        let synchronous = config.synchronous.as_str();
315        let manager = config.manager.with_init(move |inner| {
316            inner.pragma_update(None, "journal_mode", "WAL")?;
317            inner.pragma_update(None, "synchronous", synchronous)?;
318            inner.pragma_update(None, "foreign_keys", "ON")?;
319            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
320            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
321            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
322            Ok(())
323        });
324
325        use r2d2::ManageConnection;
326        let conn = manager.connect().unwrap();
327        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
328        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
329            Some(
330                conn.borrow_mut()
331                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
332                    .unwrap(),
333            )
334        });
335
336        // check if this database is newly created
337        if schema_version(txn.get()) == 0 {
338            let mut b = TableTypBuilder::default();
339            S::typs(&mut b);
340
341            for (table_name, table) in &*b.ast.tables {
342                new_table_inner(txn.get(), table, Alias::new(table_name));
343            }
344            (config.init)(txn.get());
345            set_user_version(txn.get(), S::VERSION).unwrap();
346        }
347
348        let user_version = user_version(txn.get()).unwrap();
349        // We can not migrate databases older than `S`
350        if user_version < S::VERSION {
351            return None;
352        }
353        assert_eq!(
354            foreign_key_check(txn.get()),
355            None,
356            "foreign key constraint violated"
357        );
358
359        Some(Migrator {
360            manager,
361            transaction: txn,
362            _p: PhantomData,
363        })
364    }
365}
366
367/// [Migrator] is used to apply database migrations.
368///
369/// When all migrations are done, it can be turned into a [Database] instance with
370/// [Migrator::finish].
371pub struct Migrator<S> {
372    manager: r2d2_sqlite::SqliteConnectionManager,
373    transaction: OwnedTransaction,
374    _p: PhantomData<S>,
375}
376
377/// [Migrated] provides a proof of migration.
378///
379/// This only needs to be provided for tables that are migrated from a previous table.
380pub struct Migrated<'t, FromSchema, T> {
381    _p: PhantomData<T>,
382    f: Box<dyn 't + FnOnce(&mut SchemaBuilder<'t, FromSchema>)>,
383    _local: PhantomData<*const ()>,
384}
385
386impl<'t, FromSchema: 'static, T: Table> Migrated<'t, FromSchema, T> {
387    /// Don't migrate the remaining rows.
388    ///
389    /// This can cause foreign key constraint violations, which is why an error callback needs to be provided.
390    pub fn map_fk_err(err: impl 't + FnOnce() -> Infallible) -> Self {
391        Self {
392            _p: PhantomData,
393            f: Box::new(|x| x.foreign_key::<T>(err)),
394            _local: PhantomData,
395        }
396    }
397
398    #[doc(hidden)]
399    pub fn apply(self, b: &mut SchemaBuilder<'t, FromSchema>) {
400        (self.f)(b)
401    }
402}
403
404impl<S: Schema> Migrator<S> {
405    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
406    ///
407    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
408    pub fn migrate<'x, M>(
409        mut self,
410        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
411    ) -> Migrator<M::To>
412    where
413        M: SchemaMigration<'x, From = S>,
414    {
415        if user_version(self.transaction.get()).unwrap() == S::VERSION {
416            self.transaction = std::thread::scope(|s| {
417                s.spawn(|| {
418                    TXN.set(Some(self.transaction));
419
420                    check_schema::<S>();
421
422                    let mut txn = TransactionMigrate {
423                        inner: Transaction::new(),
424                        scope: Default::default(),
425                        rename_map: HashMap::new(),
426                    };
427                    let m = m(&mut txn);
428
429                    let mut builder = SchemaBuilder {
430                        drop: vec![],
431                        foreign_key: HashMap::new(),
432                        inner: txn,
433                    };
434                    m.tables(&mut builder);
435
436                    let transaction = TXN.take().unwrap();
437
438                    for drop in builder.drop {
439                        let sql = drop.to_string(SqliteQueryBuilder);
440                        transaction.get().execute(&sql, []).unwrap();
441                    }
442                    for (to, tmp) in builder.inner.rename_map {
443                        let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
444                        let sql = rename.to_string(SqliteQueryBuilder);
445                        transaction.get().execute(&sql, []).unwrap();
446                    }
447                    if let Some(fk) = foreign_key_check(transaction.get()) {
448                        (builder.foreign_key.remove(&*fk).unwrap())();
449                    }
450                    #[allow(
451                        unreachable_code,
452                        reason = "rustc is stupid and thinks this is unreachable"
453                    )]
454                    set_user_version(transaction.get(), M::To::VERSION).unwrap();
455
456                    transaction
457                })
458                .join()
459                .unwrap()
460            });
461        }
462
463        Migrator {
464            manager: self.manager,
465            transaction: self.transaction,
466            _p: PhantomData,
467        }
468    }
469
470    /// Commit the migration transaction and return a [Database].
471    ///
472    /// Returns [None] if the database schema version is newer than `S`.
473    ///
474    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
475    pub fn finish(mut self) -> Option<Database<S>> {
476        let conn = &self.transaction;
477        if user_version(conn.get()).unwrap() != S::VERSION {
478            return None;
479        }
480
481        self.transaction = std::thread::scope(|s| {
482            s.spawn(|| {
483                TXN.set(Some(self.transaction));
484                check_schema::<S>();
485                TXN.take().unwrap()
486            })
487            .join()
488            .unwrap()
489        });
490
491        // adds an sqlite_stat1 table
492        self.transaction
493            .get()
494            .execute_batch("PRAGMA optimize;")
495            .unwrap();
496
497        let schema_version = schema_version(self.transaction.get());
498        self.transaction.with(|x| x.commit().unwrap());
499
500        Some(Database {
501            manager: self.manager,
502            schema_version,
503            schema: PhantomData,
504        })
505    }
506}
507
508pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
509    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
510        .unwrap()
511}
512
513// Read user version field from the SQLite db
514fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
515    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
516}
517
518// Set user version field from the SQLite db
519fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
520    conn.pragma_update(None, "user_version", v)
521}
522
523fn check_schema<S: Schema>() {
524    let mut b = TableTypBuilder::default();
525    S::typs(&mut b);
526    pretty_assertions::assert_eq!(
527        b.ast,
528        read_schema(&crate::Transaction::new()),
529        "schema is different (expected left, but got right)",
530    );
531}
532
533fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
534    let error = conn
535        .prepare("PRAGMA foreign_key_check")
536        .unwrap()
537        .query_map([], |row| row.get(2))
538        .unwrap()
539        .next();
540    error.transpose().unwrap()
541}
542
543#[test]
544fn open_multiple() {
545    #[crate::migration::schema(Empty)]
546    pub mod vN {}
547
548    let _a = Database::<v0::Empty>::migrator(Config::open_in_memory());
549    let _b = Database::<v0::Empty>::migrator(Config::open_in_memory());
550}