1pub mod config;
2pub mod migration;
3#[cfg(test)]
4mod test;
5
6use std::{
7 collections::{BTreeSet, HashMap},
8 marker::PhantomData,
9 sync::atomic::AtomicI64,
10};
11
12use annotate_snippets::{Renderer, renderer::DecorStyle};
13use rusqlite::config::DbConfig;
14use sea_query::{Alias, ColumnDef, IntoIden, SqliteQueryBuilder};
15use self_cell::MutBorrow;
16
17use crate::{
18 Table, Transaction,
19 alias::Scope,
20 migrate::{
21 config::Config,
22 migration::{SchemaBuilder, TransactionMigrate},
23 },
24 pool::Pool,
25 schema::{from_db, from_macro, read::read_schema},
26 transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
27};
28
29pub struct TableTypBuilder<S> {
30 pub(crate) ast: from_macro::Schema,
31 _p: PhantomData<S>,
32}
33
34impl<S> Default for TableTypBuilder<S> {
35 fn default() -> Self {
36 Self {
37 ast: Default::default(),
38 _p: Default::default(),
39 }
40 }
41}
42
43impl<S> TableTypBuilder<S> {
44 pub fn table<T: Table<Schema = S>>(&mut self) {
45 let table = from_macro::Table::new::<T>();
46 let old = self.ast.tables.insert(T::NAME, table);
47 debug_assert!(old.is_none());
48 }
49}
50
51pub trait Schema: Sized + 'static {
52 const VERSION: i64;
53 const SOURCE: &str;
54 const PATH: &str;
55 const SPAN: (usize, usize);
56 fn typs(b: &mut TableTypBuilder<Self>);
57}
58
59fn new_table_inner(table: &crate::schema::from_macro::Table, alias: impl IntoIden) -> String {
60 let alias = alias.into_iden();
61 let mut create = table.create();
62 create
63 .table(alias.clone())
64 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
65 let mut sql = create.to_string(SqliteQueryBuilder);
66 sql.push_str(" STRICT");
67 sql
68}
69
70pub trait SchemaMigration<'a> {
71 type From: Schema;
72 type To: Schema;
73
74 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
75}
76
77impl<S: Schema> Database<S> {
78 pub fn migrator(config: Config) -> Option<Migrator<S>> {
82 let synchronous = config.synchronous.as_str();
83 let foreign_keys = config.foreign_keys.as_str();
84 let manager = config.manager.with_init(move |inner| {
85 inner.pragma_update(None, "journal_mode", "WAL")?;
86 inner.pragma_update(None, "synchronous", synchronous)?;
87 inner.pragma_update(None, "foreign_keys", foreign_keys)?;
88 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
89 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
90 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
91 Ok(())
92 });
93
94 use r2d2::ManageConnection;
95 let conn = manager.connect().unwrap();
96 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
97 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
98 Some(
99 conn.borrow_mut()
100 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
101 .unwrap(),
102 )
103 });
104
105 if schema_version(txn.get()) == 0 {
107 let schema = crate::schema::from_macro::Schema::new::<S>();
108
109 for (&table_name, table) in &schema.tables {
110 txn.get()
111 .execute(&new_table_inner(table, table_name), [])
112 .unwrap();
113 for stmt in table.delayed_indices(table_name) {
114 txn.get().execute(&stmt, []).unwrap();
115 }
116 }
117 (config.init)(txn.get());
118 set_user_version(txn.get(), S::VERSION).unwrap();
119 }
120
121 let user_version = user_version(txn.get()).unwrap();
122 if user_version < S::VERSION {
124 return None;
125 }
126 debug_assert_eq!(
127 foreign_key_check(txn.get()),
128 None,
129 "foreign key constraint violated"
130 );
131
132 Some(Migrator {
133 indices_fixed: false,
134 manager,
135 transaction: txn,
136 _p: PhantomData,
137 })
138 }
139}
140
141pub struct Migrator<S> {
146 manager: r2d2_sqlite::SqliteConnectionManager,
147 transaction: OwnedTransaction,
148 indices_fixed: bool,
149 _p: PhantomData<S>,
150}
151
152impl<S: Schema> Migrator<S> {
153 pub fn migrate<'x, M>(
157 mut self,
158 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
159 ) -> Migrator<M::To>
160 where
161 M: SchemaMigration<'x, From = S>,
162 {
163 if user_version(self.transaction.get()).unwrap() == S::VERSION {
164 let res = std::thread::scope(|s| {
165 s.spawn(|| {
166 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
167 let txn = Transaction::new_ref();
168
169 check_schema::<S>(txn);
170 if !self.indices_fixed {
171 fix_indices::<S>(txn);
172 self.indices_fixed = true;
173 }
174
175 let mut txn = TransactionMigrate {
176 inner: Transaction::new(),
177 scope: Default::default(),
178 rename_map: HashMap::new(),
179 extra_index: Vec::new(),
180 };
181 let m = m(&mut txn);
182
183 let mut builder = SchemaBuilder {
184 drop: vec![],
185 foreign_key: HashMap::new(),
186 inner: txn,
187 };
188 m.tables(&mut builder);
189
190 let transaction = TXN.take().unwrap();
191
192 for drop in builder.drop {
193 let sql = drop.to_string(SqliteQueryBuilder);
194 transaction.get().execute(&sql, []).unwrap();
195 }
196 for (to, tmp) in builder.inner.rename_map {
197 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
198 let sql = rename.to_string(SqliteQueryBuilder);
199 transaction.get().execute(&sql, []).unwrap();
200 }
201 if let Some(fk) = foreign_key_check(transaction.get()) {
202 (builder.foreign_key.remove(&*fk).unwrap())();
203 }
204 #[allow(
205 unreachable_code,
206 reason = "rustc is stupid and thinks this is unreachable"
207 )]
208 for stmt in builder.inner.extra_index {
210 transaction.get().execute(&stmt, []).unwrap();
211 }
212 set_user_version(transaction.get(), M::To::VERSION).unwrap();
213
214 transaction.into_owner()
215 })
216 .join()
217 });
218 match res {
219 Ok(val) => self.transaction = val,
220 Err(payload) => std::panic::resume_unwind(payload),
221 }
222 }
223
224 Migrator {
225 indices_fixed: self.indices_fixed,
226 manager: self.manager,
227 transaction: self.transaction,
228 _p: PhantomData,
229 }
230 }
231
232 pub fn finish(mut self) -> Option<Database<S>> {
238 if user_version(self.transaction.get()).unwrap() != S::VERSION {
239 return None;
240 }
241
242 let res = std::thread::scope(|s| {
243 s.spawn(|| {
244 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
245 let txn = Transaction::new_ref();
246
247 check_schema::<S>(txn);
248 if !self.indices_fixed {
249 fix_indices::<S>(txn);
250 self.indices_fixed = true;
251 }
252
253 TXN.take().unwrap().into_owner()
254 })
255 .join()
256 });
257 match res {
258 Ok(val) => self.transaction = val,
259 Err(payload) => std::panic::resume_unwind(payload),
260 }
261
262 self.transaction
264 .get()
265 .execute_batch("PRAGMA optimize;")
266 .unwrap();
267
268 let schema_version = schema_version(self.transaction.get());
269 self.transaction.with(|x| x.commit().unwrap());
270
271 Some(Database {
272 manager: Pool::new(self.manager),
273 schema_version: AtomicI64::new(schema_version),
274 schema: PhantomData,
275 mut_lock: parking_lot::FairMutex::new(()),
276 })
277 }
278}
279
280fn fix_indices<S: Schema>(txn: &Transaction<S>) {
281 let schema = read_schema(txn);
282 let expected_schema = crate::schema::from_macro::Schema::new::<S>();
283
284 fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
285 let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
286 let actual: BTreeSet<_> = actual.indices.values().collect();
287 expected == actual
288 }
289
290 for (&table_name, expected_table) in &expected_schema.tables {
291 let table = &schema.tables[table_name];
292
293 if !check_eq(expected_table, &table) {
294 let scope = Scope::default();
299 let tmp_name = scope.tmp_table();
300
301 txn.execute(&new_table_inner(expected_table, tmp_name));
302
303 let mut columns: Vec<_> = expected_table
304 .columns
305 .keys()
306 .map(|x| Alias::new(x))
307 .collect();
308 columns.push(Alias::new("id"));
309
310 txn.execute(
311 &sea_query::InsertStatement::new()
312 .into_table(tmp_name)
313 .columns(columns.clone())
314 .select_from(
315 sea_query::SelectStatement::new()
316 .from(table_name)
317 .columns(columns)
318 .take(),
319 )
320 .unwrap()
321 .build(SqliteQueryBuilder)
322 .0,
323 );
324
325 txn.execute(
326 &sea_query::TableDropStatement::new()
327 .table(table_name)
328 .build(SqliteQueryBuilder),
329 );
330
331 txn.execute(
332 &sea_query::TableRenameStatement::new()
333 .table(tmp_name, table_name)
334 .build(SqliteQueryBuilder),
335 );
336 for sql in expected_table.delayed_indices(table_name) {
338 txn.execute(&sql);
339 }
340 }
341 }
342
343 let schema = read_schema(txn);
345 for (name, table) in schema.tables {
346 let expected_table = &expected_schema.tables[&*name];
347 assert!(check_eq(expected_table, &table));
348 }
349}
350
351impl<S> Transaction<S> {
352 #[track_caller]
353 pub(crate) fn execute(&self, sql: &str) {
354 TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
355 .unwrap();
356 }
357}
358
359pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
360 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
361 .unwrap()
362}
363
364pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
366 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
367}
368
369fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
371 conn.pragma_update(None, "user_version", v)
372}
373
374pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
375 let from_macro = crate::schema::from_macro::Schema::new::<S>();
376 let from_db = read_schema(txn);
377 let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
378 if !report.is_empty() {
379 let renderer = if cfg!(test) {
380 Renderer::plain().anonymized_line_numbers(true)
381 } else {
382 Renderer::styled()
383 }
384 .decor_style(DecorStyle::Unicode);
385 panic!("{}", renderer.render(&report));
386 }
387}
388
389fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
390 let error = conn
391 .prepare("PRAGMA foreign_key_check")
392 .unwrap()
393 .query_map([], |row| row.get(2))
394 .unwrap()
395 .next();
396 error.transpose().unwrap()
397}
398
399impl<S> Transaction<S> {
400 #[cfg(test)]
401 pub(crate) fn schema(&self) -> Vec<String> {
402 TXN.with_borrow(|x| {
403 x.as_ref()
404 .unwrap()
405 .get()
406 .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
407 .unwrap()
408 .query_map([], |row| row.get::<_, Option<String>>("sql"))
409 .unwrap()
410 .flat_map(|x| x.unwrap())
411 .collect()
412 })
413 }
414}