rust_query/
migrate.rs

1use std::{
2    collections::{HashMap, HashSet},
3    convert::Infallible,
4    marker::PhantomData,
5    ops::{Deref, Not},
6    path::Path,
7    rc::Rc,
8    sync::atomic::AtomicBool,
9};
10
11use rusqlite::{Connection, config::DbConfig};
12use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder, TableDropStatement};
13
14use crate::{
15    FromExpr, Table, TableRow, Transaction,
16    alias::{Scope, TmpTable},
17    client::LocalClient,
18    hash,
19    schema_pragma::read_schema,
20    transaction::{Database, try_insert_private},
21};
22
23pub struct TableTypBuilder<S> {
24    pub(crate) ast: hash::Schema,
25    _p: PhantomData<S>,
26}
27
28impl<S> Default for TableTypBuilder<S> {
29    fn default() -> Self {
30        Self {
31            ast: Default::default(),
32            _p: Default::default(),
33        }
34    }
35}
36
37impl<S> TableTypBuilder<S> {
38    pub fn table<T: Table<Schema = S>>(&mut self) {
39        let mut b = hash::TypBuilder::default();
40        T::typs(&mut b);
41        self.ast.tables.insert((T::NAME.to_owned(), b.ast));
42    }
43}
44
45pub trait Schema: Sized + 'static {
46    const VERSION: i64;
47    fn typs(b: &mut TableTypBuilder<Self>);
48}
49
50pub trait Migration<'t> {
51    type FromSchema: 'static;
52    type From: Table<Schema = Self::FromSchema>;
53    type To: Table<MigrateFrom = Self::From>;
54    type Conflict;
55
56    #[doc(hidden)]
57    fn prepare(val: Self, prev: TableRow<'t, Self::From>) -> <Self::To as Table>::Insert<'t>;
58    #[doc(hidden)]
59    fn map_conflict(val: TableRow<'t, Self::From>) -> Self::Conflict;
60}
61
62/// Transaction type for use in migrations.
63pub struct TransactionMigrate<'t, FromSchema> {
64    inner: Transaction<'t, FromSchema>,
65    scope: Scope,
66    rename_map: HashMap<&'static str, TmpTable>,
67}
68
69impl<'t, FromSchema> Deref for TransactionMigrate<'t, FromSchema> {
70    type Target = Transaction<'t, FromSchema>;
71
72    fn deref(&self) -> &Self::Target {
73        &self.inner
74    }
75}
76
77impl<'t, FromSchema> TransactionMigrate<'t, FromSchema> {
78    fn new_table_name<T: Table>(&mut self) -> TmpTable {
79        *self.rename_map.entry(T::NAME).or_insert_with(|| {
80            let new_table_name = self.scope.tmp_table();
81            new_table::<T>(&self.inner.transaction, new_table_name);
82            new_table_name
83        })
84    }
85
86    fn unmigrated<M: Migration<'t, FromSchema = FromSchema>, Out>(
87        &self,
88        new_name: TmpTable,
89    ) -> impl Iterator<Item = (i64, Out)>
90    where
91        Out: FromExpr<'t, FromSchema, M::From>,
92    {
93        let data = self.inner.query(|rows| {
94            let old = rows.join(<M::From as Table>::TOKEN);
95            rows.into_vec((&old, Out::from_expr(&old)))
96        });
97
98        let migrated = Transaction::new(self.inner.transaction.clone()).query(|rows| {
99            let new = rows.join_tmp::<M::From>(new_name);
100            rows.into_vec(new)
101        });
102        let migrated: HashSet<_> = migrated.into_iter().map(|x| x.inner.idx).collect();
103
104        data.into_iter().filter_map(move |(row, data)| {
105            migrated
106                .contains(&row.inner.idx)
107                .not()
108                .then_some((row.inner.idx, data))
109        })
110    }
111
112    /// Migrate some rows to the new schema.
113    ///
114    /// This will return an error when there is a conflict.
115    /// The error type depends on the number of unique constraints that the
116    /// migration can violate:
117    /// - 0 => [Infallible]
118    /// - 1.. => `TableRow<T::From>` (row in the old table that could not be migrated)
119    pub fn migrate_optional<
120        M: Migration<'t, FromSchema = FromSchema>,
121        X: FromExpr<'t, FromSchema, M::From>,
122    >(
123        &mut self,
124        mut f: impl FnMut(X) -> Option<M>,
125    ) -> Result<(), M::Conflict> {
126        let new_name = self.new_table_name::<M::To>();
127
128        for (idx, x) in self.unmigrated::<M, X>(new_name) {
129            if let Some(new) = f(x) {
130                try_insert_private::<M::To>(
131                    &self.transaction,
132                    new_name.into_table_ref(),
133                    Some(idx),
134                    M::prepare(new, TableRow::new(idx)),
135                )
136                .map_err(|_| M::map_conflict(TableRow::new(idx)))?;
137            };
138        }
139        Ok(())
140    }
141
142    /// Migrate all rows to the new schema.
143    ///
144    /// Conflict errors work the same as in [Self::migrate_optional].
145    ///
146    /// However, this method will return [Migrated] when all rows are migrated.
147    /// This can then be used as proof that there will be no foreign key violations.
148    pub fn migrate<
149        M: Migration<'t, FromSchema = FromSchema>,
150        X: FromExpr<'t, FromSchema, M::From>,
151    >(
152        &mut self,
153        mut f: impl FnMut(X) -> M,
154    ) -> Result<Migrated<'t, FromSchema, M::To>, M::Conflict> {
155        self.migrate_optional::<M, X>(|x| Some(f(x)))?;
156
157        Ok(Migrated {
158            _p: PhantomData,
159            f: Box::new(|_| {}),
160        })
161    }
162
163    /// Helper method for [Self::migrate].
164    ///
165    /// It can only be used when the migration is known to never cause unique constraint conflicts.
166    pub fn migrate_ok<
167        M: Migration<'t, FromSchema = FromSchema, Conflict = Infallible>,
168        X: FromExpr<'t, FromSchema, M::From>,
169    >(
170        &mut self,
171        f: impl FnMut(X) -> M,
172    ) -> Migrated<'t, FromSchema, M::To> {
173        let Ok(res) = self.migrate(f);
174        res
175    }
176}
177
178pub struct SchemaBuilder<'t, FromSchema> {
179    inner: TransactionMigrate<'t, FromSchema>,
180    drop: Vec<TableDropStatement>,
181    foreign_key: HashMap<&'static str, Box<dyn 't + FnOnce() -> Infallible>>,
182}
183
184impl<'t, FromSchema: 'static> SchemaBuilder<'t, FromSchema> {
185    pub fn foreign_key<To: Table>(&mut self, err: impl 't + FnOnce() -> Infallible) {
186        self.inner.new_table_name::<To>();
187
188        self.foreign_key.insert(To::NAME, Box::new(err));
189    }
190
191    pub fn create_empty<To: Table>(&mut self) {
192        self.inner.new_table_name::<To>();
193    }
194
195    pub fn drop_table<T: Table>(&mut self) {
196        let name = Alias::new(T::NAME);
197        let step = sea_query::Table::drop().table(name).take();
198        self.drop.push(step);
199    }
200}
201
202fn new_table<T: Table>(conn: &Connection, alias: TmpTable) {
203    let mut f = crate::hash::TypBuilder::default();
204    T::typs(&mut f);
205    new_table_inner(conn, &f.ast, alias);
206}
207
208fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
209    let mut create = table.create();
210    create
211        .table(alias)
212        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
213    let mut sql = create.to_string(SqliteQueryBuilder);
214    sql.push_str(" STRICT");
215    conn.execute(&sql, []).unwrap();
216}
217
218pub trait SchemaMigration<'a> {
219    type From: Schema;
220    type To: Schema;
221
222    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
223}
224
225/// [Config] is used to open a database from a file or in memory.
226///
227/// This is the first step in the [Config] -> [Migrator] -> [Database] chain to
228/// get a [Database] instance.
229///
230/// # Sqlite config
231///
232/// Sqlite is configured to be in [WAL mode](https://www.sqlite.org/wal.html).
233/// The effect of this mode is that there can be any number of readers with one concurrent writer.
234/// What is nice about this is that a [Transaction] can always be made immediately.
235/// Making a [crate::TransactionMut] has to wait until all other [crate::TransactionMut]s are finished.
236///
237/// Sqlite is also configured with [`synchronous=NORMAL`](https://www.sqlite.org/pragma.html#pragma_synchronous). This gives better performance by fsyncing less.
238/// The database will not lose transactions due to application crashes, but it might due to system crashes or power loss.
239pub struct Config {
240    manager: r2d2_sqlite::SqliteConnectionManager,
241    init: Box<dyn FnOnce(&rusqlite::Transaction)>,
242}
243
244static ALLOWED: AtomicBool = AtomicBool::new(true);
245
246impl Config {
247    /// Open a database that is stored in a file.
248    /// Creates the database if it does not exist.
249    ///
250    /// Opening the same database multiple times at the same time is fine,
251    /// as long as they migrate to or use the same schema.
252    /// All locking is done by sqlite, so connections can even be made using different client implementations.
253    pub fn open(p: impl AsRef<Path>) -> Self {
254        let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
255        Self::open_internal(manager)
256    }
257
258    /// Creates a new empty database in memory.
259    pub fn open_in_memory() -> Self {
260        let manager = r2d2_sqlite::SqliteConnectionManager::memory();
261        Self::open_internal(manager)
262    }
263
264    fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
265        assert!(ALLOWED.swap(false, std::sync::atomic::Ordering::Relaxed));
266        let manager = manager.with_init(|inner| {
267            inner.pragma_update(None, "journal_mode", "WAL")?;
268            inner.pragma_update(None, "synchronous", "NORMAL")?;
269            inner.pragma_update(None, "foreign_keys", "ON")?;
270            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
271            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
272            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
273            Ok(())
274        });
275
276        Self {
277            manager,
278            init: Box::new(|_| {}),
279        }
280    }
281
282    /// Execute a raw sql statement if the database was just created.
283    ///
284    /// The statement is executed after creating the empty database and executingall previous statements.
285    pub fn init_stmt(mut self, sql: &'static str) -> Self {
286        self.init = Box::new(move |txn| {
287            (self.init)(txn);
288
289            txn.execute_batch(sql)
290                .expect("raw sql statement to populate db failed");
291        });
292        self
293    }
294}
295
296impl LocalClient {
297    /// Create a [Migrator] to migrate a database.
298    ///
299    /// Returns [None] if the database `user_version` on disk is older than `S`.
300    pub fn migrator<S: Schema>(&mut self, config: Config) -> Option<Migrator<'_, S>> {
301        use r2d2::ManageConnection;
302        let conn = self.conn.insert(config.manager.connect().unwrap());
303        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
304
305        let conn = conn
306            .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
307            .unwrap();
308        let conn = Rc::new(conn);
309
310        // check if this database is newly created
311        if schema_version(&conn) == 0 {
312            let mut b = TableTypBuilder::default();
313            S::typs(&mut b);
314
315            for (table_name, table) in &*b.ast.tables {
316                new_table_inner(&conn, table, Alias::new(table_name));
317            }
318            (config.init)(&conn);
319            set_user_version(&conn, S::VERSION).unwrap();
320        }
321
322        let user_version = user_version(&conn).unwrap();
323        // We can not migrate databases older than `S`
324        if user_version < S::VERSION {
325            return None;
326        }
327        assert_eq!(
328            foreign_key_check(&conn),
329            None,
330            "foreign key constraint violated"
331        );
332
333        Some(Migrator {
334            manager: config.manager,
335            transaction: conn,
336            _p: PhantomData,
337            _local: PhantomData,
338            _p0: PhantomData,
339        })
340    }
341}
342
343/// [Migrator] is used to apply database migrations.
344///
345/// When all migrations are done, it can be turned into a [Database] instance with
346/// [Migrator::finish].
347pub struct Migrator<'t, S> {
348    manager: r2d2_sqlite::SqliteConnectionManager,
349    transaction: Rc<rusqlite::Transaction<'t>>,
350    _p0: PhantomData<fn(&'t ()) -> &'t ()>,
351    _p: PhantomData<S>,
352    // We want to make sure that Migrator is always used with the same LocalClient
353    // so we make it local to the current thread.
354    // This is mostly important because the LocalClient can have a reference to our transaction.
355    _local: PhantomData<LocalClient>,
356}
357
358/// [Migrated] provides a proof of migration.
359///
360/// This only needs to be provided for tables that are migrated from a previous table.
361// TODO: is this lifetime bound good enough, why?
362pub struct Migrated<'t, FromSchema, T> {
363    _p: PhantomData<(fn(&'t ()) -> &'t (), T)>,
364    f: Box<dyn 't + FnOnce(&mut SchemaBuilder<'t, FromSchema>)>,
365}
366
367impl<'t, FromSchema: 'static, T: Table> Migrated<'t, FromSchema, T> {
368    /// Don't migrate the remaining rows.
369    ///
370    /// This can cause foreign key constraint violations, which is why an error callback needs to be provided.
371    pub fn map_fk_err(err: impl 't + FnOnce() -> Infallible) -> Self {
372        Self {
373            _p: PhantomData,
374            f: Box::new(|x| x.foreign_key::<T>(err)),
375        }
376    }
377
378    #[doc(hidden)]
379    pub fn apply(self, b: &mut SchemaBuilder<'t, FromSchema>) {
380        (self.f)(b)
381    }
382}
383
384impl<'t, S: Schema> Migrator<'t, S> {
385    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
386    ///
387    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
388    pub fn migrate<M>(
389        self,
390        m: impl FnOnce(&mut TransactionMigrate<'t, S>) -> M,
391    ) -> Migrator<'t, M::To>
392    where
393        M: SchemaMigration<'t, From = S>,
394    {
395        if user_version(&self.transaction).unwrap() == S::VERSION {
396            check_schema::<S>(&self.transaction);
397
398            let mut txn = TransactionMigrate {
399                inner: Transaction::new(self.transaction.clone()),
400                scope: Default::default(),
401                rename_map: HashMap::new(),
402            };
403            let m = m(&mut txn);
404
405            let mut builder = SchemaBuilder {
406                drop: vec![],
407                foreign_key: HashMap::new(),
408                inner: txn,
409            };
410            m.tables(&mut builder);
411
412            for drop in builder.drop {
413                let sql = drop.to_string(SqliteQueryBuilder);
414                self.transaction.execute(&sql, []).unwrap();
415            }
416            for (to, tmp) in builder.inner.rename_map {
417                let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
418                let sql = rename.to_string(SqliteQueryBuilder);
419                self.transaction.execute(&sql, []).unwrap();
420            }
421            if let Some(fk) = foreign_key_check(&self.transaction) {
422                (builder.foreign_key.remove(&*fk).unwrap())();
423            }
424            #[allow(
425                unreachable_code,
426                reason = "rustc is stupid and thinks this is unreachable"
427            )]
428            set_user_version(&self.transaction, M::To::VERSION).unwrap();
429        }
430
431        Migrator {
432            manager: self.manager,
433            transaction: self.transaction,
434            _p: PhantomData,
435            _local: PhantomData,
436            _p0: PhantomData,
437        }
438    }
439
440    /// Commit the migration transaction and return a [Database].
441    ///
442    /// Returns [None] if the database schema version is newer than `S`.
443    ///
444    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
445    pub fn finish(self) -> Option<Database<S>> {
446        let conn = &self.transaction;
447        if user_version(conn).unwrap() != S::VERSION {
448            return None;
449        }
450        check_schema::<S>(&self.transaction);
451
452        // adds an sqlite_stat1 table
453        self.transaction.execute_batch("PRAGMA optimize;").unwrap();
454
455        let schema_version = schema_version(conn);
456        Rc::into_inner(self.transaction).unwrap().commit().unwrap();
457
458        Some(Database {
459            manager: self.manager,
460            schema_version,
461            schema: PhantomData,
462        })
463    }
464}
465
466pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
467    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
468        .unwrap()
469}
470
471// Read user version field from the SQLite db
472fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
473    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
474}
475
476// Set user version field from the SQLite db
477fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
478    conn.pragma_update(None, "user_version", v)
479}
480
481fn check_schema<S: Schema>(conn: &Rc<rusqlite::Transaction>) {
482    let mut b = TableTypBuilder::default();
483    S::typs(&mut b);
484    pretty_assertions::assert_eq!(
485        b.ast,
486        read_schema(&crate::Transaction::new(conn.clone())),
487        "schema is different (expected left, but got right)",
488    );
489}
490
491fn foreign_key_check(conn: &Rc<rusqlite::Transaction>) -> Option<String> {
492    let error = conn
493        .prepare("PRAGMA foreign_key_check")
494        .unwrap()
495        .query_map([], |row| row.get(2))
496        .unwrap()
497        .next();
498    error.transpose().unwrap()
499}