1use std::{
2 collections::{HashMap, HashSet},
3 convert::Infallible,
4 marker::PhantomData,
5 ops::{Deref, Not},
6 path::Path,
7};
8
9use rusqlite::{Connection, config::DbConfig};
10use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder, TableDropStatement};
11use self_cell::MutBorrow;
12
13use crate::{
14 FromExpr, IntoExpr, Table, TableRow, Transaction,
15 alias::{Scope, TmpTable},
16 hash,
17 schema_pragma::read_schema,
18 transaction::{Database, OwnedTransaction, TXN, try_insert_private},
19};
20
21pub struct TableTypBuilder<S> {
22 pub(crate) ast: hash::Schema,
23 _p: PhantomData<S>,
24}
25
26impl<S> Default for TableTypBuilder<S> {
27 fn default() -> Self {
28 Self {
29 ast: Default::default(),
30 _p: Default::default(),
31 }
32 }
33}
34
35impl<S> TableTypBuilder<S> {
36 pub fn table<T: Table<Schema = S>>(&mut self) {
37 let mut b = hash::TypBuilder::default();
38 T::typs(&mut b);
39 self.ast.tables.insert((T::NAME.to_owned(), b.ast));
40 }
41}
42
43pub trait Schema: Sized + 'static {
44 const VERSION: i64;
45 fn typs(b: &mut TableTypBuilder<Self>);
46}
47
48pub trait Migration {
49 type FromSchema: 'static;
50 type From: Table<Schema = Self::FromSchema>;
51 type To: Table<MigrateFrom = Self::From>;
52 type Conflict;
53
54 #[doc(hidden)]
55 fn prepare(
56 val: Self,
57 prev: crate::Expr<'static, Self::FromSchema, Self::From>,
58 ) -> <Self::To as Table>::Insert;
59 #[doc(hidden)]
60 fn map_conflict(val: TableRow<Self::From>) -> Self::Conflict;
61}
62
63pub struct TransactionMigrate<FromSchema> {
65 inner: Transaction<FromSchema>,
66 scope: Scope,
67 rename_map: HashMap<&'static str, TmpTable>,
68}
69
70impl<FromSchema> Deref for TransactionMigrate<FromSchema> {
71 type Target = Transaction<FromSchema>;
72
73 fn deref(&self) -> &Self::Target {
74 &self.inner
75 }
76}
77
78impl<FromSchema> TransactionMigrate<FromSchema> {
79 fn new_table_name<T: Table>(&mut self) -> TmpTable {
80 *self.rename_map.entry(T::NAME).or_insert_with(|| {
81 let new_table_name = self.scope.tmp_table();
82 TXN.with_borrow(|txn| new_table::<T>(txn.as_ref().unwrap().get(), new_table_name));
83 new_table_name
84 })
85 }
86
87 fn unmigrated<M: Migration<FromSchema = FromSchema>, Out>(
88 &self,
89 new_name: TmpTable,
90 ) -> impl Iterator<Item = (i64, Out)>
91 where
92 Out: FromExpr<FromSchema, M::From>,
93 {
94 let data = self.inner.query(|rows| {
95 let old = rows.join(<M::From as Table>::TOKEN);
96 rows.into_vec((&old, Out::from_expr(&old)))
97 });
98
99 let migrated = Transaction::new().query(|rows| {
100 let new = rows.join_tmp::<M::From>(new_name);
101 rows.into_vec(new)
102 });
103 let migrated: HashSet<_> = migrated.into_iter().map(|x| x.inner.idx).collect();
104
105 data.into_iter().filter_map(move |(row, data)| {
106 migrated
107 .contains(&row.inner.idx)
108 .not()
109 .then_some((row.inner.idx, data))
110 })
111 }
112
113 pub fn migrate_optional<
121 M: Migration<FromSchema = FromSchema>,
122 X: FromExpr<FromSchema, M::From>,
123 >(
124 &mut self,
125 mut f: impl FnMut(X) -> Option<M>,
126 ) -> Result<(), M::Conflict> {
127 let new_name = self.new_table_name::<M::To>();
128
129 for (idx, x) in self.unmigrated::<M, X>(new_name) {
130 if let Some(new) = f(x) {
131 try_insert_private::<M::To>(
132 new_name.into_table_ref(),
133 Some(idx),
134 M::prepare(new, TableRow::new(idx).into_expr()),
135 )
136 .map_err(|_| M::map_conflict(TableRow::new(idx)))?;
137 };
138 }
139 Ok(())
140 }
141
142 pub fn migrate<M: Migration<FromSchema = FromSchema>, X: FromExpr<FromSchema, M::From>>(
149 &mut self,
150 mut f: impl FnMut(X) -> M,
151 ) -> Result<Migrated<'static, FromSchema, M::To>, M::Conflict> {
152 self.migrate_optional::<M, X>(|x| Some(f(x)))?;
153
154 Ok(Migrated {
155 _p: PhantomData,
156 f: Box::new(|_| {}),
157 _local: PhantomData,
158 })
159 }
160
161 pub fn migrate_ok<
165 M: Migration<FromSchema = FromSchema, Conflict = Infallible>,
166 X: FromExpr<FromSchema, M::From>,
167 >(
168 &mut self,
169 f: impl FnMut(X) -> M,
170 ) -> Migrated<'static, FromSchema, M::To> {
171 let Ok(res) = self.migrate(f);
172 res
173 }
174}
175
176pub struct SchemaBuilder<'t, FromSchema> {
177 inner: TransactionMigrate<FromSchema>,
178 drop: Vec<TableDropStatement>,
179 foreign_key: HashMap<&'static str, Box<dyn 't + FnOnce() -> Infallible>>,
180}
181
182impl<'t, FromSchema: 'static> SchemaBuilder<'t, FromSchema> {
183 pub fn foreign_key<To: Table>(&mut self, err: impl 't + FnOnce() -> Infallible) {
184 self.inner.new_table_name::<To>();
185
186 self.foreign_key.insert(To::NAME, Box::new(err));
187 }
188
189 pub fn create_empty<To: Table>(&mut self) {
190 self.inner.new_table_name::<To>();
191 }
192
193 pub fn drop_table<T: Table>(&mut self) {
194 let name = Alias::new(T::NAME);
195 let step = sea_query::Table::drop().table(name).take();
196 self.drop.push(step);
197 }
198}
199
200fn new_table<T: Table>(conn: &Connection, alias: TmpTable) {
201 let mut f = crate::hash::TypBuilder::default();
202 T::typs(&mut f);
203 new_table_inner(conn, &f.ast, alias);
204}
205
206fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
207 let mut create = table.create();
208 create
209 .table(alias)
210 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
211 let mut sql = create.to_string(SqliteQueryBuilder);
212 sql.push_str(" STRICT");
213 conn.execute(&sql, []).unwrap();
214}
215
216pub trait SchemaMigration<'a> {
217 type From: Schema;
218 type To: Schema;
219
220 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
221}
222
223pub struct Config {
235 manager: r2d2_sqlite::SqliteConnectionManager,
236 init: Box<dyn FnOnce(&rusqlite::Transaction)>,
237 pub synchronous: Synchronous,
241}
242
243#[non_exhaustive]
247pub enum Synchronous {
248 Full,
252
253 Normal,
258}
259
260impl Synchronous {
261 fn as_str(self) -> &'static str {
262 match self {
263 Synchronous::Full => "FULL",
264 Synchronous::Normal => "NORMAL",
265 }
266 }
267}
268
269impl Config {
270 pub fn open(p: impl AsRef<Path>) -> Self {
277 let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
278 Self::open_internal(manager)
279 }
280
281 pub fn open_in_memory() -> Self {
283 let manager = r2d2_sqlite::SqliteConnectionManager::memory();
284 Self::open_internal(manager)
285 }
286
287 fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
288 Self {
289 manager,
290 init: Box::new(|_| {}),
291 synchronous: Synchronous::Full,
292 }
293 }
294
295 pub fn init_stmt(mut self, sql: &'static str) -> Self {
299 self.init = Box::new(move |txn| {
300 (self.init)(txn);
301
302 txn.execute_batch(sql)
303 .expect("raw sql statement to populate db failed");
304 });
305 self
306 }
307}
308
309impl<S: Schema> Database<S> {
310 pub fn migrator(config: Config) -> Option<Migrator<S>> {
314 let synchronous = config.synchronous.as_str();
315 let manager = config.manager.with_init(move |inner| {
316 inner.pragma_update(None, "journal_mode", "WAL")?;
317 inner.pragma_update(None, "synchronous", synchronous)?;
318 inner.pragma_update(None, "foreign_keys", "ON")?;
319 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
320 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
321 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
322 Ok(())
323 });
324
325 use r2d2::ManageConnection;
326 let conn = manager.connect().unwrap();
327 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
328 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
329 Some(
330 conn.borrow_mut()
331 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
332 .unwrap(),
333 )
334 });
335
336 if schema_version(txn.get()) == 0 {
338 let mut b = TableTypBuilder::default();
339 S::typs(&mut b);
340
341 for (table_name, table) in &*b.ast.tables {
342 new_table_inner(txn.get(), table, Alias::new(table_name));
343 }
344 (config.init)(txn.get());
345 set_user_version(txn.get(), S::VERSION).unwrap();
346 }
347
348 let user_version = user_version(txn.get()).unwrap();
349 if user_version < S::VERSION {
351 return None;
352 }
353 assert_eq!(
354 foreign_key_check(txn.get()),
355 None,
356 "foreign key constraint violated"
357 );
358
359 Some(Migrator {
360 manager,
361 transaction: txn,
362 _p: PhantomData,
363 })
364 }
365}
366
367pub struct Migrator<S> {
372 manager: r2d2_sqlite::SqliteConnectionManager,
373 transaction: OwnedTransaction,
374 _p: PhantomData<S>,
375}
376
377pub struct Migrated<'t, FromSchema, T> {
381 _p: PhantomData<T>,
382 f: Box<dyn 't + FnOnce(&mut SchemaBuilder<'t, FromSchema>)>,
383 _local: PhantomData<*const ()>,
384}
385
386impl<'t, FromSchema: 'static, T: Table> Migrated<'t, FromSchema, T> {
387 pub fn map_fk_err(err: impl 't + FnOnce() -> Infallible) -> Self {
391 Self {
392 _p: PhantomData,
393 f: Box::new(|x| x.foreign_key::<T>(err)),
394 _local: PhantomData,
395 }
396 }
397
398 #[doc(hidden)]
399 pub fn apply(self, b: &mut SchemaBuilder<'t, FromSchema>) {
400 (self.f)(b)
401 }
402}
403
404impl<S: Schema> Migrator<S> {
405 pub fn migrate<'x, M>(
409 mut self,
410 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
411 ) -> Migrator<M::To>
412 where
413 M: SchemaMigration<'x, From = S>,
414 {
415 if user_version(self.transaction.get()).unwrap() == S::VERSION {
416 self.transaction = std::thread::scope(|s| {
417 s.spawn(|| {
418 TXN.set(Some(self.transaction));
419
420 check_schema::<S>();
421
422 let mut txn = TransactionMigrate {
423 inner: Transaction::new(),
424 scope: Default::default(),
425 rename_map: HashMap::new(),
426 };
427 let m = m(&mut txn);
428
429 let mut builder = SchemaBuilder {
430 drop: vec![],
431 foreign_key: HashMap::new(),
432 inner: txn,
433 };
434 m.tables(&mut builder);
435
436 let transaction = TXN.take().unwrap();
437
438 for drop in builder.drop {
439 let sql = drop.to_string(SqliteQueryBuilder);
440 transaction.get().execute(&sql, []).unwrap();
441 }
442 for (to, tmp) in builder.inner.rename_map {
443 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
444 let sql = rename.to_string(SqliteQueryBuilder);
445 transaction.get().execute(&sql, []).unwrap();
446 }
447 if let Some(fk) = foreign_key_check(transaction.get()) {
448 (builder.foreign_key.remove(&*fk).unwrap())();
449 }
450 #[allow(
451 unreachable_code,
452 reason = "rustc is stupid and thinks this is unreachable"
453 )]
454 set_user_version(transaction.get(), M::To::VERSION).unwrap();
455
456 transaction
457 })
458 .join()
459 .unwrap()
460 });
461 }
462
463 Migrator {
464 manager: self.manager,
465 transaction: self.transaction,
466 _p: PhantomData,
467 }
468 }
469
470 pub fn finish(mut self) -> Option<Database<S>> {
476 let conn = &self.transaction;
477 if user_version(conn.get()).unwrap() != S::VERSION {
478 return None;
479 }
480
481 self.transaction = std::thread::scope(|s| {
482 s.spawn(|| {
483 TXN.set(Some(self.transaction));
484 check_schema::<S>();
485 TXN.take().unwrap()
486 })
487 .join()
488 .unwrap()
489 });
490
491 self.transaction
493 .get()
494 .execute_batch("PRAGMA optimize;")
495 .unwrap();
496
497 let schema_version = schema_version(self.transaction.get());
498 self.transaction.with(|x| x.commit().unwrap());
499
500 Some(Database {
501 manager: self.manager,
502 schema_version,
503 schema: PhantomData,
504 })
505 }
506}
507
508pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
509 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
510 .unwrap()
511}
512
513fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
515 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
516}
517
518fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
520 conn.pragma_update(None, "user_version", v)
521}
522
523fn check_schema<S: Schema>() {
524 let mut b = TableTypBuilder::default();
525 S::typs(&mut b);
526 pretty_assertions::assert_eq!(
527 b.ast,
528 read_schema(&crate::Transaction::new()),
529 "schema is different (expected left, but got right)",
530 );
531}
532
533fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
534 let error = conn
535 .prepare("PRAGMA foreign_key_check")
536 .unwrap()
537 .query_map([], |row| row.get(2))
538 .unwrap()
539 .next();
540 error.transpose().unwrap()
541}
542
543#[test]
544fn open_multiple() {
545 #[crate::migration::schema(Empty)]
546 pub mod vN {}
547
548 let _a = Database::<v0::Empty>::migrator(Config::open_in_memory());
549 let _b = Database::<v0::Empty>::migrator(Config::open_in_memory());
550}