rust_query/
migrate.rs

1pub mod config;
2pub mod migration;
3
4use std::{collections::HashMap, marker::PhantomData, sync::atomic::AtomicI64};
5
6use rusqlite::{Connection, config::DbConfig};
7use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder};
8use self_cell::MutBorrow;
9
10use crate::{
11    Table, Transaction, hash,
12    migrate::{
13        config::Config,
14        migration::{SchemaBuilder, TransactionMigrate},
15    },
16    schema_pragma::{read_index_names_for_table, read_schema},
17    transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
18};
19
20pub struct TableTypBuilder<S> {
21    pub(crate) ast: hash::Schema,
22    _p: PhantomData<S>,
23}
24
25impl<S> Default for TableTypBuilder<S> {
26    fn default() -> Self {
27        Self {
28            ast: Default::default(),
29            _p: Default::default(),
30        }
31    }
32}
33
34impl<S> TableTypBuilder<S> {
35    pub fn table<T: Table<Schema = S>>(&mut self) {
36        let table = hash::Table::new::<T>();
37        let old = self.ast.tables.insert(T::NAME.to_owned(), table);
38        debug_assert!(old.is_none());
39    }
40}
41
42pub trait Schema: Sized + 'static {
43    const VERSION: i64;
44    fn typs(b: &mut TableTypBuilder<Self>);
45}
46
47fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
48    let mut create = table.create();
49    create
50        .table(alias)
51        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
52    let mut sql = create.to_string(SqliteQueryBuilder);
53    sql.push_str(" STRICT");
54    conn.execute(&sql, []).unwrap();
55}
56
57pub trait SchemaMigration<'a> {
58    type From: Schema;
59    type To: Schema;
60
61    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
62}
63
64impl<S: Schema> Database<S> {
65    /// Create a [Migrator] to migrate a database.
66    ///
67    /// Returns [None] if the database `user_version` on disk is older than `S`.
68    pub fn migrator(config: Config) -> Option<Migrator<S>> {
69        let synchronous = config.synchronous.as_str();
70        let foreign_keys = config.foreign_keys.as_str();
71        let manager = config.manager.with_init(move |inner| {
72            inner.pragma_update(None, "journal_mode", "WAL")?;
73            inner.pragma_update(None, "synchronous", synchronous)?;
74            inner.pragma_update(None, "foreign_keys", foreign_keys)?;
75            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
76            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
77            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
78            Ok(())
79        });
80
81        use r2d2::ManageConnection;
82        let conn = manager.connect().unwrap();
83        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
84        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
85            Some(
86                conn.borrow_mut()
87                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
88                    .unwrap(),
89            )
90        });
91
92        // check if this database is newly created
93        if schema_version(txn.get()) == 0 {
94            let schema = crate::hash::Schema::new::<S>();
95
96            for (table_name, table) in &schema.tables {
97                let table_name_ref = Alias::new(table_name);
98                new_table_inner(txn.get(), table, table_name_ref);
99                for stmt in table.create_indices(table_name) {
100                    txn.get().execute(&stmt, []).unwrap();
101                }
102            }
103            (config.init)(txn.get());
104            set_user_version(txn.get(), S::VERSION).unwrap();
105        }
106
107        let user_version = user_version(txn.get()).unwrap();
108        // We can not migrate databases older than `S`
109        if user_version < S::VERSION {
110            return None;
111        }
112        debug_assert_eq!(
113            foreign_key_check(txn.get()),
114            None,
115            "foreign key constraint violated"
116        );
117
118        Some(Migrator {
119            indices_fixed: false,
120            manager,
121            transaction: txn,
122            _p: PhantomData,
123        })
124    }
125}
126
127/// [Migrator] is used to apply database migrations.
128///
129/// When all migrations are done, it can be turned into a [Database] instance with
130/// [Migrator::finish].
131pub struct Migrator<S> {
132    manager: r2d2_sqlite::SqliteConnectionManager,
133    transaction: OwnedTransaction,
134    indices_fixed: bool,
135    _p: PhantomData<S>,
136}
137
138impl<S: Schema> Migrator<S> {
139    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
140    ///
141    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
142    pub fn migrate<'x, M>(
143        mut self,
144        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
145    ) -> Migrator<M::To>
146    where
147        M: SchemaMigration<'x, From = S>,
148    {
149        if user_version(self.transaction.get()).unwrap() == S::VERSION {
150            let res = std::thread::scope(|s| {
151                s.spawn(|| {
152                    TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
153                    let txn = Transaction::new_ref();
154
155                    check_schema::<S>(txn);
156                    if !self.indices_fixed {
157                        fix_indices::<S>(txn);
158                        self.indices_fixed = true;
159                    }
160
161                    let mut txn = TransactionMigrate {
162                        inner: Transaction::new(),
163                        scope: Default::default(),
164                        rename_map: HashMap::new(),
165                        extra_index: Vec::new(),
166                    };
167                    let m = m(&mut txn);
168
169                    let mut builder = SchemaBuilder {
170                        drop: vec![],
171                        foreign_key: HashMap::new(),
172                        inner: txn,
173                    };
174                    m.tables(&mut builder);
175
176                    let transaction = TXN.take().unwrap();
177
178                    for drop in builder.drop {
179                        let sql = drop.to_string(SqliteQueryBuilder);
180                        transaction.get().execute(&sql, []).unwrap();
181                    }
182                    for (to, tmp) in builder.inner.rename_map {
183                        let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
184                        let sql = rename.to_string(SqliteQueryBuilder);
185                        transaction.get().execute(&sql, []).unwrap();
186                    }
187                    if let Some(fk) = foreign_key_check(transaction.get()) {
188                        (builder.foreign_key.remove(&*fk).unwrap())();
189                    }
190                    #[allow(
191                        unreachable_code,
192                        reason = "rustc is stupid and thinks this is unreachable"
193                    )]
194                    // adding indexes is fine to do after checking foreign keys
195                    for stmt in builder.inner.extra_index {
196                        transaction.get().execute(&stmt, []).unwrap();
197                    }
198                    set_user_version(transaction.get(), M::To::VERSION).unwrap();
199
200                    transaction.into_owner()
201                })
202                .join()
203            });
204            match res {
205                Ok(val) => self.transaction = val,
206                Err(payload) => std::panic::resume_unwind(payload),
207            }
208        }
209
210        Migrator {
211            indices_fixed: self.indices_fixed,
212            manager: self.manager,
213            transaction: self.transaction,
214            _p: PhantomData,
215        }
216    }
217
218    /// Commit the migration transaction and return a [Database].
219    ///
220    /// Returns [None] if the database schema version is newer than `S`.
221    ///
222    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
223    pub fn finish(mut self) -> Option<Database<S>> {
224        if user_version(self.transaction.get()).unwrap() != S::VERSION {
225            return None;
226        }
227
228        let res = std::thread::scope(|s| {
229            s.spawn(|| {
230                TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
231                let txn = Transaction::new_ref();
232
233                check_schema::<S>(txn);
234                if !self.indices_fixed {
235                    fix_indices::<S>(txn);
236                    self.indices_fixed = true;
237                }
238
239                TXN.take().unwrap().into_owner()
240            })
241            .join()
242        });
243        match res {
244            Ok(val) => self.transaction = val,
245            Err(payload) => std::panic::resume_unwind(payload),
246        }
247
248        // adds an sqlite_stat1 table
249        self.transaction
250            .get()
251            .execute_batch("PRAGMA optimize;")
252            .unwrap();
253
254        let schema_version = schema_version(self.transaction.get());
255        self.transaction.with(|x| x.commit().unwrap());
256
257        Some(Database {
258            manager: self.manager,
259            schema_version: AtomicI64::new(schema_version),
260            schema: PhantomData,
261            mut_lock: parking_lot::FairMutex::new(()),
262        })
263    }
264}
265
266fn fix_indices<S: Schema>(txn: &Transaction<S>) {
267    let schema = read_schema(txn);
268    let expected_schema = crate::hash::Schema::new::<S>();
269
270    for (name, table) in schema.tables {
271        let expected_table = &expected_schema.tables[&name];
272
273        if expected_table.indices != table.indices {
274            // Delete all indices associated with the table
275            for index_name in read_index_names_for_table(&crate::Transaction::new(), &name) {
276                let sql = sea_query::Index::drop()
277                    .name(index_name)
278                    .build(SqliteQueryBuilder);
279                txn.execute(&sql);
280            }
281
282            // Add the new indices
283            for sql in expected_table.create_indices(&name) {
284                txn.execute(&sql);
285            }
286        }
287    }
288
289    assert_eq!(expected_schema, read_schema(txn));
290}
291
292impl<S> Transaction<S> {
293    pub(crate) fn execute(&self, sql: &str) {
294        TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
295            .unwrap();
296    }
297}
298
299pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
300    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
301        .unwrap()
302}
303
304// Read user version field from the SQLite db
305pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
306    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
307}
308
309// Set user version field from the SQLite db
310fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
311    conn.pragma_update(None, "user_version", v)
312}
313
314pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
315    // normalize both sides, because we only care about compatibility
316    pretty_assertions::assert_eq!(
317        crate::hash::Schema::new::<S>().normalize(),
318        read_schema(txn).normalize(),
319        "schema is different (expected left, but got right)",
320    );
321}
322
323fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
324    let error = conn
325        .prepare("PRAGMA foreign_key_check")
326        .unwrap()
327        .query_map([], |row| row.get(2))
328        .unwrap()
329        .next();
330    error.transpose().unwrap()
331}
332
333impl<S> Transaction<S> {
334    #[cfg(test)]
335    pub(crate) fn schema(&self) -> Vec<String> {
336        TXN.with_borrow(|x| {
337            x.as_ref()
338                .unwrap()
339                .get()
340                .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
341                .unwrap()
342                .query_map([], |row| row.get("sql"))
343                .unwrap()
344                .map(|x| x.unwrap())
345                .collect()
346        })
347    }
348}
349
350impl<S: Send + Sync + Schema> Database<S> {
351    #[cfg(test)]
352    fn check_schema(&self, expect: expect_test::Expect) {
353        let mut schema = self.transaction(|txn| txn.schema());
354        schema.sort();
355        expect.assert_eq(&schema.join("\n"));
356    }
357}
358
359#[test]
360fn fix_indices_test() {
361    mod without_index {
362        #[crate::migration::schema(Schema)]
363        pub mod vN {
364            pub struct Foo {
365                pub bar: String,
366            }
367        }
368    }
369
370    mod with_index {
371        #[crate::migration::schema(Schema)]
372        pub mod vN {
373            pub struct Foo {
374                #[index]
375                pub bar: String,
376            }
377        }
378    }
379
380    let db = Database::<without_index::v0::Schema>::migrator(Config::open("index_test.sqlite"))
381        .unwrap()
382        .finish()
383        .unwrap();
384    // The first database is opened with a schema without index
385    db.check_schema(expect_test::expect![[
386        r#"CREATE TABLE "foo" ( "bar" text NOT NULL, "id" integer PRIMARY KEY ) STRICT"#
387    ]]);
388
389    let db_with_index =
390        Database::<with_index::v0::Schema>::migrator(Config::open("index_test.sqlite"))
391            .unwrap()
392            .finish()
393            .unwrap();
394    // The database is updated without a new schema version.
395    // Adding an index is allowed because it does not change database validity.
396    db_with_index.check_schema(expect_test::expect![[r#"
397        CREATE INDEX "foo_index_0" ON "foo" ("bar")
398        CREATE TABLE "foo" ( "bar" text NOT NULL, "id" integer PRIMARY KEY ) STRICT"#]]);
399
400    // Using the old database connection will still work, because the new schema is compatible.
401    db.check_schema(expect_test::expect![[r#"
402        CREATE INDEX "foo_index_0" ON "foo" ("bar")
403        CREATE TABLE "foo" ( "bar" text NOT NULL, "id" integer PRIMARY KEY ) STRICT"#]]);
404
405    let db = Database::<without_index::v0::Schema>::migrator(Config::open("index_test.sqlite"))
406        .unwrap()
407        .finish()
408        .unwrap();
409    // Opening the database with the old schema again removes the index.
410    db.check_schema(expect_test::expect![[
411        r#"CREATE TABLE "foo" ( "bar" text NOT NULL, "id" integer PRIMARY KEY ) STRICT"#
412    ]]);
413}