rust_query/
migrate.rs

1use std::{
2    collections::{HashMap, HashSet},
3    convert::Infallible,
4    marker::PhantomData,
5    ops::{Deref, Not},
6    path::Path,
7};
8
9use rusqlite::{Connection, config::DbConfig};
10use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder, TableDropStatement};
11use self_cell::MutBorrow;
12
13use crate::{
14    FromExpr, IntoExpr, Table, TableRow, Transaction,
15    alias::{Scope, TmpTable},
16    hash,
17    schema_pragma::read_schema,
18    transaction::{Database, OwnedTransaction, TXN, try_insert_private},
19};
20
21pub struct TableTypBuilder<S> {
22    pub(crate) ast: hash::Schema,
23    _p: PhantomData<S>,
24}
25
26impl<S> Default for TableTypBuilder<S> {
27    fn default() -> Self {
28        Self {
29            ast: Default::default(),
30            _p: Default::default(),
31        }
32    }
33}
34
35impl<S> TableTypBuilder<S> {
36    pub fn table<T: Table<Schema = S>>(&mut self) {
37        let mut b = hash::TypBuilder::default();
38        T::typs(&mut b);
39        self.ast.tables.insert((T::NAME.to_owned(), b.ast));
40    }
41}
42
43pub trait Schema: Sized + 'static {
44    const VERSION: i64;
45    fn typs(b: &mut TableTypBuilder<Self>);
46}
47
48pub trait Migration {
49    type FromSchema: 'static;
50    type From: Table<Schema = Self::FromSchema>;
51    type To: Table<MigrateFrom = Self::From>;
52    type Conflict;
53
54    #[doc(hidden)]
55    fn prepare(
56        val: Self,
57        prev: crate::Expr<'static, Self::FromSchema, Self::From>,
58    ) -> <Self::To as Table>::Insert;
59    #[doc(hidden)]
60    fn map_conflict(val: TableRow<Self::From>) -> Self::Conflict;
61}
62
63/// Transaction type for use in migrations.
64pub struct TransactionMigrate<FromSchema> {
65    inner: Transaction<FromSchema>,
66    scope: Scope,
67    rename_map: HashMap<&'static str, TmpTable>,
68}
69
70impl<FromSchema> Deref for TransactionMigrate<FromSchema> {
71    type Target = Transaction<FromSchema>;
72
73    fn deref(&self) -> &Self::Target {
74        &self.inner
75    }
76}
77
78impl<FromSchema> TransactionMigrate<FromSchema> {
79    fn new_table_name<T: Table>(&mut self) -> TmpTable {
80        *self.rename_map.entry(T::NAME).or_insert_with(|| {
81            let new_table_name = self.scope.tmp_table();
82            TXN.with_borrow(|txn| new_table::<T>(txn.as_ref().unwrap().get(), new_table_name));
83            new_table_name
84        })
85    }
86
87    fn unmigrated<M: Migration<FromSchema = FromSchema>, Out>(
88        &self,
89        new_name: TmpTable,
90    ) -> impl Iterator<Item = (i64, Out)>
91    where
92        Out: FromExpr<FromSchema, M::From>,
93    {
94        let data = self.inner.query(|rows| {
95            let old = rows.join(<M::From as Table>::TOKEN);
96            rows.into_vec((&old, Out::from_expr(&old)))
97        });
98
99        let migrated = Transaction::new().query(|rows| {
100            let new = rows.join_tmp::<M::From>(new_name);
101            rows.into_vec(new)
102        });
103        let migrated: HashSet<_> = migrated.into_iter().map(|x| x.inner.idx).collect();
104
105        data.into_iter().filter_map(move |(row, data)| {
106            migrated
107                .contains(&row.inner.idx)
108                .not()
109                .then_some((row.inner.idx, data))
110        })
111    }
112
113    /// Migrate some rows to the new schema.
114    ///
115    /// This will return an error when there is a conflict.
116    /// The error type depends on the number of unique constraints that the
117    /// migration can violate:
118    /// - 0 => [Infallible]
119    /// - 1.. => `TableRow<T::From>` (row in the old table that could not be migrated)
120    pub fn migrate_optional<
121        M: Migration<FromSchema = FromSchema>,
122        X: FromExpr<FromSchema, M::From>,
123    >(
124        &mut self,
125        mut f: impl FnMut(X) -> Option<M>,
126    ) -> Result<(), M::Conflict> {
127        let new_name = self.new_table_name::<M::To>();
128
129        for (idx, x) in self.unmigrated::<M, X>(new_name) {
130            if let Some(new) = f(x) {
131                try_insert_private::<M::To>(
132                    new_name.into_table_ref(),
133                    Some(idx),
134                    M::prepare(new, TableRow::new(idx).into_expr()),
135                )
136                .map_err(|_| M::map_conflict(TableRow::new(idx)))?;
137            };
138        }
139        Ok(())
140    }
141
142    /// Migrate all rows to the new schema.
143    ///
144    /// Conflict errors work the same as in [Self::migrate_optional].
145    ///
146    /// However, this method will return [Migrated] when all rows are migrated.
147    /// This can then be used as proof that there will be no foreign key violations.
148    pub fn migrate<M: Migration<FromSchema = FromSchema>, X: FromExpr<FromSchema, M::From>>(
149        &mut self,
150        mut f: impl FnMut(X) -> M,
151    ) -> Result<Migrated<'static, FromSchema, M::To>, M::Conflict> {
152        self.migrate_optional::<M, X>(|x| Some(f(x)))?;
153
154        Ok(Migrated {
155            _p: PhantomData,
156            f: Box::new(|_| {}),
157            _local: PhantomData,
158        })
159    }
160
161    /// Helper method for [Self::migrate].
162    ///
163    /// It can only be used when the migration is known to never cause unique constraint conflicts.
164    pub fn migrate_ok<
165        M: Migration<FromSchema = FromSchema, Conflict = Infallible>,
166        X: FromExpr<FromSchema, M::From>,
167    >(
168        &mut self,
169        f: impl FnMut(X) -> M,
170    ) -> Migrated<'static, FromSchema, M::To> {
171        let Ok(res) = self.migrate(f);
172        res
173    }
174}
175
176pub struct SchemaBuilder<'t, FromSchema> {
177    inner: TransactionMigrate<FromSchema>,
178    drop: Vec<TableDropStatement>,
179    foreign_key: HashMap<&'static str, Box<dyn 't + FnOnce() -> Infallible>>,
180}
181
182impl<'t, FromSchema: 'static> SchemaBuilder<'t, FromSchema> {
183    pub fn foreign_key<To: Table>(&mut self, err: impl 't + FnOnce() -> Infallible) {
184        self.inner.new_table_name::<To>();
185
186        self.foreign_key.insert(To::NAME, Box::new(err));
187    }
188
189    pub fn create_empty<To: Table>(&mut self) {
190        self.inner.new_table_name::<To>();
191    }
192
193    pub fn drop_table<T: Table>(&mut self) {
194        let name = Alias::new(T::NAME);
195        let step = sea_query::Table::drop().table(name).take();
196        self.drop.push(step);
197    }
198}
199
200fn new_table<T: Table>(conn: &Connection, alias: TmpTable) {
201    let mut f = crate::hash::TypBuilder::default();
202    T::typs(&mut f);
203    new_table_inner(conn, &f.ast, alias);
204}
205
206fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
207    let mut create = table.create();
208    create
209        .table(alias)
210        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
211    let mut sql = create.to_string(SqliteQueryBuilder);
212    sql.push_str(" STRICT");
213    conn.execute(&sql, []).unwrap();
214}
215
216pub trait SchemaMigration<'a> {
217    type From: Schema;
218    type To: Schema;
219
220    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
221}
222
223/// [Config] is used to open a database from a file or in memory.
224///
225/// This is the first step in the [Config] -> [Migrator] -> [Database] chain to
226/// get a [Database] instance.
227///
228/// # Sqlite config
229///
230/// Sqlite is configured to be in [WAL mode](https://www.sqlite.org/wal.html).
231/// The effect of this mode is that there can be any number of readers with one concurrent writer.
232/// What is nice about this is that a [& crate::Transaction] can always be made immediately.
233/// Making a [&mut crate::Transaction] has to wait until all other [&mut crate::Transaction]s are finished.
234///
235/// Sqlite is also configured with [`synchronous=NORMAL`](https://www.sqlite.org/pragma.html#pragma_synchronous). This gives better performance by fsyncing less.
236/// The database will not lose transactions due to application crashes, but it might due to system crashes or power loss.
237pub struct Config {
238    manager: r2d2_sqlite::SqliteConnectionManager,
239    init: Box<dyn FnOnce(&rusqlite::Transaction)>,
240}
241
242impl Config {
243    /// Open a database that is stored in a file.
244    /// Creates the database if it does not exist.
245    ///
246    /// Opening the same database multiple times at the same time is fine,
247    /// as long as they migrate to or use the same schema.
248    /// All locking is done by sqlite, so connections can even be made using different client implementations.
249    pub fn open(p: impl AsRef<Path>) -> Self {
250        let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
251        Self::open_internal(manager)
252    }
253
254    /// Creates a new empty database in memory.
255    pub fn open_in_memory() -> Self {
256        let manager = r2d2_sqlite::SqliteConnectionManager::memory();
257        Self::open_internal(manager)
258    }
259
260    fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
261        let manager = manager.with_init(|inner| {
262            inner.pragma_update(None, "journal_mode", "WAL")?;
263            inner.pragma_update(None, "synchronous", "NORMAL")?;
264            inner.pragma_update(None, "foreign_keys", "ON")?;
265            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
266            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
267            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
268            Ok(())
269        });
270
271        Self {
272            manager,
273            init: Box::new(|_| {}),
274        }
275    }
276
277    /// Execute a raw sql statement if the database was just created.
278    ///
279    /// The statement is executed after creating the empty database and executingall previous statements.
280    pub fn init_stmt(mut self, sql: &'static str) -> Self {
281        self.init = Box::new(move |txn| {
282            (self.init)(txn);
283
284            txn.execute_batch(sql)
285                .expect("raw sql statement to populate db failed");
286        });
287        self
288    }
289}
290
291impl<S: Schema> Database<S> {
292    /// Create a [Migrator] to migrate a database.
293    ///
294    /// Returns [None] if the database `user_version` on disk is older than `S`.
295    pub fn migrator(config: Config) -> Option<Migrator<S>> {
296        use r2d2::ManageConnection;
297        let conn = config.manager.connect().unwrap();
298        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
299        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
300            Some(
301                conn.borrow_mut()
302                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
303                    .unwrap(),
304            )
305        });
306
307        // check if this database is newly created
308        if schema_version(txn.get()) == 0 {
309            let mut b = TableTypBuilder::default();
310            S::typs(&mut b);
311
312            for (table_name, table) in &*b.ast.tables {
313                new_table_inner(txn.get(), table, Alias::new(table_name));
314            }
315            (config.init)(txn.get());
316            set_user_version(txn.get(), S::VERSION).unwrap();
317        }
318
319        let user_version = user_version(txn.get()).unwrap();
320        // We can not migrate databases older than `S`
321        if user_version < S::VERSION {
322            return None;
323        }
324        assert_eq!(
325            foreign_key_check(txn.get()),
326            None,
327            "foreign key constraint violated"
328        );
329
330        Some(Migrator {
331            manager: config.manager,
332            transaction: txn,
333            _p: PhantomData,
334        })
335    }
336}
337
338/// [Migrator] is used to apply database migrations.
339///
340/// When all migrations are done, it can be turned into a [Database] instance with
341/// [Migrator::finish].
342pub struct Migrator<S> {
343    manager: r2d2_sqlite::SqliteConnectionManager,
344    transaction: OwnedTransaction,
345    _p: PhantomData<S>,
346}
347
348/// [Migrated] provides a proof of migration.
349///
350/// This only needs to be provided for tables that are migrated from a previous table.
351pub struct Migrated<'t, FromSchema, T> {
352    _p: PhantomData<T>,
353    f: Box<dyn 't + FnOnce(&mut SchemaBuilder<'t, FromSchema>)>,
354    _local: PhantomData<*const ()>,
355}
356
357impl<'t, FromSchema: 'static, T: Table> Migrated<'t, FromSchema, T> {
358    /// Don't migrate the remaining rows.
359    ///
360    /// This can cause foreign key constraint violations, which is why an error callback needs to be provided.
361    pub fn map_fk_err(err: impl 't + FnOnce() -> Infallible) -> Self {
362        Self {
363            _p: PhantomData,
364            f: Box::new(|x| x.foreign_key::<T>(err)),
365            _local: PhantomData,
366        }
367    }
368
369    #[doc(hidden)]
370    pub fn apply(self, b: &mut SchemaBuilder<'t, FromSchema>) {
371        (self.f)(b)
372    }
373}
374
375impl<S: Schema> Migrator<S> {
376    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
377    ///
378    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
379    pub fn migrate<'x, M>(
380        mut self,
381        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
382    ) -> Migrator<M::To>
383    where
384        M: SchemaMigration<'x, From = S>,
385    {
386        if user_version(self.transaction.get()).unwrap() == S::VERSION {
387            self.transaction = std::thread::scope(|s| {
388                s.spawn(|| {
389                    TXN.set(Some(self.transaction));
390
391                    check_schema::<S>();
392
393                    let mut txn = TransactionMigrate {
394                        inner: Transaction::new(),
395                        scope: Default::default(),
396                        rename_map: HashMap::new(),
397                    };
398                    let m = m(&mut txn);
399
400                    let mut builder = SchemaBuilder {
401                        drop: vec![],
402                        foreign_key: HashMap::new(),
403                        inner: txn,
404                    };
405                    m.tables(&mut builder);
406
407                    let transaction = TXN.take().unwrap();
408
409                    for drop in builder.drop {
410                        let sql = drop.to_string(SqliteQueryBuilder);
411                        transaction.get().execute(&sql, []).unwrap();
412                    }
413                    for (to, tmp) in builder.inner.rename_map {
414                        let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
415                        let sql = rename.to_string(SqliteQueryBuilder);
416                        transaction.get().execute(&sql, []).unwrap();
417                    }
418                    if let Some(fk) = foreign_key_check(transaction.get()) {
419                        (builder.foreign_key.remove(&*fk).unwrap())();
420                    }
421                    #[allow(
422                        unreachable_code,
423                        reason = "rustc is stupid and thinks this is unreachable"
424                    )]
425                    set_user_version(transaction.get(), M::To::VERSION).unwrap();
426
427                    transaction
428                })
429                .join()
430                .unwrap()
431            });
432        }
433
434        Migrator {
435            manager: self.manager,
436            transaction: self.transaction,
437            _p: PhantomData,
438        }
439    }
440
441    /// Commit the migration transaction and return a [Database].
442    ///
443    /// Returns [None] if the database schema version is newer than `S`.
444    ///
445    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
446    pub fn finish(mut self) -> Option<Database<S>> {
447        let conn = &self.transaction;
448        if user_version(conn.get()).unwrap() != S::VERSION {
449            return None;
450        }
451
452        self.transaction = std::thread::scope(|s| {
453            s.spawn(|| {
454                TXN.set(Some(self.transaction));
455                check_schema::<S>();
456                TXN.take().unwrap()
457            })
458            .join()
459            .unwrap()
460        });
461
462        // adds an sqlite_stat1 table
463        self.transaction
464            .get()
465            .execute_batch("PRAGMA optimize;")
466            .unwrap();
467
468        let schema_version = schema_version(self.transaction.get());
469        self.transaction.with(|x| x.commit().unwrap());
470
471        Some(Database {
472            manager: self.manager,
473            schema_version,
474            schema: PhantomData,
475        })
476    }
477}
478
479pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
480    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
481        .unwrap()
482}
483
484// Read user version field from the SQLite db
485fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
486    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
487}
488
489// Set user version field from the SQLite db
490fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
491    conn.pragma_update(None, "user_version", v)
492}
493
494fn check_schema<S: Schema>() {
495    let mut b = TableTypBuilder::default();
496    S::typs(&mut b);
497    pretty_assertions::assert_eq!(
498        b.ast,
499        read_schema(&crate::Transaction::new()),
500        "schema is different (expected left, but got right)",
501    );
502}
503
504fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
505    let error = conn
506        .prepare("PRAGMA foreign_key_check")
507        .unwrap()
508        .query_map([], |row| row.get(2))
509        .unwrap()
510        .next();
511    error.transpose().unwrap()
512}
513
514#[test]
515fn open_multiple() {
516    #[crate::migration::schema(Empty)]
517    pub mod vN {}
518
519    let _a = Database::<v0::Empty>::migrator(Config::open_in_memory());
520    let _b = Database::<v0::Empty>::migrator(Config::open_in_memory());
521}