Skip to main content

rust_query/
migrate.rs

1pub mod config;
2mod fix_by_copy;
3pub mod migration;
4#[cfg(test)]
5mod test;
6
7use std::{collections::HashMap, marker::PhantomData, sync::atomic::AtomicI64};
8
9use annotate_snippets::{Renderer, renderer::DecorStyle};
10use self_cell::MutBorrow;
11
12use crate::{
13    Table, Transaction,
14    lower::{self, list_writer::Alias},
15    migrate::{
16        config::Config,
17        fix_by_copy::fix_by_copy,
18        migration::{SchemaBuilder, TransactionMigrate},
19    },
20    pool::Pool,
21    schema::{from_macro, read::read_schema},
22    transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
23};
24
25pub struct TableTypBuilder<S> {
26    pub(crate) ast: from_macro::Schema,
27    _p: PhantomData<S>,
28}
29
30impl<S> Default for TableTypBuilder<S> {
31    fn default() -> Self {
32        Self {
33            ast: Default::default(),
34            _p: Default::default(),
35        }
36    }
37}
38
39impl<S> TableTypBuilder<S> {
40    pub fn table<T: Table<Schema = S>>(&mut self) {
41        let table = from_macro::Table::new::<T>();
42        let old = self.ast.tables.insert(T::NAME, table);
43        debug_assert!(old.is_none());
44    }
45}
46
47pub trait Schema: Sized + 'static {
48    const VERSION: i64;
49    const SOURCE: &str;
50    const PATH: &str;
51    const SPAN: (usize, usize);
52    fn typs(b: &mut TableTypBuilder<Self>);
53}
54
55pub trait SchemaMigration<'a> {
56    type From: Schema;
57    type To: Schema;
58
59    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
60}
61
62impl<S: Schema> Database<S> {
63    /// Create a [Migrator] to migrate a database.
64    ///
65    /// Returns [None] if the database `user_version` on disk is older than `S`.
66    pub fn migrator(config: Config) -> Option<Migrator<S>> {
67        let pool = Pool::new(config);
68
69        let conn = pool.pop();
70        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
71        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
72            Some(
73                conn.borrow_mut()
74                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
75                    .unwrap(),
76            )
77        });
78
79        let mut user_version = Some(user_version(txn.get()).unwrap());
80
81        // check if this database is newly created
82        if schema_version(txn.get()) == 0 {
83            user_version = None;
84
85            let schema = crate::schema::from_macro::Schema::new::<S>();
86
87            for (table_name, table) in schema.tables {
88                let table = table.to_db();
89                let create = table.create(lower::JoinableTable::Table(table_name), "id");
90                txn.get().execute(&create, []).unwrap();
91                for stmt in table.delayed_indices(table_name) {
92                    txn.get().execute(&stmt, []).unwrap();
93                }
94            }
95        } else if user_version.unwrap() < S::VERSION {
96            // We can not migrate databases older than `S`
97            return None;
98        }
99
100        debug_assert_eq!(
101            foreign_key_check(txn.get()),
102            None,
103            "foreign key constraint violated"
104        );
105
106        Some(Migrator {
107            user_version,
108            pool,
109            transaction: txn,
110            _p: PhantomData,
111        })
112    }
113}
114
115/// [Migrator] is used to apply database migrations.
116///
117/// When all migrations are done, it can be turned into a [Database] instance with
118/// [Migrator::finish].
119pub struct Migrator<S> {
120    pool: Pool,
121    transaction: OwnedTransaction,
122    // Initialized to the user version when the transaction starts.
123    // This is set to None if the schema user_version is updated.
124    // Fixups are only applied if the user_version is None.
125    // Indices are fixed before this is set to None.
126    user_version: Option<i64>,
127    _p: PhantomData<S>,
128}
129
130impl<S: Schema> Migrator<S> {
131    fn with_transaction(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
132        assert!(self.user_version.is_none_or(|x| x == S::VERSION));
133        let res = std::thread::scope(|s| {
134            s.spawn(|| {
135                TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
136                let txn = Transaction::new_ref();
137
138                // check if this is the first migration that is applied
139                if self.user_version.take().is_some() {
140                    // we check the schema before doing any migrations
141                    check_schema::<S>(txn, false);
142                    // fixing indices before migrations can help with migration performance
143                    fix_by_copy::<S>(txn, fix_by_copy::Detail::Indexes);
144                }
145
146                f(txn);
147
148                let transaction = TXN.take().unwrap();
149
150                transaction.into_owner()
151            })
152            .join()
153        });
154        match res {
155            Ok(val) => self.transaction = val,
156            Err(payload) => std::panic::resume_unwind(payload),
157        }
158        self
159    }
160
161    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
162    ///
163    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
164    pub fn migrate<'x, M>(
165        mut self,
166        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
167    ) -> Migrator<M::To>
168    where
169        M: SchemaMigration<'x, From = S>,
170    {
171        if self.user_version.is_none_or(|x| x == S::VERSION) {
172            self = self.with_transaction(|txn| {
173                let mut txn = TransactionMigrate {
174                    inner: txn.copy(),
175                    scope: Default::default(),
176                    rename_map: HashMap::new(),
177                    extra_index: Vec::new(),
178                };
179                let m = m(&mut txn);
180
181                let mut builder = SchemaBuilder {
182                    drop: vec![],
183                    foreign_key: HashMap::new(),
184                    inner: txn,
185                };
186                m.tables(&mut builder);
187                let txn = builder.inner.inner;
188
189                for drop in builder.drop {
190                    txn.execute(&drop);
191                }
192                for (to, tmp) in builder.inner.rename_map {
193                    txn.execute(&format!("ALTER TABLE main.{tmp} RENAME TO {}", Alias(to)));
194                }
195                for stmt in builder.inner.extra_index {
196                    txn.execute(&stmt);
197                }
198
199                // Change transaction schema because we are now on the new version already
200                fix_by_copy::<M::To>(&Transaction::new(), fix_by_copy::Detail::ForeignKeys);
201
202                let transaction = TXN.take().unwrap();
203                if let Some(fk) = foreign_key_check(transaction.get()) {
204                    (builder.foreign_key.remove(&*fk).unwrap())();
205                }
206
207                TXN.set(Some(transaction));
208            });
209        }
210
211        Migrator {
212            user_version: self.user_version,
213            pool: self.pool,
214            transaction: self.transaction,
215            _p: PhantomData,
216        }
217    }
218
219    /// Mutate the database as part of migrations.
220    ///
221    /// The closure will only be executed if the database got migrated to schema version `S`
222    /// by this [Migrator] instance.
223    /// If [Migrator::fixup] is used before all [Migrator::migrate], then the closures is only executed
224    /// when the database is created.
225    pub fn fixup(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
226        if self.user_version.is_none() {
227            self = self.with_transaction(f);
228        }
229        self
230    }
231
232    /// Commit the migration transaction and return a [Database].
233    ///
234    /// Returns [None] if the database schema version is newer than `S`.
235    ///
236    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
237    pub fn finish(mut self) -> Option<Database<S>> {
238        if self.user_version.is_some_and(|x| x != S::VERSION) {
239            return None;
240        }
241
242        // This checks that the schema is correct and fixes indices etc
243        self = self.with_transaction(|txn| {
244            // sanity check, this should never fail
245            check_schema::<S>(txn, true);
246        });
247
248        // adds an sqlite_stat1 table
249        self.transaction
250            .get()
251            .execute_batch("PRAGMA optimize;")
252            .unwrap();
253
254        set_user_version(self.transaction.get(), S::VERSION).unwrap();
255        let schema_version = schema_version(self.transaction.get());
256        self.transaction.with(|x| x.commit().unwrap());
257
258        Some(Database {
259            pool: self.pool,
260            schema_version: AtomicI64::new(schema_version),
261            schema: PhantomData,
262            mut_lock: parking_lot::FairMutex::new(()),
263        })
264    }
265}
266
267impl<S> Transaction<S> {
268    #[track_caller]
269    pub(crate) fn execute(&self, sql: &str) {
270        TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
271            .unwrap();
272    }
273}
274
275pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
276    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
277        .unwrap()
278}
279
280// Read user version field from the SQLite db
281pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
282    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
283}
284
285// Set user version field from the SQLite db
286fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
287    conn.pragma_update(None, "user_version", v)
288}
289
290pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>, sanity: bool) {
291    let from_macro = crate::schema::from_macro::Schema::new::<S>();
292    let from_db = read_schema(txn);
293    let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
294    if !report.is_empty() {
295        let renderer = if cfg!(test) {
296            Renderer::plain().anonymized_line_numbers(true)
297        } else {
298            Renderer::styled()
299        }
300        .decor_style(DecorStyle::Unicode);
301        if sanity {
302            unreachable!("THIS IS A RUST-QUERY BUG {}", renderer.render(&report));
303        } else {
304            panic!("{}", renderer.render(&report));
305        }
306    }
307}
308
309fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
310    let error = conn
311        .prepare("PRAGMA foreign_key_check")
312        .unwrap()
313        .query_map([], |row| row.get(2))
314        .unwrap()
315        .next();
316    error.transpose().unwrap()
317}
318
319impl<S> Transaction<S> {
320    #[cfg(test)]
321    pub(crate) fn schema(&self) -> Vec<String> {
322        TXN.with_borrow(|x| {
323            x.as_ref()
324                .unwrap()
325                .get()
326                .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
327                .unwrap()
328                .query_map([], |row| row.get::<_, Option<String>>("sql"))
329                .unwrap()
330                .flat_map(|x| x.unwrap())
331                .collect()
332        })
333    }
334}