rust_query/
migrate.rs

1pub mod config;
2pub mod migration;
3#[cfg(test)]
4mod test;
5
6use std::{
7    collections::{BTreeSet, HashMap},
8    marker::PhantomData,
9    sync::atomic::AtomicI64,
10};
11
12use annotate_snippets::{Renderer, renderer::DecorStyle};
13use rusqlite::config::DbConfig;
14use sea_query::{Alias, ColumnDef, IntoIden, SqliteQueryBuilder};
15use self_cell::MutBorrow;
16
17use crate::{
18    Table, Transaction,
19    alias::Scope,
20    migrate::{
21        config::Config,
22        migration::{SchemaBuilder, TransactionMigrate},
23    },
24    pool::Pool,
25    schema::{from_db, from_macro, read::read_schema},
26    transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
27};
28
29pub struct TableTypBuilder<S> {
30    pub(crate) ast: from_macro::Schema,
31    _p: PhantomData<S>,
32}
33
34impl<S> Default for TableTypBuilder<S> {
35    fn default() -> Self {
36        Self {
37            ast: Default::default(),
38            _p: Default::default(),
39        }
40    }
41}
42
43impl<S> TableTypBuilder<S> {
44    pub fn table<T: Table<Schema = S>>(&mut self) {
45        let table = from_macro::Table::new::<T>();
46        let old = self.ast.tables.insert(T::NAME, table);
47        debug_assert!(old.is_none());
48    }
49}
50
51pub trait Schema: Sized + 'static {
52    const VERSION: i64;
53    const SOURCE: &str;
54    const PATH: &str;
55    const SPAN: (usize, usize);
56    fn typs(b: &mut TableTypBuilder<Self>);
57}
58
59fn new_table_inner(table: &crate::schema::from_macro::Table, alias: impl IntoIden) -> String {
60    let alias = alias.into_iden();
61    let mut create = table.create();
62    create
63        .table(alias.clone())
64        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
65    let mut sql = create.to_string(SqliteQueryBuilder);
66    sql.push_str(" STRICT");
67    sql
68}
69
70pub trait SchemaMigration<'a> {
71    type From: Schema;
72    type To: Schema;
73
74    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
75}
76
77impl<S: Schema> Database<S> {
78    /// Create a [Migrator] to migrate a database.
79    ///
80    /// Returns [None] if the database `user_version` on disk is older than `S`.
81    pub fn migrator(config: Config) -> Option<Migrator<S>> {
82        let synchronous = config.synchronous.as_str();
83        let foreign_keys = config.foreign_keys.as_str();
84        let manager = config.manager.with_init(move |inner| {
85            inner.pragma_update(None, "journal_mode", "WAL")?;
86            inner.pragma_update(None, "synchronous", synchronous)?;
87            inner.pragma_update(None, "foreign_keys", foreign_keys)?;
88            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
89            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
90            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
91            Ok(())
92        });
93
94        use r2d2::ManageConnection;
95        let conn = manager.connect().unwrap();
96        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
97        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
98            Some(
99                conn.borrow_mut()
100                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
101                    .unwrap(),
102            )
103        });
104
105        let mut user_version = Some(user_version(txn.get()).unwrap());
106
107        // check if this database is newly created
108        if schema_version(txn.get()) == 0 {
109            user_version = None;
110
111            let schema = crate::schema::from_macro::Schema::new::<S>();
112
113            for (&table_name, table) in &schema.tables {
114                txn.get()
115                    .execute(&new_table_inner(table, table_name), [])
116                    .unwrap();
117                for stmt in table.delayed_indices(table_name) {
118                    txn.get().execute(&stmt, []).unwrap();
119                }
120            }
121            (config.init)(txn.get());
122        } else if user_version.unwrap() < S::VERSION {
123            // We can not migrate databases older than `S`
124            return None;
125        }
126
127        debug_assert_eq!(
128            foreign_key_check(txn.get()),
129            None,
130            "foreign key constraint violated"
131        );
132
133        Some(Migrator {
134            user_version,
135            manager,
136            transaction: txn,
137            _p: PhantomData,
138        })
139    }
140}
141
142/// [Migrator] is used to apply database migrations.
143///
144/// When all migrations are done, it can be turned into a [Database] instance with
145/// [Migrator::finish].
146pub struct Migrator<S> {
147    manager: r2d2_sqlite::SqliteConnectionManager,
148    transaction: OwnedTransaction,
149    // Initialized to the user version when the transaction starts.
150    // This is set to None if the schema user_version is updated.
151    // Fixups are only applied if the user_version is None.
152    // Indices are fixed before this is set to None.
153    user_version: Option<i64>,
154    _p: PhantomData<S>,
155}
156
157impl<S: Schema> Migrator<S> {
158    fn with_transaction(mut self, f: impl Send + FnOnce(&mut Transaction<S>)) -> Self {
159        assert!(self.user_version.is_none_or(|x| x == S::VERSION));
160        let res = std::thread::scope(|s| {
161            s.spawn(|| {
162                TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
163                let txn = Transaction::new_ref();
164
165                // check if this is the first migration that is applied
166                if self.user_version.take().is_some() {
167                    // we check the schema before doing any migrations
168                    check_schema::<S>(txn);
169                    // fixing indices before migrations can help with migration performance
170                    fix_indices::<S>(txn);
171                }
172
173                f(txn);
174
175                let transaction = TXN.take().unwrap();
176
177                transaction.into_owner()
178            })
179            .join()
180        });
181        match res {
182            Ok(val) => self.transaction = val,
183            Err(payload) => std::panic::resume_unwind(payload),
184        }
185        self
186    }
187
188    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
189    ///
190    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
191    pub fn migrate<'x, M>(
192        mut self,
193        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
194    ) -> Migrator<M::To>
195    where
196        M: SchemaMigration<'x, From = S>,
197    {
198        if self.user_version.is_none_or(|x| x == S::VERSION) {
199            self = self.with_transaction(|txn| {
200                let mut txn = TransactionMigrate {
201                    inner: txn.copy(),
202                    scope: Default::default(),
203                    rename_map: HashMap::new(),
204                    extra_index: Vec::new(),
205                };
206                let m = m(&mut txn);
207
208                let mut builder = SchemaBuilder {
209                    drop: vec![],
210                    foreign_key: HashMap::new(),
211                    inner: txn,
212                };
213                m.tables(&mut builder);
214
215                let transaction = TXN.take().unwrap();
216
217                for drop in builder.drop {
218                    let sql = drop.to_string(SqliteQueryBuilder);
219                    transaction.get().execute(&sql, []).unwrap();
220                }
221                for (to, tmp) in builder.inner.rename_map {
222                    let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
223                    let sql = rename.to_string(SqliteQueryBuilder);
224                    transaction.get().execute(&sql, []).unwrap();
225                }
226                if let Some(fk) = foreign_key_check(transaction.get()) {
227                    (builder.foreign_key.remove(&*fk).unwrap())();
228                }
229                #[allow(
230                    unreachable_code,
231                    reason = "rustc is stupid and thinks this is unreachable"
232                )]
233                // adding non unique indexes is fine to do after checking foreign keys
234                for stmt in builder.inner.extra_index {
235                    transaction.get().execute(&stmt, []).unwrap();
236                }
237
238                TXN.set(Some(transaction));
239            });
240        }
241
242        Migrator {
243            user_version: self.user_version,
244            manager: self.manager,
245            transaction: self.transaction,
246            _p: PhantomData,
247        }
248    }
249
250    /// Mutate the database as part of migrations.
251    ///
252    /// The closure will only be executed if the database got migrated to schema version `S`
253    /// by this [Migrator] instance.
254    /// If [Migrator::fixup] is used before [Migrator::migrate], then the closures is only executed
255    /// when the database is created.
256    pub fn fixup(mut self, f: impl Send + FnOnce(&mut Transaction<S>)) -> Self {
257        if self.user_version.is_none() {
258            self = self.with_transaction(f);
259        }
260        self
261    }
262
263    /// Commit the migration transaction and return a [Database].
264    ///
265    /// Returns [None] if the database schema version is newer than `S`.
266    ///
267    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
268    pub fn finish(mut self) -> Option<Database<S>> {
269        if self.user_version.is_some_and(|x| x != S::VERSION) {
270            return None;
271        }
272
273        // This checks that the schema is correct and fixes indices etc
274        self = self.with_transaction(|txn| {
275            // sanity check, this should never fail
276            check_schema::<S>(txn);
277        });
278
279        // adds an sqlite_stat1 table
280        self.transaction
281            .get()
282            .execute_batch("PRAGMA optimize;")
283            .unwrap();
284
285        set_user_version(self.transaction.get(), S::VERSION).unwrap();
286        let schema_version = schema_version(self.transaction.get());
287        self.transaction.with(|x| x.commit().unwrap());
288
289        Some(Database {
290            manager: Pool::new(self.manager),
291            schema_version: AtomicI64::new(schema_version),
292            schema: PhantomData,
293            mut_lock: parking_lot::FairMutex::new(()),
294        })
295    }
296}
297
298fn fix_indices<S: Schema>(txn: &Transaction<S>) {
299    let schema = read_schema(txn);
300    let expected_schema = crate::schema::from_macro::Schema::new::<S>();
301
302    fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
303        let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
304        let actual: BTreeSet<_> = actual.indices.values().collect();
305        expected == actual
306    }
307
308    for (&table_name, expected_table) in &expected_schema.tables {
309        let table = &schema.tables[table_name];
310
311        if !check_eq(expected_table, &table) {
312            // Unique constraints that are part of a table definition
313            // can not be dropped, so we assume the worst and just recreate
314            // the whole table.
315
316            let scope = Scope::default();
317            let tmp_name = scope.tmp_table();
318
319            txn.execute(&new_table_inner(expected_table, tmp_name));
320
321            let mut columns: Vec<_> = expected_table
322                .columns
323                .keys()
324                .map(|x| Alias::new(x))
325                .collect();
326            columns.push(Alias::new("id"));
327
328            txn.execute(
329                &sea_query::InsertStatement::new()
330                    .into_table(tmp_name)
331                    .columns(columns.clone())
332                    .select_from(
333                        sea_query::SelectStatement::new()
334                            .from(table_name)
335                            .columns(columns)
336                            .take(),
337                    )
338                    .unwrap()
339                    .build(SqliteQueryBuilder)
340                    .0,
341            );
342
343            txn.execute(
344                &sea_query::TableDropStatement::new()
345                    .table(table_name)
346                    .build(SqliteQueryBuilder),
347            );
348
349            txn.execute(
350                &sea_query::TableRenameStatement::new()
351                    .table(tmp_name, table_name)
352                    .build(SqliteQueryBuilder),
353            );
354            // Add the new non-unique indices
355            for sql in expected_table.delayed_indices(table_name) {
356                txn.execute(&sql);
357            }
358        }
359    }
360
361    // check that we solved the mismatch
362    let schema = read_schema(txn);
363    for (name, table) in schema.tables {
364        let expected_table = &expected_schema.tables[&*name];
365        assert!(check_eq(expected_table, &table));
366    }
367}
368
369impl<S> Transaction<S> {
370    #[track_caller]
371    pub(crate) fn execute(&self, sql: &str) {
372        TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
373            .unwrap();
374    }
375}
376
377pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
378    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
379        .unwrap()
380}
381
382// Read user version field from the SQLite db
383pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
384    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
385}
386
387// Set user version field from the SQLite db
388fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
389    conn.pragma_update(None, "user_version", v)
390}
391
392pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
393    let from_macro = crate::schema::from_macro::Schema::new::<S>();
394    let from_db = read_schema(txn);
395    let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
396    if !report.is_empty() {
397        let renderer = if cfg!(test) {
398            Renderer::plain().anonymized_line_numbers(true)
399        } else {
400            Renderer::styled()
401        }
402        .decor_style(DecorStyle::Unicode);
403        panic!("{}", renderer.render(&report));
404    }
405}
406
407fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
408    let error = conn
409        .prepare("PRAGMA foreign_key_check")
410        .unwrap()
411        .query_map([], |row| row.get(2))
412        .unwrap()
413        .next();
414    error.transpose().unwrap()
415}
416
417impl<S> Transaction<S> {
418    #[cfg(test)]
419    pub(crate) fn schema(&self) -> Vec<String> {
420        TXN.with_borrow(|x| {
421            x.as_ref()
422                .unwrap()
423                .get()
424                .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
425                .unwrap()
426                .query_map([], |row| row.get::<_, Option<String>>("sql"))
427                .unwrap()
428                .flat_map(|x| x.unwrap())
429                .collect()
430        })
431    }
432}