rust_query/
migrate.rs

1pub mod config;
2pub mod migration;
3
4use std::{
5    collections::{BTreeSet, HashMap},
6    marker::PhantomData,
7    sync::atomic::AtomicI64,
8};
9
10use annotate_snippets::{Renderer, renderer::DecorStyle};
11use rusqlite::{Connection, config::DbConfig};
12use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder};
13use self_cell::MutBorrow;
14
15use crate::{
16    Table, Transaction,
17    migrate::{
18        config::Config,
19        migration::{SchemaBuilder, TransactionMigrate},
20    },
21    schema::{
22        from_db, from_macro,
23        read::{read_index_names_for_table, read_schema},
24    },
25    transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
26};
27
28pub struct TableTypBuilder<S> {
29    pub(crate) ast: from_macro::Schema,
30    _p: PhantomData<S>,
31}
32
33impl<S> Default for TableTypBuilder<S> {
34    fn default() -> Self {
35        Self {
36            ast: Default::default(),
37            _p: Default::default(),
38        }
39    }
40}
41
42impl<S> TableTypBuilder<S> {
43    pub fn table<T: Table<Schema = S>>(&mut self) {
44        let table = from_macro::Table::new::<T>();
45        let old = self.ast.tables.insert(T::NAME.to_owned(), table);
46        debug_assert!(old.is_none());
47    }
48}
49
50pub trait Schema: Sized + 'static {
51    const VERSION: i64;
52    const SOURCE: &str;
53    const PATH: &str;
54    const SPAN: (usize, usize);
55    fn typs(b: &mut TableTypBuilder<Self>);
56}
57
58fn new_table_inner(
59    conn: &Connection,
60    table: &crate::schema::from_macro::Table,
61    alias: impl IntoTableRef,
62) {
63    let mut create = table.create();
64    create
65        .table(alias)
66        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
67    let mut sql = create.to_string(SqliteQueryBuilder);
68    sql.push_str(" STRICT");
69    conn.execute(&sql, []).unwrap();
70}
71
72pub trait SchemaMigration<'a> {
73    type From: Schema;
74    type To: Schema;
75
76    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
77}
78
79impl<S: Schema> Database<S> {
80    /// Create a [Migrator] to migrate a database.
81    ///
82    /// Returns [None] if the database `user_version` on disk is older than `S`.
83    pub fn migrator(config: Config) -> Option<Migrator<S>> {
84        let synchronous = config.synchronous.as_str();
85        let foreign_keys = config.foreign_keys.as_str();
86        let manager = config.manager.with_init(move |inner| {
87            inner.pragma_update(None, "journal_mode", "WAL")?;
88            inner.pragma_update(None, "synchronous", synchronous)?;
89            inner.pragma_update(None, "foreign_keys", foreign_keys)?;
90            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
91            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
92            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
93            Ok(())
94        });
95
96        use r2d2::ManageConnection;
97        let conn = manager.connect().unwrap();
98        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
99        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
100            Some(
101                conn.borrow_mut()
102                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
103                    .unwrap(),
104            )
105        });
106
107        // check if this database is newly created
108        if schema_version(txn.get()) == 0 {
109            let schema = crate::schema::from_macro::Schema::new::<S>();
110
111            for (table_name, table) in &schema.tables {
112                let table_name_ref = Alias::new(table_name);
113                new_table_inner(txn.get(), table, table_name_ref);
114                for stmt in table.create_indices(table_name) {
115                    txn.get().execute(&stmt, []).unwrap();
116                }
117            }
118            (config.init)(txn.get());
119            set_user_version(txn.get(), S::VERSION).unwrap();
120        }
121
122        let user_version = user_version(txn.get()).unwrap();
123        // We can not migrate databases older than `S`
124        if user_version < S::VERSION {
125            return None;
126        }
127        debug_assert_eq!(
128            foreign_key_check(txn.get()),
129            None,
130            "foreign key constraint violated"
131        );
132
133        Some(Migrator {
134            indices_fixed: false,
135            manager,
136            transaction: txn,
137            _p: PhantomData,
138        })
139    }
140}
141
142/// [Migrator] is used to apply database migrations.
143///
144/// When all migrations are done, it can be turned into a [Database] instance with
145/// [Migrator::finish].
146pub struct Migrator<S> {
147    manager: r2d2_sqlite::SqliteConnectionManager,
148    transaction: OwnedTransaction,
149    indices_fixed: bool,
150    _p: PhantomData<S>,
151}
152
153impl<S: Schema> Migrator<S> {
154    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
155    ///
156    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
157    pub fn migrate<'x, M>(
158        mut self,
159        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
160    ) -> Migrator<M::To>
161    where
162        M: SchemaMigration<'x, From = S>,
163    {
164        if user_version(self.transaction.get()).unwrap() == S::VERSION {
165            let res = std::thread::scope(|s| {
166                s.spawn(|| {
167                    TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
168                    let txn = Transaction::new_ref();
169
170                    check_schema::<S>(txn);
171                    if !self.indices_fixed {
172                        fix_indices::<S>(txn);
173                        self.indices_fixed = true;
174                    }
175
176                    let mut txn = TransactionMigrate {
177                        inner: Transaction::new(),
178                        scope: Default::default(),
179                        rename_map: HashMap::new(),
180                        extra_index: Vec::new(),
181                    };
182                    let m = m(&mut txn);
183
184                    let mut builder = SchemaBuilder {
185                        drop: vec![],
186                        foreign_key: HashMap::new(),
187                        inner: txn,
188                    };
189                    m.tables(&mut builder);
190
191                    let transaction = TXN.take().unwrap();
192
193                    for drop in builder.drop {
194                        let sql = drop.to_string(SqliteQueryBuilder);
195                        transaction.get().execute(&sql, []).unwrap();
196                    }
197                    for (to, tmp) in builder.inner.rename_map {
198                        let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
199                        let sql = rename.to_string(SqliteQueryBuilder);
200                        transaction.get().execute(&sql, []).unwrap();
201                    }
202                    if let Some(fk) = foreign_key_check(transaction.get()) {
203                        (builder.foreign_key.remove(&*fk).unwrap())();
204                    }
205                    #[allow(
206                        unreachable_code,
207                        reason = "rustc is stupid and thinks this is unreachable"
208                    )]
209                    // adding indexes is fine to do after checking foreign keys
210                    for stmt in builder.inner.extra_index {
211                        transaction.get().execute(&stmt, []).unwrap();
212                    }
213                    set_user_version(transaction.get(), M::To::VERSION).unwrap();
214
215                    transaction.into_owner()
216                })
217                .join()
218            });
219            match res {
220                Ok(val) => self.transaction = val,
221                Err(payload) => std::panic::resume_unwind(payload),
222            }
223        }
224
225        Migrator {
226            indices_fixed: self.indices_fixed,
227            manager: self.manager,
228            transaction: self.transaction,
229            _p: PhantomData,
230        }
231    }
232
233    /// Commit the migration transaction and return a [Database].
234    ///
235    /// Returns [None] if the database schema version is newer than `S`.
236    ///
237    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
238    pub fn finish(mut self) -> Option<Database<S>> {
239        if user_version(self.transaction.get()).unwrap() != S::VERSION {
240            return None;
241        }
242
243        let res = std::thread::scope(|s| {
244            s.spawn(|| {
245                TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
246                let txn = Transaction::new_ref();
247
248                check_schema::<S>(txn);
249                if !self.indices_fixed {
250                    fix_indices::<S>(txn);
251                    self.indices_fixed = true;
252                }
253
254                TXN.take().unwrap().into_owner()
255            })
256            .join()
257        });
258        match res {
259            Ok(val) => self.transaction = val,
260            Err(payload) => std::panic::resume_unwind(payload),
261        }
262
263        // adds an sqlite_stat1 table
264        self.transaction
265            .get()
266            .execute_batch("PRAGMA optimize;")
267            .unwrap();
268
269        let schema_version = schema_version(self.transaction.get());
270        self.transaction.with(|x| x.commit().unwrap());
271
272        Some(Database {
273            manager: self.manager,
274            schema_version: AtomicI64::new(schema_version),
275            schema: PhantomData,
276            mut_lock: parking_lot::FairMutex::new(()),
277        })
278    }
279}
280
281fn fix_indices<S: Schema>(txn: &Transaction<S>) {
282    let schema = read_schema(txn);
283    let expected_schema = crate::schema::from_macro::Schema::new::<S>();
284
285    fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
286        let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
287        let actual: BTreeSet<_> = actual.indices.values().collect();
288        expected == actual
289    }
290
291    for (name, table) in schema.tables {
292        let expected_table = &expected_schema.tables[&name];
293
294        if !check_eq(expected_table, &table) {
295            // Delete all indices associated with the table
296            for index_name in read_index_names_for_table(&crate::Transaction::new(), &name) {
297                let sql = sea_query::Index::drop()
298                    .name(index_name)
299                    .build(SqliteQueryBuilder);
300                txn.execute(&sql);
301            }
302
303            // Add the new indices
304            for sql in expected_table.create_indices(&name) {
305                txn.execute(&sql);
306            }
307        }
308    }
309
310    // check that we solved the mismatch
311    let schema = read_schema(txn);
312    for (name, table) in schema.tables {
313        let expected_table = &expected_schema.tables[&name];
314        assert!(check_eq(expected_table, &table));
315    }
316}
317
318impl<S> Transaction<S> {
319    pub(crate) fn execute(&self, sql: &str) {
320        TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
321            .unwrap();
322    }
323}
324
325pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
326    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
327        .unwrap()
328}
329
330// Read user version field from the SQLite db
331pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
332    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
333}
334
335// Set user version field from the SQLite db
336fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
337    conn.pragma_update(None, "user_version", v)
338}
339
340pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
341    let from_macro = crate::schema::from_macro::Schema::new::<S>();
342    let from_db = read_schema(txn);
343    let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
344    if !report.is_empty() {
345        let renderer = if cfg!(test) {
346            Renderer::plain().anonymized_line_numbers(true)
347        } else {
348            Renderer::styled()
349        }
350        .decor_style(DecorStyle::Unicode);
351        panic!("{}", renderer.render(&report));
352    }
353}
354
355fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
356    let error = conn
357        .prepare("PRAGMA foreign_key_check")
358        .unwrap()
359        .query_map([], |row| row.get(2))
360        .unwrap()
361        .next();
362    error.transpose().unwrap()
363}
364
365impl<S> Transaction<S> {
366    #[cfg(test)]
367    pub(crate) fn schema(&self) -> Vec<String> {
368        TXN.with_borrow(|x| {
369            x.as_ref()
370                .unwrap()
371                .get()
372                .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
373                .unwrap()
374                .query_map([], |row| row.get("sql"))
375                .unwrap()
376                .map(|x| x.unwrap())
377                .collect()
378        })
379    }
380}