rust_query/
migrate.rs

1pub mod config;
2pub mod migration;
3#[cfg(test)]
4mod test;
5
6use std::{
7    collections::{BTreeSet, HashMap},
8    marker::PhantomData,
9    sync::atomic::AtomicI64,
10};
11
12use annotate_snippets::{Renderer, renderer::DecorStyle};
13use rusqlite::config::DbConfig;
14use sea_query::{Alias, ColumnDef, IntoIden, SqliteQueryBuilder};
15use self_cell::MutBorrow;
16
17use crate::{
18    Table, Transaction,
19    alias::Scope,
20    migrate::{
21        config::Config,
22        migration::{SchemaBuilder, TransactionMigrate},
23    },
24    pool::Pool,
25    schema::{from_db, from_macro, read::read_schema},
26    transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
27};
28
29pub struct TableTypBuilder<S> {
30    pub(crate) ast: from_macro::Schema,
31    _p: PhantomData<S>,
32}
33
34impl<S> Default for TableTypBuilder<S> {
35    fn default() -> Self {
36        Self {
37            ast: Default::default(),
38            _p: Default::default(),
39        }
40    }
41}
42
43impl<S> TableTypBuilder<S> {
44    pub fn table<T: Table<Schema = S>>(&mut self) {
45        let table = from_macro::Table::new::<T>();
46        let old = self.ast.tables.insert(T::NAME, table);
47        debug_assert!(old.is_none());
48    }
49}
50
51pub trait Schema: Sized + 'static {
52    const VERSION: i64;
53    const SOURCE: &str;
54    const PATH: &str;
55    const SPAN: (usize, usize);
56    fn typs(b: &mut TableTypBuilder<Self>);
57}
58
59fn new_table_inner(table: &crate::schema::from_macro::Table, alias: impl IntoIden) -> String {
60    let alias = alias.into_iden();
61    let mut create = table.create();
62    create
63        .table(alias.clone())
64        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
65    let mut sql = create.to_string(SqliteQueryBuilder);
66    sql.push_str(" STRICT");
67    sql
68}
69
70pub trait SchemaMigration<'a> {
71    type From: Schema;
72    type To: Schema;
73
74    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
75}
76
77impl<S: Schema> Database<S> {
78    /// Create a [Migrator] to migrate a database.
79    ///
80    /// Returns [None] if the database `user_version` on disk is older than `S`.
81    pub fn migrator(config: Config) -> Option<Migrator<S>> {
82        let synchronous = config.synchronous.as_str();
83        let foreign_keys = config.foreign_keys.as_str();
84        let manager = config.manager.with_init(move |inner| {
85            inner.pragma_update(None, "journal_mode", "WAL")?;
86            inner.pragma_update(None, "synchronous", synchronous)?;
87            inner.pragma_update(None, "foreign_keys", foreign_keys)?;
88            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
89            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
90            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
91            Ok(())
92        });
93
94        use r2d2::ManageConnection;
95        let conn = manager.connect().unwrap();
96        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
97        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
98            Some(
99                conn.borrow_mut()
100                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
101                    .unwrap(),
102            )
103        });
104
105        // check if this database is newly created
106        if schema_version(txn.get()) == 0 {
107            let schema = crate::schema::from_macro::Schema::new::<S>();
108
109            for (&table_name, table) in &schema.tables {
110                txn.get()
111                    .execute(&new_table_inner(table, table_name), [])
112                    .unwrap();
113                for stmt in table.delayed_indices(table_name) {
114                    txn.get().execute(&stmt, []).unwrap();
115                }
116            }
117            (config.init)(txn.get());
118            set_user_version(txn.get(), S::VERSION).unwrap();
119        }
120
121        let user_version = user_version(txn.get()).unwrap();
122        // We can not migrate databases older than `S`
123        if user_version < S::VERSION {
124            return None;
125        }
126        debug_assert_eq!(
127            foreign_key_check(txn.get()),
128            None,
129            "foreign key constraint violated"
130        );
131
132        Some(Migrator {
133            indices_fixed: false,
134            manager,
135            transaction: txn,
136            _p: PhantomData,
137        })
138    }
139}
140
141/// [Migrator] is used to apply database migrations.
142///
143/// When all migrations are done, it can be turned into a [Database] instance with
144/// [Migrator::finish].
145pub struct Migrator<S> {
146    manager: r2d2_sqlite::SqliteConnectionManager,
147    transaction: OwnedTransaction,
148    indices_fixed: bool,
149    _p: PhantomData<S>,
150}
151
152impl<S: Schema> Migrator<S> {
153    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
154    ///
155    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
156    pub fn migrate<'x, M>(
157        mut self,
158        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
159    ) -> Migrator<M::To>
160    where
161        M: SchemaMigration<'x, From = S>,
162    {
163        if user_version(self.transaction.get()).unwrap() == S::VERSION {
164            let res = std::thread::scope(|s| {
165                s.spawn(|| {
166                    TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
167                    let txn = Transaction::new_ref();
168
169                    check_schema::<S>(txn);
170                    if !self.indices_fixed {
171                        fix_indices::<S>(txn);
172                        self.indices_fixed = true;
173                    }
174
175                    let mut txn = TransactionMigrate {
176                        inner: Transaction::new(),
177                        scope: Default::default(),
178                        rename_map: HashMap::new(),
179                        extra_index: Vec::new(),
180                    };
181                    let m = m(&mut txn);
182
183                    let mut builder = SchemaBuilder {
184                        drop: vec![],
185                        foreign_key: HashMap::new(),
186                        inner: txn,
187                    };
188                    m.tables(&mut builder);
189
190                    let transaction = TXN.take().unwrap();
191
192                    for drop in builder.drop {
193                        let sql = drop.to_string(SqliteQueryBuilder);
194                        transaction.get().execute(&sql, []).unwrap();
195                    }
196                    for (to, tmp) in builder.inner.rename_map {
197                        let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
198                        let sql = rename.to_string(SqliteQueryBuilder);
199                        transaction.get().execute(&sql, []).unwrap();
200                    }
201                    if let Some(fk) = foreign_key_check(transaction.get()) {
202                        (builder.foreign_key.remove(&*fk).unwrap())();
203                    }
204                    #[allow(
205                        unreachable_code,
206                        reason = "rustc is stupid and thinks this is unreachable"
207                    )]
208                    // adding non unique indexes is fine to do after checking foreign keys
209                    for stmt in builder.inner.extra_index {
210                        transaction.get().execute(&stmt, []).unwrap();
211                    }
212                    set_user_version(transaction.get(), M::To::VERSION).unwrap();
213
214                    transaction.into_owner()
215                })
216                .join()
217            });
218            match res {
219                Ok(val) => self.transaction = val,
220                Err(payload) => std::panic::resume_unwind(payload),
221            }
222        }
223
224        Migrator {
225            indices_fixed: self.indices_fixed,
226            manager: self.manager,
227            transaction: self.transaction,
228            _p: PhantomData,
229        }
230    }
231
232    /// Commit the migration transaction and return a [Database].
233    ///
234    /// Returns [None] if the database schema version is newer than `S`.
235    ///
236    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
237    pub fn finish(mut self) -> Option<Database<S>> {
238        if user_version(self.transaction.get()).unwrap() != S::VERSION {
239            return None;
240        }
241
242        let res = std::thread::scope(|s| {
243            s.spawn(|| {
244                TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
245                let txn = Transaction::new_ref();
246
247                check_schema::<S>(txn);
248                if !self.indices_fixed {
249                    fix_indices::<S>(txn);
250                    self.indices_fixed = true;
251                }
252
253                TXN.take().unwrap().into_owner()
254            })
255            .join()
256        });
257        match res {
258            Ok(val) => self.transaction = val,
259            Err(payload) => std::panic::resume_unwind(payload),
260        }
261
262        // adds an sqlite_stat1 table
263        self.transaction
264            .get()
265            .execute_batch("PRAGMA optimize;")
266            .unwrap();
267
268        let schema_version = schema_version(self.transaction.get());
269        self.transaction.with(|x| x.commit().unwrap());
270
271        Some(Database {
272            manager: Pool::new(self.manager),
273            schema_version: AtomicI64::new(schema_version),
274            schema: PhantomData,
275            mut_lock: parking_lot::FairMutex::new(()),
276        })
277    }
278}
279
280fn fix_indices<S: Schema>(txn: &Transaction<S>) {
281    let schema = read_schema(txn);
282    let expected_schema = crate::schema::from_macro::Schema::new::<S>();
283
284    fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
285        let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
286        let actual: BTreeSet<_> = actual.indices.values().collect();
287        expected == actual
288    }
289
290    for (&table_name, expected_table) in &expected_schema.tables {
291        let table = &schema.tables[table_name];
292
293        if !check_eq(expected_table, &table) {
294            // Unique constraints that are part of a table definition
295            // can not be dropped, so we assume the worst and just recreate
296            // the whole table.
297
298            let scope = Scope::default();
299            let tmp_name = scope.tmp_table();
300
301            txn.execute(&new_table_inner(expected_table, tmp_name));
302
303            let mut columns: Vec<_> = expected_table
304                .columns
305                .keys()
306                .map(|x| Alias::new(x))
307                .collect();
308            columns.push(Alias::new("id"));
309
310            txn.execute(
311                &sea_query::InsertStatement::new()
312                    .into_table(tmp_name)
313                    .columns(columns.clone())
314                    .select_from(
315                        sea_query::SelectStatement::new()
316                            .from(table_name)
317                            .columns(columns)
318                            .take(),
319                    )
320                    .unwrap()
321                    .build(SqliteQueryBuilder)
322                    .0,
323            );
324
325            txn.execute(
326                &sea_query::TableDropStatement::new()
327                    .table(table_name)
328                    .build(SqliteQueryBuilder),
329            );
330
331            txn.execute(
332                &sea_query::TableRenameStatement::new()
333                    .table(tmp_name, table_name)
334                    .build(SqliteQueryBuilder),
335            );
336            // Add the new non-unique indices
337            for sql in expected_table.delayed_indices(table_name) {
338                txn.execute(&sql);
339            }
340        }
341    }
342
343    // check that we solved the mismatch
344    let schema = read_schema(txn);
345    for (name, table) in schema.tables {
346        let expected_table = &expected_schema.tables[&*name];
347        assert!(check_eq(expected_table, &table));
348    }
349}
350
351impl<S> Transaction<S> {
352    #[track_caller]
353    pub(crate) fn execute(&self, sql: &str) {
354        TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
355            .unwrap();
356    }
357}
358
359pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
360    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
361        .unwrap()
362}
363
364// Read user version field from the SQLite db
365pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
366    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
367}
368
369// Set user version field from the SQLite db
370fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
371    conn.pragma_update(None, "user_version", v)
372}
373
374pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
375    let from_macro = crate::schema::from_macro::Schema::new::<S>();
376    let from_db = read_schema(txn);
377    let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
378    if !report.is_empty() {
379        let renderer = if cfg!(test) {
380            Renderer::plain().anonymized_line_numbers(true)
381        } else {
382            Renderer::styled()
383        }
384        .decor_style(DecorStyle::Unicode);
385        panic!("{}", renderer.render(&report));
386    }
387}
388
389fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
390    let error = conn
391        .prepare("PRAGMA foreign_key_check")
392        .unwrap()
393        .query_map([], |row| row.get(2))
394        .unwrap()
395        .next();
396    error.transpose().unwrap()
397}
398
399impl<S> Transaction<S> {
400    #[cfg(test)]
401    pub(crate) fn schema(&self) -> Vec<String> {
402        TXN.with_borrow(|x| {
403            x.as_ref()
404                .unwrap()
405                .get()
406                .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
407                .unwrap()
408                .query_map([], |row| row.get::<_, Option<String>>("sql"))
409                .unwrap()
410                .flat_map(|x| x.unwrap())
411                .collect()
412        })
413    }
414}