1pub mod config;
2pub mod migration;
3
4use std::{
5 collections::{BTreeSet, HashMap},
6 marker::PhantomData,
7 sync::atomic::AtomicI64,
8};
9
10use annotate_snippets::{Renderer, renderer::DecorStyle};
11use rusqlite::{Connection, config::DbConfig};
12use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder};
13use self_cell::MutBorrow;
14
15use crate::{
16 Table, Transaction,
17 migrate::{
18 config::Config,
19 migration::{SchemaBuilder, TransactionMigrate},
20 },
21 schema::{
22 from_db, from_macro,
23 read::{read_index_names_for_table, read_schema},
24 },
25 transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
26};
27
28pub struct TableTypBuilder<S> {
29 pub(crate) ast: from_macro::Schema,
30 _p: PhantomData<S>,
31}
32
33impl<S> Default for TableTypBuilder<S> {
34 fn default() -> Self {
35 Self {
36 ast: Default::default(),
37 _p: Default::default(),
38 }
39 }
40}
41
42impl<S> TableTypBuilder<S> {
43 pub fn table<T: Table<Schema = S>>(&mut self) {
44 let table = from_macro::Table::new::<T>();
45 let old = self.ast.tables.insert(T::NAME.to_owned(), table);
46 debug_assert!(old.is_none());
47 }
48}
49
50pub trait Schema: Sized + 'static {
51 const VERSION: i64;
52 const SOURCE: &str;
53 const PATH: &str;
54 const SPAN: (usize, usize);
55 fn typs(b: &mut TableTypBuilder<Self>);
56}
57
58fn new_table_inner(
59 conn: &Connection,
60 table: &crate::schema::from_macro::Table,
61 alias: impl IntoTableRef,
62) {
63 let mut create = table.create();
64 create
65 .table(alias)
66 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
67 let mut sql = create.to_string(SqliteQueryBuilder);
68 sql.push_str(" STRICT");
69 conn.execute(&sql, []).unwrap();
70}
71
72pub trait SchemaMigration<'a> {
73 type From: Schema;
74 type To: Schema;
75
76 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
77}
78
79impl<S: Schema> Database<S> {
80 pub fn migrator(config: Config) -> Option<Migrator<S>> {
84 let synchronous = config.synchronous.as_str();
85 let foreign_keys = config.foreign_keys.as_str();
86 let manager = config.manager.with_init(move |inner| {
87 inner.pragma_update(None, "journal_mode", "WAL")?;
88 inner.pragma_update(None, "synchronous", synchronous)?;
89 inner.pragma_update(None, "foreign_keys", foreign_keys)?;
90 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
91 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
92 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
93 Ok(())
94 });
95
96 use r2d2::ManageConnection;
97 let conn = manager.connect().unwrap();
98 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
99 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
100 Some(
101 conn.borrow_mut()
102 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
103 .unwrap(),
104 )
105 });
106
107 if schema_version(txn.get()) == 0 {
109 let schema = crate::schema::from_macro::Schema::new::<S>();
110
111 for (table_name, table) in &schema.tables {
112 let table_name_ref = Alias::new(table_name);
113 new_table_inner(txn.get(), table, table_name_ref);
114 for stmt in table.create_indices(table_name) {
115 txn.get().execute(&stmt, []).unwrap();
116 }
117 }
118 (config.init)(txn.get());
119 set_user_version(txn.get(), S::VERSION).unwrap();
120 }
121
122 let user_version = user_version(txn.get()).unwrap();
123 if user_version < S::VERSION {
125 return None;
126 }
127 debug_assert_eq!(
128 foreign_key_check(txn.get()),
129 None,
130 "foreign key constraint violated"
131 );
132
133 Some(Migrator {
134 indices_fixed: false,
135 manager,
136 transaction: txn,
137 _p: PhantomData,
138 })
139 }
140}
141
142pub struct Migrator<S> {
147 manager: r2d2_sqlite::SqliteConnectionManager,
148 transaction: OwnedTransaction,
149 indices_fixed: bool,
150 _p: PhantomData<S>,
151}
152
153impl<S: Schema> Migrator<S> {
154 pub fn migrate<'x, M>(
158 mut self,
159 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
160 ) -> Migrator<M::To>
161 where
162 M: SchemaMigration<'x, From = S>,
163 {
164 if user_version(self.transaction.get()).unwrap() == S::VERSION {
165 let res = std::thread::scope(|s| {
166 s.spawn(|| {
167 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
168 let txn = Transaction::new_ref();
169
170 check_schema::<S>(txn);
171 if !self.indices_fixed {
172 fix_indices::<S>(txn);
173 self.indices_fixed = true;
174 }
175
176 let mut txn = TransactionMigrate {
177 inner: Transaction::new(),
178 scope: Default::default(),
179 rename_map: HashMap::new(),
180 extra_index: Vec::new(),
181 };
182 let m = m(&mut txn);
183
184 let mut builder = SchemaBuilder {
185 drop: vec![],
186 foreign_key: HashMap::new(),
187 inner: txn,
188 };
189 m.tables(&mut builder);
190
191 let transaction = TXN.take().unwrap();
192
193 for drop in builder.drop {
194 let sql = drop.to_string(SqliteQueryBuilder);
195 transaction.get().execute(&sql, []).unwrap();
196 }
197 for (to, tmp) in builder.inner.rename_map {
198 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
199 let sql = rename.to_string(SqliteQueryBuilder);
200 transaction.get().execute(&sql, []).unwrap();
201 }
202 if let Some(fk) = foreign_key_check(transaction.get()) {
203 (builder.foreign_key.remove(&*fk).unwrap())();
204 }
205 #[allow(
206 unreachable_code,
207 reason = "rustc is stupid and thinks this is unreachable"
208 )]
209 for stmt in builder.inner.extra_index {
211 transaction.get().execute(&stmt, []).unwrap();
212 }
213 set_user_version(transaction.get(), M::To::VERSION).unwrap();
214
215 transaction.into_owner()
216 })
217 .join()
218 });
219 match res {
220 Ok(val) => self.transaction = val,
221 Err(payload) => std::panic::resume_unwind(payload),
222 }
223 }
224
225 Migrator {
226 indices_fixed: self.indices_fixed,
227 manager: self.manager,
228 transaction: self.transaction,
229 _p: PhantomData,
230 }
231 }
232
233 pub fn finish(mut self) -> Option<Database<S>> {
239 if user_version(self.transaction.get()).unwrap() != S::VERSION {
240 return None;
241 }
242
243 let res = std::thread::scope(|s| {
244 s.spawn(|| {
245 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
246 let txn = Transaction::new_ref();
247
248 check_schema::<S>(txn);
249 if !self.indices_fixed {
250 fix_indices::<S>(txn);
251 self.indices_fixed = true;
252 }
253
254 TXN.take().unwrap().into_owner()
255 })
256 .join()
257 });
258 match res {
259 Ok(val) => self.transaction = val,
260 Err(payload) => std::panic::resume_unwind(payload),
261 }
262
263 self.transaction
265 .get()
266 .execute_batch("PRAGMA optimize;")
267 .unwrap();
268
269 let schema_version = schema_version(self.transaction.get());
270 self.transaction.with(|x| x.commit().unwrap());
271
272 Some(Database {
273 manager: self.manager,
274 schema_version: AtomicI64::new(schema_version),
275 schema: PhantomData,
276 mut_lock: parking_lot::FairMutex::new(()),
277 })
278 }
279}
280
281fn fix_indices<S: Schema>(txn: &Transaction<S>) {
282 let schema = read_schema(txn);
283 let expected_schema = crate::schema::from_macro::Schema::new::<S>();
284
285 fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
286 let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
287 let actual: BTreeSet<_> = actual.indices.values().collect();
288 expected == actual
289 }
290
291 for (name, table) in schema.tables {
292 let expected_table = &expected_schema.tables[&name];
293
294 if !check_eq(expected_table, &table) {
295 for index_name in read_index_names_for_table(&crate::Transaction::new(), &name) {
297 let sql = sea_query::Index::drop()
298 .name(index_name)
299 .build(SqliteQueryBuilder);
300 txn.execute(&sql);
301 }
302
303 for sql in expected_table.create_indices(&name) {
305 txn.execute(&sql);
306 }
307 }
308 }
309
310 let schema = read_schema(txn);
312 for (name, table) in schema.tables {
313 let expected_table = &expected_schema.tables[&name];
314 assert!(check_eq(expected_table, &table));
315 }
316}
317
318impl<S> Transaction<S> {
319 pub(crate) fn execute(&self, sql: &str) {
320 TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
321 .unwrap();
322 }
323}
324
325pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
326 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
327 .unwrap()
328}
329
330pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
332 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
333}
334
335fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
337 conn.pragma_update(None, "user_version", v)
338}
339
340pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
341 let from_macro = crate::schema::from_macro::Schema::new::<S>();
342 let from_db = read_schema(txn);
343 let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
344 if !report.is_empty() {
345 let renderer = if cfg!(test) {
346 Renderer::plain().anonymized_line_numbers(true)
347 } else {
348 Renderer::styled()
349 }
350 .decor_style(DecorStyle::Unicode);
351 panic!("{}", renderer.render(&report));
352 }
353}
354
355fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
356 let error = conn
357 .prepare("PRAGMA foreign_key_check")
358 .unwrap()
359 .query_map([], |row| row.get(2))
360 .unwrap()
361 .next();
362 error.transpose().unwrap()
363}
364
365impl<S> Transaction<S> {
366 #[cfg(test)]
367 pub(crate) fn schema(&self) -> Vec<String> {
368 TXN.with_borrow(|x| {
369 x.as_ref()
370 .unwrap()
371 .get()
372 .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
373 .unwrap()
374 .query_map([], |row| row.get("sql"))
375 .unwrap()
376 .map(|x| x.unwrap())
377 .collect()
378 })
379 }
380}