1use std::{
2 collections::{HashMap, HashSet},
3 convert::Infallible,
4 marker::PhantomData,
5 ops::{Deref, Not},
6 path::Path,
7 rc::Rc,
8 sync::atomic::AtomicBool,
9};
10
11use rusqlite::{Connection, config::DbConfig};
12use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder, TableDropStatement};
13
14use crate::{
15 FromExpr, Table, TableRow, Transaction,
16 alias::{Scope, TmpTable},
17 client::LocalClient,
18 hash,
19 schema_pragma::read_schema,
20 transaction::{Database, try_insert_private},
21};
22
23pub struct TableTypBuilder<S> {
24 pub(crate) ast: hash::Schema,
25 _p: PhantomData<S>,
26}
27
28impl<S> Default for TableTypBuilder<S> {
29 fn default() -> Self {
30 Self {
31 ast: Default::default(),
32 _p: Default::default(),
33 }
34 }
35}
36
37impl<S> TableTypBuilder<S> {
38 pub fn table<T: Table<Schema = S>>(&mut self) {
39 let mut b = hash::TypBuilder::default();
40 T::typs(&mut b);
41 self.ast.tables.insert((T::NAME.to_owned(), b.ast));
42 }
43}
44
45pub trait Schema: Sized + 'static {
46 const VERSION: i64;
47 fn typs(b: &mut TableTypBuilder<Self>);
48}
49
50pub trait Migration<'t> {
51 type FromSchema: 'static;
52 type From: Table<Schema = Self::FromSchema>;
53 type To: Table<MigrateFrom = Self::From>;
54 type Conflict;
55
56 #[doc(hidden)]
57 fn prepare(val: Self, prev: TableRow<'t, Self::From>) -> <Self::To as Table>::Insert<'t>;
58 #[doc(hidden)]
59 fn map_conflict(val: TableRow<'t, Self::From>) -> Self::Conflict;
60}
61
62pub struct TransactionMigrate<'t, FromSchema> {
64 inner: Transaction<'t, FromSchema>,
65 scope: Scope,
66 rename_map: HashMap<&'static str, TmpTable>,
67}
68
69impl<'t, FromSchema> Deref for TransactionMigrate<'t, FromSchema> {
70 type Target = Transaction<'t, FromSchema>;
71
72 fn deref(&self) -> &Self::Target {
73 &self.inner
74 }
75}
76
77impl<'t, FromSchema> TransactionMigrate<'t, FromSchema> {
78 fn new_table_name<T: Table>(&mut self) -> TmpTable {
79 *self.rename_map.entry(T::NAME).or_insert_with(|| {
80 let new_table_name = self.scope.tmp_table();
81 new_table::<T>(&self.inner.transaction, new_table_name);
82 new_table_name
83 })
84 }
85
86 fn unmigrated<M: Migration<'t, FromSchema = FromSchema>, Out>(
87 &self,
88 new_name: TmpTable,
89 ) -> impl Iterator<Item = (i64, Out)>
90 where
91 Out: FromExpr<'t, FromSchema, M::From>,
92 {
93 let data = self.inner.query(|rows| {
94 let old = rows.join(<M::From as Table>::TOKEN);
95 rows.into_vec((&old, Out::from_expr(&old)))
96 });
97
98 let migrated = Transaction::new(self.inner.transaction.clone()).query(|rows| {
99 let new = rows.join_tmp::<M::From>(new_name);
100 rows.into_vec(new)
101 });
102 let migrated: HashSet<_> = migrated.into_iter().map(|x| x.inner.idx).collect();
103
104 data.into_iter().filter_map(move |(row, data)| {
105 migrated
106 .contains(&row.inner.idx)
107 .not()
108 .then_some((row.inner.idx, data))
109 })
110 }
111
112 pub fn migrate_optional<
120 M: Migration<'t, FromSchema = FromSchema>,
121 X: FromExpr<'t, FromSchema, M::From>,
122 >(
123 &mut self,
124 mut f: impl FnMut(X) -> Option<M>,
125 ) -> Result<(), M::Conflict> {
126 let new_name = self.new_table_name::<M::To>();
127
128 for (idx, x) in self.unmigrated::<M, X>(new_name) {
129 if let Some(new) = f(x) {
130 try_insert_private::<M::To>(
131 &self.transaction,
132 new_name.into_table_ref(),
133 Some(idx),
134 M::prepare(new, TableRow::new(idx)),
135 )
136 .map_err(|_| M::map_conflict(TableRow::new(idx)))?;
137 };
138 }
139 Ok(())
140 }
141
142 pub fn migrate<
149 M: Migration<'t, FromSchema = FromSchema>,
150 X: FromExpr<'t, FromSchema, M::From>,
151 >(
152 &mut self,
153 mut f: impl FnMut(X) -> M,
154 ) -> Result<Migrated<'t, FromSchema, M::To>, M::Conflict> {
155 self.migrate_optional::<M, X>(|x| Some(f(x)))?;
156
157 Ok(Migrated {
158 _p: PhantomData,
159 f: Box::new(|_| {}),
160 })
161 }
162
163 pub fn migrate_ok<
167 M: Migration<'t, FromSchema = FromSchema, Conflict = Infallible>,
168 X: FromExpr<'t, FromSchema, M::From>,
169 >(
170 &mut self,
171 f: impl FnMut(X) -> M,
172 ) -> Migrated<'t, FromSchema, M::To> {
173 let Ok(res) = self.migrate(f);
174 res
175 }
176}
177
178pub struct SchemaBuilder<'t, FromSchema> {
179 inner: TransactionMigrate<'t, FromSchema>,
180 drop: Vec<TableDropStatement>,
181 foreign_key: HashMap<&'static str, Box<dyn 't + FnOnce() -> Infallible>>,
182}
183
184impl<'t, FromSchema: 'static> SchemaBuilder<'t, FromSchema> {
185 pub fn foreign_key<To: Table>(&mut self, err: impl 't + FnOnce() -> Infallible) {
186 self.inner.new_table_name::<To>();
187
188 self.foreign_key.insert(To::NAME, Box::new(err));
189 }
190
191 pub fn create_empty<To: Table>(&mut self) {
192 self.inner.new_table_name::<To>();
193 }
194
195 pub fn drop_table<T: Table>(&mut self) {
196 let name = Alias::new(T::NAME);
197 let step = sea_query::Table::drop().table(name).take();
198 self.drop.push(step);
199 }
200}
201
202fn new_table<T: Table>(conn: &Connection, alias: TmpTable) {
203 let mut f = crate::hash::TypBuilder::default();
204 T::typs(&mut f);
205 new_table_inner(conn, &f.ast, alias);
206}
207
208fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
209 let mut create = table.create();
210 create
211 .table(alias)
212 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
213 let mut sql = create.to_string(SqliteQueryBuilder);
214 sql.push_str(" STRICT");
215 conn.execute(&sql, []).unwrap();
216}
217
218pub trait SchemaMigration<'a> {
219 type From: Schema;
220 type To: Schema;
221
222 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
223}
224
225pub struct Config {
240 manager: r2d2_sqlite::SqliteConnectionManager,
241 init: Box<dyn FnOnce(&rusqlite::Transaction)>,
242}
243
244static ALLOWED: AtomicBool = AtomicBool::new(true);
245
246impl Config {
247 pub fn open(p: impl AsRef<Path>) -> Self {
254 let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
255 Self::open_internal(manager)
256 }
257
258 pub fn open_in_memory() -> Self {
260 let manager = r2d2_sqlite::SqliteConnectionManager::memory();
261 Self::open_internal(manager)
262 }
263
264 fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
265 assert!(ALLOWED.swap(false, std::sync::atomic::Ordering::Relaxed));
266 let manager = manager.with_init(|inner| {
267 inner.pragma_update(None, "journal_mode", "WAL")?;
268 inner.pragma_update(None, "synchronous", "NORMAL")?;
269 inner.pragma_update(None, "foreign_keys", "ON")?;
270 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
271 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
272 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
273 Ok(())
274 });
275
276 Self {
277 manager,
278 init: Box::new(|_| {}),
279 }
280 }
281
282 pub fn init_stmt(mut self, sql: &'static str) -> Self {
286 self.init = Box::new(move |txn| {
287 (self.init)(txn);
288
289 txn.execute_batch(sql)
290 .expect("raw sql statement to populate db failed");
291 });
292 self
293 }
294}
295
296impl LocalClient {
297 pub fn migrator<S: Schema>(&mut self, config: Config) -> Option<Migrator<'_, S>> {
301 use r2d2::ManageConnection;
302 let conn = self.conn.insert(config.manager.connect().unwrap());
303 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
304
305 let conn = conn
306 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
307 .unwrap();
308 let conn = Rc::new(conn);
309
310 if schema_version(&conn) == 0 {
312 let mut b = TableTypBuilder::default();
313 S::typs(&mut b);
314
315 for (table_name, table) in &*b.ast.tables {
316 new_table_inner(&conn, table, Alias::new(table_name));
317 }
318 (config.init)(&conn);
319 set_user_version(&conn, S::VERSION).unwrap();
320 }
321
322 let user_version = user_version(&conn).unwrap();
323 if user_version < S::VERSION {
325 return None;
326 }
327 assert_eq!(
328 foreign_key_check(&conn),
329 None,
330 "foreign key constraint violated"
331 );
332
333 Some(Migrator {
334 manager: config.manager,
335 transaction: conn,
336 _p: PhantomData,
337 _local: PhantomData,
338 _p0: PhantomData,
339 })
340 }
341}
342
343pub struct Migrator<'t, S> {
348 manager: r2d2_sqlite::SqliteConnectionManager,
349 transaction: Rc<rusqlite::Transaction<'t>>,
350 _p0: PhantomData<fn(&'t ()) -> &'t ()>,
351 _p: PhantomData<S>,
352 _local: PhantomData<LocalClient>,
356}
357
358pub struct Migrated<'t, FromSchema, T> {
363 _p: PhantomData<(fn(&'t ()) -> &'t (), T)>,
364 f: Box<dyn 't + FnOnce(&mut SchemaBuilder<'t, FromSchema>)>,
365}
366
367impl<'t, FromSchema: 'static, T: Table> Migrated<'t, FromSchema, T> {
368 pub fn map_fk_err(err: impl 't + FnOnce() -> Infallible) -> Self {
372 Self {
373 _p: PhantomData,
374 f: Box::new(|x| x.foreign_key::<T>(err)),
375 }
376 }
377
378 #[doc(hidden)]
379 pub fn apply(self, b: &mut SchemaBuilder<'t, FromSchema>) {
380 (self.f)(b)
381 }
382}
383
384impl<'t, S: Schema> Migrator<'t, S> {
385 pub fn migrate<M>(
389 self,
390 m: impl FnOnce(&mut TransactionMigrate<'t, S>) -> M,
391 ) -> Migrator<'t, M::To>
392 where
393 M: SchemaMigration<'t, From = S>,
394 {
395 if user_version(&self.transaction).unwrap() == S::VERSION {
396 check_schema::<S>(&self.transaction);
397
398 let mut txn = TransactionMigrate {
399 inner: Transaction::new(self.transaction.clone()),
400 scope: Default::default(),
401 rename_map: HashMap::new(),
402 };
403 let m = m(&mut txn);
404
405 let mut builder = SchemaBuilder {
406 drop: vec![],
407 foreign_key: HashMap::new(),
408 inner: txn,
409 };
410 m.tables(&mut builder);
411
412 for drop in builder.drop {
413 let sql = drop.to_string(SqliteQueryBuilder);
414 self.transaction.execute(&sql, []).unwrap();
415 }
416 for (to, tmp) in builder.inner.rename_map {
417 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
418 let sql = rename.to_string(SqliteQueryBuilder);
419 self.transaction.execute(&sql, []).unwrap();
420 }
421 if let Some(fk) = foreign_key_check(&self.transaction) {
422 (builder.foreign_key.remove(&*fk).unwrap())();
423 }
424 #[allow(
425 unreachable_code,
426 reason = "rustc is stupid and thinks this is unreachable"
427 )]
428 set_user_version(&self.transaction, M::To::VERSION).unwrap();
429 }
430
431 Migrator {
432 manager: self.manager,
433 transaction: self.transaction,
434 _p: PhantomData,
435 _local: PhantomData,
436 _p0: PhantomData,
437 }
438 }
439
440 pub fn finish(self) -> Option<Database<S>> {
446 let conn = &self.transaction;
447 if user_version(conn).unwrap() != S::VERSION {
448 return None;
449 }
450 check_schema::<S>(&self.transaction);
451
452 self.transaction.execute_batch("PRAGMA optimize;").unwrap();
454
455 let schema_version = schema_version(conn);
456 Rc::into_inner(self.transaction).unwrap().commit().unwrap();
457
458 Some(Database {
459 manager: self.manager,
460 schema_version,
461 schema: PhantomData,
462 })
463 }
464}
465
466pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
467 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
468 .unwrap()
469}
470
471fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
473 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
474}
475
476fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
478 conn.pragma_update(None, "user_version", v)
479}
480
481fn check_schema<S: Schema>(conn: &Rc<rusqlite::Transaction>) {
482 let mut b = TableTypBuilder::default();
483 S::typs(&mut b);
484 pretty_assertions::assert_eq!(
485 b.ast,
486 read_schema(&crate::Transaction::new(conn.clone())),
487 "schema is different (expected left, but got right)",
488 );
489}
490
491fn foreign_key_check(conn: &Rc<rusqlite::Transaction>) -> Option<String> {
492 let error = conn
493 .prepare("PRAGMA foreign_key_check")
494 .unwrap()
495 .query_map([], |row| row.get(2))
496 .unwrap()
497 .next();
498 error.transpose().unwrap()
499}