1use std::{
2 collections::{HashMap, HashSet},
3 convert::Infallible,
4 marker::PhantomData,
5 ops::{Deref, Not},
6 path::Path,
7 sync::atomic::AtomicI64,
8};
9
10use rusqlite::{Connection, config::DbConfig};
11use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder, TableDropStatement};
12use self_cell::MutBorrow;
13
14use crate::{
15 FromExpr, IntoExpr, Table, TableRow, Transaction,
16 alias::{Scope, TmpTable},
17 hash,
18 schema_pragma::read_schema,
19 transaction::{Database, OwnedTransaction, TXN, try_insert_private},
20};
21
22pub struct TableTypBuilder<S> {
23 pub(crate) ast: hash::Schema,
24 _p: PhantomData<S>,
25}
26
27impl<S> Default for TableTypBuilder<S> {
28 fn default() -> Self {
29 Self {
30 ast: Default::default(),
31 _p: Default::default(),
32 }
33 }
34}
35
36impl<S> TableTypBuilder<S> {
37 pub fn table<T: Table<Schema = S>>(&mut self) {
38 let mut b = hash::TypBuilder::default();
39 T::typs(&mut b);
40 self.ast.tables.insert((T::NAME.to_owned(), b.ast));
41 }
42}
43
44pub trait Schema: Sized + 'static {
45 const VERSION: i64;
46 fn typs(b: &mut TableTypBuilder<Self>);
47}
48
49pub trait Migration {
50 type FromSchema: 'static;
51 type From: Table<Schema = Self::FromSchema>;
52 type To: Table<MigrateFrom = Self::From>;
53 type Conflict;
54
55 #[doc(hidden)]
56 fn prepare(
57 val: Self,
58 prev: crate::Expr<'static, Self::FromSchema, Self::From>,
59 ) -> <Self::To as Table>::Insert;
60 #[doc(hidden)]
61 fn map_conflict(val: TableRow<Self::From>) -> Self::Conflict;
62}
63
64pub struct TransactionMigrate<FromSchema> {
66 inner: Transaction<FromSchema>,
67 scope: Scope,
68 rename_map: HashMap<&'static str, TmpTable>,
69}
70
71impl<FromSchema> Deref for TransactionMigrate<FromSchema> {
72 type Target = Transaction<FromSchema>;
73
74 fn deref(&self) -> &Self::Target {
75 &self.inner
76 }
77}
78
79impl<FromSchema> TransactionMigrate<FromSchema> {
80 fn new_table_name<T: Table>(&mut self) -> TmpTable {
81 *self.rename_map.entry(T::NAME).or_insert_with(|| {
82 let new_table_name = self.scope.tmp_table();
83 TXN.with_borrow(|txn| new_table::<T>(txn.as_ref().unwrap().get(), new_table_name));
84 new_table_name
85 })
86 }
87
88 fn unmigrated<M: Migration<FromSchema = FromSchema>, Out>(
89 &self,
90 new_name: TmpTable,
91 ) -> impl Iterator<Item = (i64, Out)>
92 where
93 Out: FromExpr<FromSchema, M::From>,
94 {
95 let data = self.inner.query(|rows| {
96 let old = rows.join(<M::From as Table>::TOKEN);
97 rows.into_vec((&old, Out::from_expr(&old)))
98 });
99
100 let migrated = Transaction::new().query(|rows| {
101 let new = rows.join_tmp::<M::From>(new_name);
102 rows.into_vec(new)
103 });
104 let migrated: HashSet<_> = migrated.into_iter().map(|x| x.inner.idx).collect();
105
106 data.into_iter().filter_map(move |(row, data)| {
107 migrated
108 .contains(&row.inner.idx)
109 .not()
110 .then_some((row.inner.idx, data))
111 })
112 }
113
114 pub fn migrate_optional<
122 M: Migration<FromSchema = FromSchema>,
123 X: FromExpr<FromSchema, M::From>,
124 >(
125 &mut self,
126 mut f: impl FnMut(X) -> Option<M>,
127 ) -> Result<(), M::Conflict> {
128 let new_name = self.new_table_name::<M::To>();
129
130 for (idx, x) in self.unmigrated::<M, X>(new_name) {
131 if let Some(new) = f(x) {
132 try_insert_private::<M::To>(
133 new_name.into_table_ref(),
134 Some(idx),
135 M::prepare(new, TableRow::new(idx).into_expr()),
136 )
137 .map_err(|_| M::map_conflict(TableRow::new(idx)))?;
138 };
139 }
140 Ok(())
141 }
142
143 pub fn migrate<M: Migration<FromSchema = FromSchema>, X: FromExpr<FromSchema, M::From>>(
150 &mut self,
151 mut f: impl FnMut(X) -> M,
152 ) -> Result<Migrated<'static, FromSchema, M::To>, M::Conflict> {
153 self.migrate_optional::<M, X>(|x| Some(f(x)))?;
154
155 Ok(Migrated {
156 _p: PhantomData,
157 f: Box::new(|_| {}),
158 _local: PhantomData,
159 })
160 }
161
162 pub fn migrate_ok<
166 M: Migration<FromSchema = FromSchema, Conflict = Infallible>,
167 X: FromExpr<FromSchema, M::From>,
168 >(
169 &mut self,
170 f: impl FnMut(X) -> M,
171 ) -> Migrated<'static, FromSchema, M::To> {
172 let Ok(res) = self.migrate(f);
173 res
174 }
175}
176
177pub struct SchemaBuilder<'t, FromSchema> {
178 inner: TransactionMigrate<FromSchema>,
179 drop: Vec<TableDropStatement>,
180 foreign_key: HashMap<&'static str, Box<dyn 't + FnOnce() -> Infallible>>,
181}
182
183impl<'t, FromSchema: 'static> SchemaBuilder<'t, FromSchema> {
184 pub fn foreign_key<To: Table>(&mut self, err: impl 't + FnOnce() -> Infallible) {
185 self.inner.new_table_name::<To>();
186
187 self.foreign_key.insert(To::NAME, Box::new(err));
188 }
189
190 pub fn create_empty<To: Table>(&mut self) {
191 self.inner.new_table_name::<To>();
192 }
193
194 pub fn drop_table<T: Table>(&mut self) {
195 let name = Alias::new(T::NAME);
196 let step = sea_query::Table::drop().table(name).take();
197 self.drop.push(step);
198 }
199}
200
201fn new_table<T: Table>(conn: &Connection, alias: TmpTable) {
202 let mut f = crate::hash::TypBuilder::default();
203 T::typs(&mut f);
204 new_table_inner(conn, &f.ast, alias);
205}
206
207fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
208 let mut create = table.create();
209 create
210 .table(alias)
211 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
212 let mut sql = create.to_string(SqliteQueryBuilder);
213 sql.push_str(" STRICT");
214 conn.execute(&sql, []).unwrap();
215}
216
217pub trait SchemaMigration<'a> {
218 type From: Schema;
219 type To: Schema;
220
221 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
222}
223
224pub struct Config {
236 manager: r2d2_sqlite::SqliteConnectionManager,
237 init: Box<dyn FnOnce(&rusqlite::Transaction)>,
238 pub synchronous: Synchronous,
242 pub foreign_keys: ForeignKeys,
244}
245
246#[non_exhaustive]
250pub enum Synchronous {
251 Full,
255
256 Normal,
261}
262
263impl Synchronous {
264 fn as_str(self) -> &'static str {
265 match self {
266 Synchronous::Full => "FULL",
267 Synchronous::Normal => "NORMAL",
268 }
269 }
270}
271
272#[non_exhaustive]
276pub enum ForeignKeys {
277 Rust,
282
283 SQLite,
296}
297
298impl ForeignKeys {
299 fn as_str(self) -> &'static str {
300 match self {
301 ForeignKeys::Rust => "OFF",
302 ForeignKeys::SQLite => "ON",
303 }
304 }
305}
306
307impl Config {
308 pub fn open(p: impl AsRef<Path>) -> Self {
315 let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
316 Self::open_internal(manager)
317 }
318
319 pub fn open_in_memory() -> Self {
321 let manager = r2d2_sqlite::SqliteConnectionManager::memory();
322 Self::open_internal(manager)
323 }
324
325 fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
326 Self {
327 manager,
328 init: Box::new(|_| {}),
329 synchronous: Synchronous::Full,
330 foreign_keys: ForeignKeys::SQLite,
331 }
332 }
333
334 pub fn init_stmt(mut self, sql: &'static str) -> Self {
338 self.init = Box::new(move |txn| {
339 (self.init)(txn);
340
341 txn.execute_batch(sql)
342 .expect("raw sql statement to populate db failed");
343 });
344 self
345 }
346}
347
348impl<S: Schema> Database<S> {
349 pub fn migrator(config: Config) -> Option<Migrator<S>> {
353 let synchronous = config.synchronous.as_str();
354 let foreign_keys = config.foreign_keys.as_str();
355 let manager = config.manager.with_init(move |inner| {
356 inner.pragma_update(None, "journal_mode", "WAL")?;
357 inner.pragma_update(None, "synchronous", synchronous)?;
358 inner.pragma_update(None, "foreign_keys", foreign_keys)?;
359 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
360 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
361 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
362 Ok(())
363 });
364
365 use r2d2::ManageConnection;
366 let conn = manager.connect().unwrap();
367 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
368 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
369 Some(
370 conn.borrow_mut()
371 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
372 .unwrap(),
373 )
374 });
375
376 if schema_version(txn.get()) == 0 {
378 let mut b = TableTypBuilder::default();
379 S::typs(&mut b);
380
381 for (table_name, table) in &*b.ast.tables {
382 new_table_inner(txn.get(), table, Alias::new(table_name));
383 }
384 (config.init)(txn.get());
385 set_user_version(txn.get(), S::VERSION).unwrap();
386 }
387
388 let user_version = user_version(txn.get()).unwrap();
389 if user_version < S::VERSION {
391 return None;
392 }
393 debug_assert_eq!(
394 foreign_key_check(txn.get()),
395 None,
396 "foreign key constraint violated"
397 );
398
399 Some(Migrator {
400 manager,
401 transaction: txn,
402 _p: PhantomData,
403 })
404 }
405}
406
407pub struct Migrator<S> {
412 manager: r2d2_sqlite::SqliteConnectionManager,
413 transaction: OwnedTransaction,
414 _p: PhantomData<S>,
415}
416
417pub struct Migrated<'t, FromSchema, T> {
421 _p: PhantomData<T>,
422 f: Box<dyn 't + FnOnce(&mut SchemaBuilder<'t, FromSchema>)>,
423 _local: PhantomData<*const ()>,
424}
425
426impl<'t, FromSchema: 'static, T: Table> Migrated<'t, FromSchema, T> {
427 pub fn map_fk_err(err: impl 't + FnOnce() -> Infallible) -> Self {
431 Self {
432 _p: PhantomData,
433 f: Box::new(|x| x.foreign_key::<T>(err)),
434 _local: PhantomData,
435 }
436 }
437
438 #[doc(hidden)]
439 pub fn apply(self, b: &mut SchemaBuilder<'t, FromSchema>) {
440 (self.f)(b)
441 }
442}
443
444impl<S: Schema> Migrator<S> {
445 pub fn migrate<'x, M>(
449 mut self,
450 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
451 ) -> Migrator<M::To>
452 where
453 M: SchemaMigration<'x, From = S>,
454 {
455 if user_version(self.transaction.get()).unwrap() == S::VERSION {
456 let res = std::thread::scope(|s| {
457 s.spawn(|| {
458 TXN.set(Some(self.transaction));
459
460 check_schema::<S>();
461
462 let mut txn = TransactionMigrate {
463 inner: Transaction::new(),
464 scope: Default::default(),
465 rename_map: HashMap::new(),
466 };
467 let m = m(&mut txn);
468
469 let mut builder = SchemaBuilder {
470 drop: vec![],
471 foreign_key: HashMap::new(),
472 inner: txn,
473 };
474 m.tables(&mut builder);
475
476 let transaction = TXN.take().unwrap();
477
478 for drop in builder.drop {
479 let sql = drop.to_string(SqliteQueryBuilder);
480 transaction.get().execute(&sql, []).unwrap();
481 }
482 for (to, tmp) in builder.inner.rename_map {
483 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
484 let sql = rename.to_string(SqliteQueryBuilder);
485 transaction.get().execute(&sql, []).unwrap();
486 }
487 if let Some(fk) = foreign_key_check(transaction.get()) {
488 (builder.foreign_key.remove(&*fk).unwrap())();
489 }
490 #[allow(
491 unreachable_code,
492 reason = "rustc is stupid and thinks this is unreachable"
493 )]
494 set_user_version(transaction.get(), M::To::VERSION).unwrap();
495
496 transaction
497 })
498 .join()
499 });
500 match res {
501 Ok(val) => self.transaction = val,
502 Err(payload) => std::panic::resume_unwind(payload),
503 }
504 }
505
506 Migrator {
507 manager: self.manager,
508 transaction: self.transaction,
509 _p: PhantomData,
510 }
511 }
512
513 pub fn finish(mut self) -> Option<Database<S>> {
519 let conn = &self.transaction;
520 if user_version(conn.get()).unwrap() != S::VERSION {
521 return None;
522 }
523
524 let res = std::thread::scope(|s| {
525 s.spawn(|| {
526 TXN.set(Some(self.transaction));
527 check_schema::<S>();
528 TXN.take().unwrap()
529 })
530 .join()
531 });
532 match res {
533 Ok(val) => self.transaction = val,
534 Err(payload) => std::panic::resume_unwind(payload),
535 }
536
537 self.transaction
539 .get()
540 .execute_batch("PRAGMA optimize;")
541 .unwrap();
542
543 let schema_version = schema_version(self.transaction.get());
544 self.transaction.with(|x| x.commit().unwrap());
545
546 Some(Database {
547 manager: self.manager,
548 schema_version: AtomicI64::new(schema_version),
549 schema: PhantomData,
550 mut_lock: parking_lot::FairMutex::new(()),
551 })
552 }
553}
554
555pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
556 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
557 .unwrap()
558}
559
560pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
562 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
563}
564
565fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
567 conn.pragma_update(None, "user_version", v)
568}
569
570pub(crate) fn check_schema<S: Schema>() {
571 let mut b = TableTypBuilder::default();
572 S::typs(&mut b);
573 pretty_assertions::assert_eq!(
574 b.ast,
575 read_schema(&crate::Transaction::new()),
576 "schema is different (expected left, but got right)",
577 );
578}
579
580fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
581 let error = conn
582 .prepare("PRAGMA foreign_key_check")
583 .unwrap()
584 .query_map([], |row| row.get(2))
585 .unwrap()
586 .next();
587 error.transpose().unwrap()
588}
589
590#[test]
591fn open_multiple() {
592 #[crate::migration::schema(Empty)]
593 pub mod vN {}
594
595 let _a = Database::<v0::Empty>::migrator(Config::open_in_memory());
596 let _b = Database::<v0::Empty>::migrator(Config::open_in_memory());
597}