1pub mod config;
2pub mod migration;
3#[cfg(test)]
4mod test;
5
6use std::{
7 collections::{BTreeSet, HashMap},
8 marker::PhantomData,
9 sync::atomic::AtomicI64,
10};
11
12use annotate_snippets::{Renderer, renderer::DecorStyle};
13use rusqlite::config::DbConfig;
14use sea_query::{Alias, ColumnDef, IntoIden, SqliteQueryBuilder};
15use self_cell::MutBorrow;
16
17use crate::{
18 Table, Transaction,
19 alias::Scope,
20 migrate::{
21 config::Config,
22 migration::{SchemaBuilder, TransactionMigrate},
23 },
24 pool::Pool,
25 schema::{from_db, from_macro, read::read_schema},
26 transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
27};
28
29pub struct TableTypBuilder<S> {
30 pub(crate) ast: from_macro::Schema,
31 _p: PhantomData<S>,
32}
33
34impl<S> Default for TableTypBuilder<S> {
35 fn default() -> Self {
36 Self {
37 ast: Default::default(),
38 _p: Default::default(),
39 }
40 }
41}
42
43impl<S> TableTypBuilder<S> {
44 pub fn table<T: Table<Schema = S>>(&mut self) {
45 let table = from_macro::Table::new::<T>();
46 let old = self.ast.tables.insert(T::NAME, table);
47 debug_assert!(old.is_none());
48 }
49}
50
51pub trait Schema: Sized + 'static {
52 const VERSION: i64;
53 const SOURCE: &str;
54 const PATH: &str;
55 const SPAN: (usize, usize);
56 fn typs(b: &mut TableTypBuilder<Self>);
57}
58
59fn new_table_inner(table: &crate::schema::from_macro::Table, alias: impl IntoIden) -> String {
60 let alias = alias.into_iden();
61 let mut create = table.create();
62 create
63 .table(alias.clone())
64 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
65 let mut sql = create.to_string(SqliteQueryBuilder);
66 sql.push_str(" STRICT");
67 sql
68}
69
70pub trait SchemaMigration<'a> {
71 type From: Schema;
72 type To: Schema;
73
74 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
75}
76
77impl<S: Schema> Database<S> {
78 pub fn migrator(config: Config) -> Option<Migrator<S>> {
82 let synchronous = config.synchronous.as_str();
83 let foreign_keys = config.foreign_keys.as_str();
84 let manager = config.manager.with_init(move |inner| {
85 inner.pragma_update(None, "journal_mode", "WAL")?;
86 inner.pragma_update(None, "synchronous", synchronous)?;
87 inner.pragma_update(None, "foreign_keys", foreign_keys)?;
88 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
89 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
90 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
91 Ok(())
92 });
93
94 use r2d2::ManageConnection;
95 let conn = manager.connect().unwrap();
96 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
97 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
98 Some(
99 conn.borrow_mut()
100 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
101 .unwrap(),
102 )
103 });
104
105 let mut user_version = Some(user_version(txn.get()).unwrap());
106
107 if schema_version(txn.get()) == 0 {
109 user_version = None;
110
111 let schema = crate::schema::from_macro::Schema::new::<S>();
112
113 for (&table_name, table) in &schema.tables {
114 txn.get()
115 .execute(&new_table_inner(table, table_name), [])
116 .unwrap();
117 for stmt in table.delayed_indices(table_name) {
118 txn.get().execute(&stmt, []).unwrap();
119 }
120 }
121 (config.init)(txn.get());
122 } else if user_version.unwrap() < S::VERSION {
123 return None;
125 }
126
127 debug_assert_eq!(
128 foreign_key_check(txn.get()),
129 None,
130 "foreign key constraint violated"
131 );
132
133 Some(Migrator {
134 user_version,
135 manager,
136 transaction: txn,
137 _p: PhantomData,
138 })
139 }
140}
141
142pub struct Migrator<S> {
147 manager: r2d2_sqlite::SqliteConnectionManager,
148 transaction: OwnedTransaction,
149 user_version: Option<i64>,
154 _p: PhantomData<S>,
155}
156
157impl<S: Schema> Migrator<S> {
158 fn with_transaction(mut self, f: impl Send + FnOnce(&mut Transaction<S>)) -> Self {
159 assert!(self.user_version.is_none_or(|x| x == S::VERSION));
160 let res = std::thread::scope(|s| {
161 s.spawn(|| {
162 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
163 let txn = Transaction::new_ref();
164
165 if self.user_version.take().is_some() {
167 check_schema::<S>(txn);
169 fix_indices::<S>(txn);
171 }
172
173 f(txn);
174
175 let transaction = TXN.take().unwrap();
176
177 transaction.into_owner()
178 })
179 .join()
180 });
181 match res {
182 Ok(val) => self.transaction = val,
183 Err(payload) => std::panic::resume_unwind(payload),
184 }
185 self
186 }
187
188 pub fn migrate<'x, M>(
192 mut self,
193 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
194 ) -> Migrator<M::To>
195 where
196 M: SchemaMigration<'x, From = S>,
197 {
198 if self.user_version.is_none_or(|x| x == S::VERSION) {
199 self = self.with_transaction(|txn| {
200 let mut txn = TransactionMigrate {
201 inner: txn.copy(),
202 scope: Default::default(),
203 rename_map: HashMap::new(),
204 extra_index: Vec::new(),
205 };
206 let m = m(&mut txn);
207
208 let mut builder = SchemaBuilder {
209 drop: vec![],
210 foreign_key: HashMap::new(),
211 inner: txn,
212 };
213 m.tables(&mut builder);
214
215 let transaction = TXN.take().unwrap();
216
217 for drop in builder.drop {
218 let sql = drop.to_string(SqliteQueryBuilder);
219 transaction.get().execute(&sql, []).unwrap();
220 }
221 for (to, tmp) in builder.inner.rename_map {
222 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
223 let sql = rename.to_string(SqliteQueryBuilder);
224 transaction.get().execute(&sql, []).unwrap();
225 }
226 if let Some(fk) = foreign_key_check(transaction.get()) {
227 (builder.foreign_key.remove(&*fk).unwrap())();
228 }
229 #[allow(
230 unreachable_code,
231 reason = "rustc is stupid and thinks this is unreachable"
232 )]
233 for stmt in builder.inner.extra_index {
235 transaction.get().execute(&stmt, []).unwrap();
236 }
237
238 TXN.set(Some(transaction));
239 });
240 }
241
242 Migrator {
243 user_version: self.user_version,
244 manager: self.manager,
245 transaction: self.transaction,
246 _p: PhantomData,
247 }
248 }
249
250 pub fn fixup(mut self, f: impl Send + FnOnce(&mut Transaction<S>)) -> Self {
257 if self.user_version.is_none() {
258 self = self.with_transaction(f);
259 }
260 self
261 }
262
263 pub fn finish(mut self) -> Option<Database<S>> {
269 if self.user_version.is_some_and(|x| x != S::VERSION) {
270 return None;
271 }
272
273 self = self.with_transaction(|txn| {
275 check_schema::<S>(txn);
277 });
278
279 self.transaction
281 .get()
282 .execute_batch("PRAGMA optimize;")
283 .unwrap();
284
285 set_user_version(self.transaction.get(), S::VERSION).unwrap();
286 let schema_version = schema_version(self.transaction.get());
287 self.transaction.with(|x| x.commit().unwrap());
288
289 Some(Database {
290 manager: Pool::new(self.manager),
291 schema_version: AtomicI64::new(schema_version),
292 schema: PhantomData,
293 mut_lock: parking_lot::FairMutex::new(()),
294 })
295 }
296}
297
298fn fix_indices<S: Schema>(txn: &Transaction<S>) {
299 let schema = read_schema(txn);
300 let expected_schema = crate::schema::from_macro::Schema::new::<S>();
301
302 fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
303 let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
304 let actual: BTreeSet<_> = actual.indices.values().collect();
305 expected == actual
306 }
307
308 for (&table_name, expected_table) in &expected_schema.tables {
309 let table = &schema.tables[table_name];
310
311 if !check_eq(expected_table, &table) {
312 let scope = Scope::default();
317 let tmp_name = scope.tmp_table();
318
319 txn.execute(&new_table_inner(expected_table, tmp_name));
320
321 let mut columns: Vec<_> = expected_table
322 .columns
323 .keys()
324 .map(|x| Alias::new(x))
325 .collect();
326 columns.push(Alias::new("id"));
327
328 txn.execute(
329 &sea_query::InsertStatement::new()
330 .into_table(tmp_name)
331 .columns(columns.clone())
332 .select_from(
333 sea_query::SelectStatement::new()
334 .from(table_name)
335 .columns(columns)
336 .take(),
337 )
338 .unwrap()
339 .build(SqliteQueryBuilder)
340 .0,
341 );
342
343 txn.execute(
344 &sea_query::TableDropStatement::new()
345 .table(table_name)
346 .build(SqliteQueryBuilder),
347 );
348
349 txn.execute(
350 &sea_query::TableRenameStatement::new()
351 .table(tmp_name, table_name)
352 .build(SqliteQueryBuilder),
353 );
354 for sql in expected_table.delayed_indices(table_name) {
356 txn.execute(&sql);
357 }
358 }
359 }
360
361 let schema = read_schema(txn);
363 for (name, table) in schema.tables {
364 let expected_table = &expected_schema.tables[&*name];
365 assert!(check_eq(expected_table, &table));
366 }
367}
368
369impl<S> Transaction<S> {
370 #[track_caller]
371 pub(crate) fn execute(&self, sql: &str) {
372 TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
373 .unwrap();
374 }
375}
376
377pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
378 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
379 .unwrap()
380}
381
382pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
384 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
385}
386
387fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
389 conn.pragma_update(None, "user_version", v)
390}
391
392pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
393 let from_macro = crate::schema::from_macro::Schema::new::<S>();
394 let from_db = read_schema(txn);
395 let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
396 if !report.is_empty() {
397 let renderer = if cfg!(test) {
398 Renderer::plain().anonymized_line_numbers(true)
399 } else {
400 Renderer::styled()
401 }
402 .decor_style(DecorStyle::Unicode);
403 panic!("{}", renderer.render(&report));
404 }
405}
406
407fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
408 let error = conn
409 .prepare("PRAGMA foreign_key_check")
410 .unwrap()
411 .query_map([], |row| row.get(2))
412 .unwrap()
413 .next();
414 error.transpose().unwrap()
415}
416
417impl<S> Transaction<S> {
418 #[cfg(test)]
419 pub(crate) fn schema(&self) -> Vec<String> {
420 TXN.with_borrow(|x| {
421 x.as_ref()
422 .unwrap()
423 .get()
424 .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
425 .unwrap()
426 .query_map([], |row| row.get::<_, Option<String>>("sql"))
427 .unwrap()
428 .flat_map(|x| x.unwrap())
429 .collect()
430 })
431 }
432}