1use std::{marker::PhantomData, path::Path, sync::atomic::AtomicBool};
2
3use rusqlite::{config::DbConfig, Connection};
4use sea_query::{
5 Alias, ColumnDef, InsertStatement, IntoTableRef, SqliteQueryBuilder, TableDropStatement,
6 TableRenameStatement,
7};
8use sea_query_rusqlite::RusqliteBinder;
9
10use crate::{
11 alias::{Scope, TmpTable},
12 ast::MySelect,
13 dummy::{Cached, Cacher},
14 hash,
15 insert::Reader,
16 pragma::read_schema,
17 token::LocalClient,
18 transaction::Database,
19 value, Column, IntoColumn, Rows, Table,
20};
21
22pub type M<'a, From, To> = Box<
23 dyn 'a
24 + for<'t> FnOnce(
25 ::rust_query::Column<'t, <From as Table>::Schema, From>,
26 ) -> Alter<'t, 'a, From, To>,
27>;
28
29pub struct Alter<'t, 'a, From, To> {
40 _p: PhantomData<&'t &'a ()>,
41 inner: Box<dyn TableMigration<'t, 'a, From = From, To = To> + 't>,
42}
43
44impl<'t, 'a, From, To> Alter<'t, 'a, From, To> {
45 pub fn new(val: impl TableMigration<'t, 'a, From = From, To = To> + 't) -> Self {
46 Self {
47 _p: PhantomData,
48 inner: Box::new(val),
49 }
50 }
51}
52
53pub type C<'a, FromSchema, To> =
54 Box<dyn 'a + for<'t> FnOnce(&mut Rows<'t, FromSchema>) -> Create<'t, 'a, FromSchema, To>>;
55
56pub struct Create<'t, 'a, FromSchema, To> {
60 _p: PhantomData<&'t &'a ()>,
61 inner: Box<dyn TableCreation<'t, 'a, FromSchema = FromSchema, To = To> + 't>,
62}
63
64impl<'t, 'a, FromSchema, To: 'a> Create<'t, 'a, FromSchema, To> {
65 pub fn new(val: impl TableCreation<'t, 'a, FromSchema = FromSchema, To = To> + 't) -> Self {
66 Self {
67 _p: PhantomData,
68 inner: Box::new(val),
69 }
70 }
71
72 pub fn empty(rows: &mut Rows<'t, FromSchema>) -> Self {
74 rows.filter(false);
75 Create::new(NeverCreate(PhantomData, PhantomData))
76 }
77}
78
79struct NeverCreate<FromSchema, To>(PhantomData<FromSchema>, PhantomData<To>);
80
81impl<'t, 'a, FromSchema, To> TableCreation<'t, 'a> for NeverCreate<FromSchema, To> {
82 type FromSchema = FromSchema;
83 type To = To;
84
85 fn prepare(
86 self: Box<Self>,
87 _: Cacher<'_, 't, Self::FromSchema>,
88 ) -> Box<dyn FnMut(crate::private::Row<'_, 't, 'a>, Reader<'_, 't, Self::FromSchema>) + 't>
89 where
90 'a: 't,
91 {
92 Box::new(|_, _| unreachable!())
93 }
94}
95
96pub struct TableTypBuilder<S> {
97 pub(crate) ast: hash::Schema,
98 _p: PhantomData<S>,
99}
100
101impl<S> Default for TableTypBuilder<S> {
102 fn default() -> Self {
103 Self {
104 ast: Default::default(),
105 _p: Default::default(),
106 }
107 }
108}
109
110impl<S> TableTypBuilder<S> {
111 pub fn table<T: Table<Schema = S>>(&mut self) {
112 let mut b = hash::TypBuilder::default();
113 T::typs(&mut b);
114 self.ast.tables.insert((T::NAME.to_owned(), b.ast));
115 }
116}
117
118pub trait Schema: Sized + 'static {
119 const VERSION: i64;
120 fn typs(b: &mut TableTypBuilder<Self>);
121}
122
123pub trait TableMigration<'t, 'a> {
124 type From: Table;
125 type To;
126
127 fn prepare(
128 self: Box<Self>,
129 prev: Cached<'t, Self::From>,
130 cacher: Cacher<'_, 't, <Self::From as Table>::Schema>,
131 ) -> Box<
132 dyn FnMut(crate::private::Row<'_, 't, 'a>, Reader<'_, 't, <Self::From as Table>::Schema>)
133 + 't,
134 >
135 where
136 'a: 't;
137}
138
139pub trait TableCreation<'t, 'a> {
140 type FromSchema;
141 type To;
142
143 fn prepare(
144 self: Box<Self>,
145 cacher: Cacher<'_, 't, Self::FromSchema>,
146 ) -> Box<dyn FnMut(crate::private::Row<'_, 't, 'a>, Reader<'_, 't, Self::FromSchema>) + 't>
147 where
148 'a: 't;
149}
150
151struct Wrapper<'t, 'a, From: Table, To>(
152 Box<dyn TableMigration<'t, 'a, From = From, To = To> + 't>,
153 Column<'t, From::Schema, From>,
154);
155
156impl<'t, 'a, From: Table, To> TableCreation<'t, 'a> for Wrapper<'t, 'a, From, To> {
157 type FromSchema = From::Schema;
158 type To = To;
159
160 fn prepare(
161 self: Box<Self>,
162 mut cacher: Cacher<'_, 't, Self::FromSchema>,
163 ) -> Box<dyn FnMut(crate::private::Row<'_, 't, 'a>, Reader<'_, 't, Self::FromSchema>) + 't>
164 where
165 'a: 't,
166 {
167 let db_id = cacher.cache(self.1);
168 let mut prepared = Box::new(self.0).prepare(db_id, cacher);
169 Box::new(move |row, reader| {
170 reader.col(From::ID, row.get(db_id));
172 prepared(row, reader);
173 })
174 }
175}
176
177impl<'inner, S> Rows<'inner, S> {
178 fn cacher<'t>(&'_ self) -> Cacher<'_, 't, S> {
179 Cacher {
180 ast: &self.ast,
181 _p: PhantomData,
182 }
183 }
184}
185
186pub struct SchemaBuilder<'x, 'a> {
187 scope: Scope,
189 conn: &'x rusqlite::Transaction<'x>,
190 drop: Vec<TableDropStatement>,
191 rename: Vec<TableRenameStatement>,
192 _p: PhantomData<fn(&'a ()) -> &'a ()>,
193}
194
195impl<'a> SchemaBuilder<'_, 'a> {
196 pub fn migrate_table<From: Table, To: Table>(&mut self, m: M<'a, From, To>) {
197 self.create_inner::<From::Schema, To>(|rows| {
198 let db_id = From::join(rows);
199 let migration = m(db_id.clone());
200 Create::new(Wrapper(migration.inner, db_id))
201 });
202
203 self.drop.push(
204 sea_query::Table::drop()
205 .table(Alias::new(From::NAME))
206 .take(),
207 );
208 }
209
210 pub fn create_from<FromSchema, To: Table>(&mut self, f: C<'a, FromSchema, To>) {
211 self.create_inner::<FromSchema, To>(f);
212 }
213
214 fn create_inner<FromSchema, To: Table>(
215 &mut self,
216 f: impl for<'t> FnOnce(&mut Rows<'t, FromSchema>) -> Create<'t, 'a, FromSchema, To>,
217 ) {
218 let new_table_name = self.scope.tmp_table();
219 new_table::<To>(self.conn, new_table_name);
220
221 self.rename.push(
222 sea_query::Table::rename()
223 .table(new_table_name, Alias::new(To::NAME))
224 .take(),
225 );
226
227 let mut q = Rows::<FromSchema> {
228 phantom: PhantomData,
229 ast: MySelect::default(),
230 };
231 let create = f(&mut q);
232 let mut prepared = create.inner.prepare(q.cacher());
233
234 let select = q.ast.simple();
235 let (sql, values) = select.build_rusqlite(SqliteQueryBuilder);
236
237 let mut statement = self.conn.prepare(&sql).unwrap();
239 let mut rows = statement.query(&*values.as_params()).unwrap();
240
241 while let Some(row) = rows.next().unwrap() {
242 let row = crate::private::Row {
243 _p: PhantomData,
244 _p2: PhantomData,
245 row,
246 };
247
248 let new_ast = MySelect::default();
249 let reader = Reader {
250 ast: &new_ast,
251 _p: PhantomData,
252 _p2: PhantomData,
253 };
254 prepared(row, reader);
255
256 let mut insert = InsertStatement::new();
257 let names = new_ast.select.iter().map(|(_field, name)| *name);
258 insert.into_table(new_table_name);
259 insert.columns(names);
260 insert.select_from(new_ast.simple()).unwrap();
261
262 let (sql, values) = insert.build_rusqlite(SqliteQueryBuilder);
263 let mut statement = self.conn.prepare_cached(&sql).unwrap();
264 statement.execute(&*values.as_params()).unwrap();
265 }
266 }
267
268 pub fn drop_table<T: Table>(&mut self) {
269 let name = Alias::new(T::NAME);
270 let step = sea_query::Table::drop().table(name).take();
271 self.drop.push(step);
272 }
273}
274
275fn new_table<T: Table>(conn: &Connection, alias: TmpTable) {
276 let mut f = crate::hash::TypBuilder::default();
277 T::typs(&mut f);
278 new_table_inner(conn, &f.ast, alias);
279}
280
281fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
282 let mut create = table.create();
283 create
284 .table(alias)
285 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
286 let mut sql = create.to_string(SqliteQueryBuilder);
287 sql.push_str(" STRICT");
288 conn.execute(&sql, []).unwrap();
289}
290
291pub trait Migration<'a> {
292 type From: Schema;
293 type To: Schema;
294
295 fn tables(self, b: &mut SchemaBuilder<'_, 'a>);
296}
297
298pub struct Config {
303 manager: r2d2_sqlite::SqliteConnectionManager,
304 init: Box<dyn FnOnce(&rusqlite::Transaction)>,
305}
306
307static ALLOWED: AtomicBool = AtomicBool::new(true);
308
309impl Config {
310 pub fn open(p: impl AsRef<Path>) -> Self {
317 let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
318 Self::open_internal(manager)
319 }
320
321 pub fn open_in_memory() -> Self {
323 let manager = r2d2_sqlite::SqliteConnectionManager::memory();
324 Self::open_internal(manager)
325 }
326
327 fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
328 assert!(ALLOWED.swap(false, std::sync::atomic::Ordering::Relaxed));
329 let manager = manager.with_init(|inner| {
330 inner.pragma_update(None, "journal_mode", "WAL")?;
331 inner.pragma_update(None, "synchronous", "NORMAL")?;
332 inner.pragma_update(None, "foreign_keys", "ON")?;
333 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
334 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
335 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
336 Ok(())
337 });
338
339 Self {
340 manager,
341 init: Box::new(|_| {}),
342 }
343 }
344
345 pub fn init_stmt(mut self, sql: &'static str) -> Self {
349 self.init = Box::new(move |txn| {
350 (self.init)(txn);
351
352 txn.execute_batch(sql)
353 .expect("raw sql statement to populate db failed");
354 });
355 self
356 }
357}
358
359impl LocalClient {
360 pub fn migrator<'t, S: Schema>(&'t mut self, config: Config) -> Option<Migrator<'t, S>> {
366 use r2d2::ManageConnection;
367 let conn = self.conn.insert(config.manager.connect().unwrap());
368 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
369
370 let conn = conn
371 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
372 .unwrap();
373
374 if schema_version(&conn) == 0 {
376 let mut b = TableTypBuilder::default();
377 S::typs(&mut b);
378
379 for (table_name, table) in &*b.ast.tables {
380 new_table_inner(&conn, table, Alias::new(table_name));
381 }
382 (config.init)(&conn);
383 set_user_version(&conn, S::VERSION).unwrap();
384 }
385
386 let user_version = user_version(&conn).unwrap();
387 if user_version < S::VERSION {
389 return None;
390 } else if user_version == S::VERSION {
391 foreign_key_check::<S>(&conn);
392 }
393
394 Some(Migrator {
395 manager: config.manager,
396 transaction: conn,
397 _p: PhantomData,
398 _local: PhantomData,
399 })
400 }
401}
402
403pub struct Migrator<'t, S> {
408 manager: r2d2_sqlite::SqliteConnectionManager,
409 transaction: rusqlite::Transaction<'t>,
410 _p: PhantomData<S>,
411 _local: PhantomData<LocalClient>,
415}
416
417impl<'t, S: Schema> Migrator<'t, S> {
418 pub fn migrate<M, N: Schema>(self, m: M) -> Migrator<'t, N>
422 where
423 M: Migration<'t, From = S, To = N>,
424 {
425 let conn = &self.transaction;
426
427 if user_version(conn).unwrap() == S::VERSION {
428 let mut builder = SchemaBuilder {
429 scope: Default::default(),
430 conn,
431 drop: vec![],
432 rename: vec![],
433 _p: PhantomData,
434 };
435 m.tables(&mut builder);
436 for drop in builder.drop {
437 let sql = drop.to_string(SqliteQueryBuilder);
438 conn.execute(&sql, []).unwrap();
439 }
440 for rename in builder.rename {
441 let sql = rename.to_string(SqliteQueryBuilder);
442 conn.execute(&sql, []).unwrap();
443 }
444 foreign_key_check::<N>(conn);
445 set_user_version(conn, N::VERSION).unwrap();
446 }
447
448 Migrator {
449 manager: self.manager,
450 transaction: self.transaction,
451 _p: PhantomData,
452 _local: PhantomData,
453 }
454 }
455
456 pub fn finish(self) -> Option<Database<S>> {
460 let conn = &self.transaction;
461 if user_version(conn).unwrap() != S::VERSION {
462 return None;
463 }
464
465 let schema_version = schema_version(conn);
466 self.transaction.commit().unwrap();
467
468 Some(Database {
469 manager: self.manager,
470 schema_version,
471 schema: PhantomData,
472 })
473 }
474}
475
476pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
477 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
478 .unwrap()
479}
480
481fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
483 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
484}
485
486fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
488 conn.pragma_update(None, "user_version", v)
489}
490
491fn foreign_key_check<S: Schema>(conn: &rusqlite::Transaction) {
492 let errors = conn
493 .prepare("PRAGMA foreign_key_check")
494 .unwrap()
495 .query_map([], |_| Ok(()))
496 .unwrap()
497 .count();
498 if errors != 0 {
499 panic!("migration violated foreign key constraint")
500 }
501
502 let mut b = TableTypBuilder::default();
503 S::typs(&mut b);
504 pretty_assertions::assert_eq!(
505 b.ast,
506 read_schema(conn),
507 "schema is different (expected left, but got right)",
508 );
509}
510
511#[derive(Clone, Copy)]
513pub struct NoTable(());
514
515impl value::Typed for NoTable {
516 type Typ = NoTable;
517 fn build_expr(&self, _b: value::ValueBuilder) -> sea_query::SimpleExpr {
518 unreachable!("NoTable can not be constructed")
519 }
520}
521impl<S> IntoColumn<'_, S> for NoTable {
522 type Owned = Self;
523
524 fn into_owned(self) -> Self::Owned {
525 self
526 }
527}