1use std::{
2 collections::{HashMap, HashSet},
3 convert::Infallible,
4 marker::PhantomData,
5 ops::{Deref, Not},
6 path::Path,
7};
8
9use rusqlite::{Connection, config::DbConfig};
10use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder, TableDropStatement};
11use self_cell::MutBorrow;
12
13use crate::{
14 FromExpr, IntoExpr, Table, TableRow, Transaction,
15 alias::{Scope, TmpTable},
16 hash,
17 schema_pragma::read_schema,
18 transaction::{Database, OwnedTransaction, TXN, try_insert_private},
19};
20
21pub struct TableTypBuilder<S> {
22 pub(crate) ast: hash::Schema,
23 _p: PhantomData<S>,
24}
25
26impl<S> Default for TableTypBuilder<S> {
27 fn default() -> Self {
28 Self {
29 ast: Default::default(),
30 _p: Default::default(),
31 }
32 }
33}
34
35impl<S> TableTypBuilder<S> {
36 pub fn table<T: Table<Schema = S>>(&mut self) {
37 let mut b = hash::TypBuilder::default();
38 T::typs(&mut b);
39 self.ast.tables.insert((T::NAME.to_owned(), b.ast));
40 }
41}
42
43pub trait Schema: Sized + 'static {
44 const VERSION: i64;
45 fn typs(b: &mut TableTypBuilder<Self>);
46}
47
48pub trait Migration {
49 type FromSchema: 'static;
50 type From: Table<Schema = Self::FromSchema>;
51 type To: Table<MigrateFrom = Self::From>;
52 type Conflict;
53
54 #[doc(hidden)]
55 fn prepare(
56 val: Self,
57 prev: crate::Expr<'static, Self::FromSchema, Self::From>,
58 ) -> <Self::To as Table>::Insert;
59 #[doc(hidden)]
60 fn map_conflict(val: TableRow<Self::From>) -> Self::Conflict;
61}
62
63pub struct TransactionMigrate<FromSchema> {
65 inner: Transaction<FromSchema>,
66 scope: Scope,
67 rename_map: HashMap<&'static str, TmpTable>,
68}
69
70impl<FromSchema> Deref for TransactionMigrate<FromSchema> {
71 type Target = Transaction<FromSchema>;
72
73 fn deref(&self) -> &Self::Target {
74 &self.inner
75 }
76}
77
78impl<FromSchema> TransactionMigrate<FromSchema> {
79 fn new_table_name<T: Table>(&mut self) -> TmpTable {
80 *self.rename_map.entry(T::NAME).or_insert_with(|| {
81 let new_table_name = self.scope.tmp_table();
82 TXN.with_borrow(|txn| new_table::<T>(txn.as_ref().unwrap().get(), new_table_name));
83 new_table_name
84 })
85 }
86
87 fn unmigrated<M: Migration<FromSchema = FromSchema>, Out>(
88 &self,
89 new_name: TmpTable,
90 ) -> impl Iterator<Item = (i64, Out)>
91 where
92 Out: FromExpr<FromSchema, M::From>,
93 {
94 let data = self.inner.query(|rows| {
95 let old = rows.join(<M::From as Table>::TOKEN);
96 rows.into_vec((&old, Out::from_expr(&old)))
97 });
98
99 let migrated = Transaction::new().query(|rows| {
100 let new = rows.join_tmp::<M::From>(new_name);
101 rows.into_vec(new)
102 });
103 let migrated: HashSet<_> = migrated.into_iter().map(|x| x.inner.idx).collect();
104
105 data.into_iter().filter_map(move |(row, data)| {
106 migrated
107 .contains(&row.inner.idx)
108 .not()
109 .then_some((row.inner.idx, data))
110 })
111 }
112
113 pub fn migrate_optional<
121 M: Migration<FromSchema = FromSchema>,
122 X: FromExpr<FromSchema, M::From>,
123 >(
124 &mut self,
125 mut f: impl FnMut(X) -> Option<M>,
126 ) -> Result<(), M::Conflict> {
127 let new_name = self.new_table_name::<M::To>();
128
129 for (idx, x) in self.unmigrated::<M, X>(new_name) {
130 if let Some(new) = f(x) {
131 try_insert_private::<M::To>(
132 new_name.into_table_ref(),
133 Some(idx),
134 M::prepare(new, TableRow::new(idx).into_expr()),
135 )
136 .map_err(|_| M::map_conflict(TableRow::new(idx)))?;
137 };
138 }
139 Ok(())
140 }
141
142 pub fn migrate<M: Migration<FromSchema = FromSchema>, X: FromExpr<FromSchema, M::From>>(
149 &mut self,
150 mut f: impl FnMut(X) -> M,
151 ) -> Result<Migrated<'static, FromSchema, M::To>, M::Conflict> {
152 self.migrate_optional::<M, X>(|x| Some(f(x)))?;
153
154 Ok(Migrated {
155 _p: PhantomData,
156 f: Box::new(|_| {}),
157 _local: PhantomData,
158 })
159 }
160
161 pub fn migrate_ok<
165 M: Migration<FromSchema = FromSchema, Conflict = Infallible>,
166 X: FromExpr<FromSchema, M::From>,
167 >(
168 &mut self,
169 f: impl FnMut(X) -> M,
170 ) -> Migrated<'static, FromSchema, M::To> {
171 let Ok(res) = self.migrate(f);
172 res
173 }
174}
175
176pub struct SchemaBuilder<'t, FromSchema> {
177 inner: TransactionMigrate<FromSchema>,
178 drop: Vec<TableDropStatement>,
179 foreign_key: HashMap<&'static str, Box<dyn 't + FnOnce() -> Infallible>>,
180}
181
182impl<'t, FromSchema: 'static> SchemaBuilder<'t, FromSchema> {
183 pub fn foreign_key<To: Table>(&mut self, err: impl 't + FnOnce() -> Infallible) {
184 self.inner.new_table_name::<To>();
185
186 self.foreign_key.insert(To::NAME, Box::new(err));
187 }
188
189 pub fn create_empty<To: Table>(&mut self) {
190 self.inner.new_table_name::<To>();
191 }
192
193 pub fn drop_table<T: Table>(&mut self) {
194 let name = Alias::new(T::NAME);
195 let step = sea_query::Table::drop().table(name).take();
196 self.drop.push(step);
197 }
198}
199
200fn new_table<T: Table>(conn: &Connection, alias: TmpTable) {
201 let mut f = crate::hash::TypBuilder::default();
202 T::typs(&mut f);
203 new_table_inner(conn, &f.ast, alias);
204}
205
206fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
207 let mut create = table.create();
208 create
209 .table(alias)
210 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
211 let mut sql = create.to_string(SqliteQueryBuilder);
212 sql.push_str(" STRICT");
213 conn.execute(&sql, []).unwrap();
214}
215
216pub trait SchemaMigration<'a> {
217 type From: Schema;
218 type To: Schema;
219
220 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
221}
222
223pub struct Config {
238 manager: r2d2_sqlite::SqliteConnectionManager,
239 init: Box<dyn FnOnce(&rusqlite::Transaction)>,
240}
241
242impl Config {
243 pub fn open(p: impl AsRef<Path>) -> Self {
250 let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
251 Self::open_internal(manager)
252 }
253
254 pub fn open_in_memory() -> Self {
256 let manager = r2d2_sqlite::SqliteConnectionManager::memory();
257 Self::open_internal(manager)
258 }
259
260 fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
261 let manager = manager.with_init(|inner| {
262 inner.pragma_update(None, "journal_mode", "WAL")?;
263 inner.pragma_update(None, "synchronous", "NORMAL")?;
264 inner.pragma_update(None, "foreign_keys", "ON")?;
265 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
266 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
267 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
268 Ok(())
269 });
270
271 Self {
272 manager,
273 init: Box::new(|_| {}),
274 }
275 }
276
277 pub fn init_stmt(mut self, sql: &'static str) -> Self {
281 self.init = Box::new(move |txn| {
282 (self.init)(txn);
283
284 txn.execute_batch(sql)
285 .expect("raw sql statement to populate db failed");
286 });
287 self
288 }
289}
290
291impl<S: Schema> Database<S> {
292 pub fn migrator(config: Config) -> Option<Migrator<S>> {
296 use r2d2::ManageConnection;
297 let conn = config.manager.connect().unwrap();
298 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
299 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
300 Some(
301 conn.borrow_mut()
302 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
303 .unwrap(),
304 )
305 });
306
307 if schema_version(txn.get()) == 0 {
309 let mut b = TableTypBuilder::default();
310 S::typs(&mut b);
311
312 for (table_name, table) in &*b.ast.tables {
313 new_table_inner(txn.get(), table, Alias::new(table_name));
314 }
315 (config.init)(txn.get());
316 set_user_version(txn.get(), S::VERSION).unwrap();
317 }
318
319 let user_version = user_version(txn.get()).unwrap();
320 if user_version < S::VERSION {
322 return None;
323 }
324 assert_eq!(
325 foreign_key_check(txn.get()),
326 None,
327 "foreign key constraint violated"
328 );
329
330 Some(Migrator {
331 manager: config.manager,
332 transaction: txn,
333 _p: PhantomData,
334 })
335 }
336}
337
338pub struct Migrator<S> {
343 manager: r2d2_sqlite::SqliteConnectionManager,
344 transaction: OwnedTransaction,
345 _p: PhantomData<S>,
346}
347
348pub struct Migrated<'t, FromSchema, T> {
352 _p: PhantomData<T>,
353 f: Box<dyn 't + FnOnce(&mut SchemaBuilder<'t, FromSchema>)>,
354 _local: PhantomData<*const ()>,
355}
356
357impl<'t, FromSchema: 'static, T: Table> Migrated<'t, FromSchema, T> {
358 pub fn map_fk_err(err: impl 't + FnOnce() -> Infallible) -> Self {
362 Self {
363 _p: PhantomData,
364 f: Box::new(|x| x.foreign_key::<T>(err)),
365 _local: PhantomData,
366 }
367 }
368
369 #[doc(hidden)]
370 pub fn apply(self, b: &mut SchemaBuilder<'t, FromSchema>) {
371 (self.f)(b)
372 }
373}
374
375impl<S: Schema> Migrator<S> {
376 pub fn migrate<'x, M>(
380 mut self,
381 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
382 ) -> Migrator<M::To>
383 where
384 M: SchemaMigration<'x, From = S>,
385 {
386 if user_version(self.transaction.get()).unwrap() == S::VERSION {
387 self.transaction = std::thread::scope(|s| {
388 s.spawn(|| {
389 TXN.set(Some(self.transaction));
390
391 check_schema::<S>();
392
393 let mut txn = TransactionMigrate {
394 inner: Transaction::new(),
395 scope: Default::default(),
396 rename_map: HashMap::new(),
397 };
398 let m = m(&mut txn);
399
400 let mut builder = SchemaBuilder {
401 drop: vec![],
402 foreign_key: HashMap::new(),
403 inner: txn,
404 };
405 m.tables(&mut builder);
406
407 let transaction = TXN.take().unwrap();
408
409 for drop in builder.drop {
410 let sql = drop.to_string(SqliteQueryBuilder);
411 transaction.get().execute(&sql, []).unwrap();
412 }
413 for (to, tmp) in builder.inner.rename_map {
414 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
415 let sql = rename.to_string(SqliteQueryBuilder);
416 transaction.get().execute(&sql, []).unwrap();
417 }
418 if let Some(fk) = foreign_key_check(transaction.get()) {
419 (builder.foreign_key.remove(&*fk).unwrap())();
420 }
421 #[allow(
422 unreachable_code,
423 reason = "rustc is stupid and thinks this is unreachable"
424 )]
425 set_user_version(transaction.get(), M::To::VERSION).unwrap();
426
427 transaction
428 })
429 .join()
430 .unwrap()
431 });
432 }
433
434 Migrator {
435 manager: self.manager,
436 transaction: self.transaction,
437 _p: PhantomData,
438 }
439 }
440
441 pub fn finish(mut self) -> Option<Database<S>> {
447 let conn = &self.transaction;
448 if user_version(conn.get()).unwrap() != S::VERSION {
449 return None;
450 }
451
452 self.transaction = std::thread::scope(|s| {
453 s.spawn(|| {
454 TXN.set(Some(self.transaction));
455 check_schema::<S>();
456 TXN.take().unwrap()
457 })
458 .join()
459 .unwrap()
460 });
461
462 self.transaction
464 .get()
465 .execute_batch("PRAGMA optimize;")
466 .unwrap();
467
468 let schema_version = schema_version(self.transaction.get());
469 self.transaction.with(|x| x.commit().unwrap());
470
471 Some(Database {
472 manager: self.manager,
473 schema_version,
474 schema: PhantomData,
475 })
476 }
477}
478
479pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
480 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
481 .unwrap()
482}
483
484fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
486 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
487}
488
489fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
491 conn.pragma_update(None, "user_version", v)
492}
493
494fn check_schema<S: Schema>() {
495 let mut b = TableTypBuilder::default();
496 S::typs(&mut b);
497 pretty_assertions::assert_eq!(
498 b.ast,
499 read_schema(&crate::Transaction::new()),
500 "schema is different (expected left, but got right)",
501 );
502}
503
504fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
505 let error = conn
506 .prepare("PRAGMA foreign_key_check")
507 .unwrap()
508 .query_map([], |row| row.get(2))
509 .unwrap()
510 .next();
511 error.transpose().unwrap()
512}
513
514#[test]
515fn open_multiple() {
516 #[crate::migration::schema(Empty)]
517 pub mod vN {}
518
519 let _a = Database::<v0::Empty>::migrator(Config::open_in_memory());
520 let _b = Database::<v0::Empty>::migrator(Config::open_in_memory());
521}