1pub mod config;
2pub mod migration;
3
4use std::{collections::HashMap, marker::PhantomData, sync::atomic::AtomicI64};
5
6use rusqlite::{Connection, config::DbConfig};
7use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder};
8use self_cell::MutBorrow;
9
10use crate::{
11 Table, Transaction, hash,
12 migrate::{
13 config::Config,
14 migration::{SchemaBuilder, TransactionMigrate},
15 },
16 schema_pragma::{read_index_names_for_table, read_schema},
17 transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
18};
19
20pub struct TableTypBuilder<S> {
21 pub(crate) ast: hash::Schema,
22 _p: PhantomData<S>,
23}
24
25impl<S> Default for TableTypBuilder<S> {
26 fn default() -> Self {
27 Self {
28 ast: Default::default(),
29 _p: Default::default(),
30 }
31 }
32}
33
34impl<S> TableTypBuilder<S> {
35 pub fn table<T: Table<Schema = S>>(&mut self) {
36 let table = hash::Table::new::<T>();
37 let old = self.ast.tables.insert(T::NAME.to_owned(), table);
38 debug_assert!(old.is_none());
39 }
40}
41
42pub trait Schema: Sized + 'static {
43 const VERSION: i64;
44 fn typs(b: &mut TableTypBuilder<Self>);
45}
46
47fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
48 let mut create = table.create();
49 create
50 .table(alias)
51 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
52 let mut sql = create.to_string(SqliteQueryBuilder);
53 sql.push_str(" STRICT");
54 conn.execute(&sql, []).unwrap();
55}
56
57pub trait SchemaMigration<'a> {
58 type From: Schema;
59 type To: Schema;
60
61 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
62}
63
64impl<S: Schema> Database<S> {
65 pub fn migrator(config: Config) -> Option<Migrator<S>> {
69 let synchronous = config.synchronous.as_str();
70 let foreign_keys = config.foreign_keys.as_str();
71 let manager = config.manager.with_init(move |inner| {
72 inner.pragma_update(None, "journal_mode", "WAL")?;
73 inner.pragma_update(None, "synchronous", synchronous)?;
74 inner.pragma_update(None, "foreign_keys", foreign_keys)?;
75 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
76 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
77 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
78 Ok(())
79 });
80
81 use r2d2::ManageConnection;
82 let conn = manager.connect().unwrap();
83 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
84 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
85 Some(
86 conn.borrow_mut()
87 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
88 .unwrap(),
89 )
90 });
91
92 if schema_version(txn.get()) == 0 {
94 let schema = crate::hash::Schema::new::<S>();
95
96 for (table_name, table) in &schema.tables {
97 let table_name_ref = Alias::new(table_name);
98 new_table_inner(txn.get(), table, table_name_ref);
99 for stmt in table.create_indices(table_name) {
100 txn.get().execute(&stmt, []).unwrap();
101 }
102 }
103 (config.init)(txn.get());
104 set_user_version(txn.get(), S::VERSION).unwrap();
105 }
106
107 let user_version = user_version(txn.get()).unwrap();
108 if user_version < S::VERSION {
110 return None;
111 }
112 debug_assert_eq!(
113 foreign_key_check(txn.get()),
114 None,
115 "foreign key constraint violated"
116 );
117
118 Some(Migrator {
119 indices_fixed: false,
120 manager,
121 transaction: txn,
122 _p: PhantomData,
123 })
124 }
125}
126
127pub struct Migrator<S> {
132 manager: r2d2_sqlite::SqliteConnectionManager,
133 transaction: OwnedTransaction,
134 indices_fixed: bool,
135 _p: PhantomData<S>,
136}
137
138impl<S: Schema> Migrator<S> {
139 pub fn migrate<'x, M>(
143 mut self,
144 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
145 ) -> Migrator<M::To>
146 where
147 M: SchemaMigration<'x, From = S>,
148 {
149 if user_version(self.transaction.get()).unwrap() == S::VERSION {
150 let res = std::thread::scope(|s| {
151 s.spawn(|| {
152 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
153 let txn = Transaction::new_ref();
154
155 check_schema::<S>(txn);
156 if !self.indices_fixed {
157 fix_indices::<S>(txn);
158 self.indices_fixed = true;
159 }
160
161 let mut txn = TransactionMigrate {
162 inner: Transaction::new(),
163 scope: Default::default(),
164 rename_map: HashMap::new(),
165 extra_index: Vec::new(),
166 };
167 let m = m(&mut txn);
168
169 let mut builder = SchemaBuilder {
170 drop: vec![],
171 foreign_key: HashMap::new(),
172 inner: txn,
173 };
174 m.tables(&mut builder);
175
176 let transaction = TXN.take().unwrap();
177
178 for drop in builder.drop {
179 let sql = drop.to_string(SqliteQueryBuilder);
180 transaction.get().execute(&sql, []).unwrap();
181 }
182 for (to, tmp) in builder.inner.rename_map {
183 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
184 let sql = rename.to_string(SqliteQueryBuilder);
185 transaction.get().execute(&sql, []).unwrap();
186 }
187 if let Some(fk) = foreign_key_check(transaction.get()) {
188 (builder.foreign_key.remove(&*fk).unwrap())();
189 }
190 #[allow(
191 unreachable_code,
192 reason = "rustc is stupid and thinks this is unreachable"
193 )]
194 for stmt in builder.inner.extra_index {
196 transaction.get().execute(&stmt, []).unwrap();
197 }
198 set_user_version(transaction.get(), M::To::VERSION).unwrap();
199
200 transaction.into_owner()
201 })
202 .join()
203 });
204 match res {
205 Ok(val) => self.transaction = val,
206 Err(payload) => std::panic::resume_unwind(payload),
207 }
208 }
209
210 Migrator {
211 indices_fixed: self.indices_fixed,
212 manager: self.manager,
213 transaction: self.transaction,
214 _p: PhantomData,
215 }
216 }
217
218 pub fn finish(mut self) -> Option<Database<S>> {
224 if user_version(self.transaction.get()).unwrap() != S::VERSION {
225 return None;
226 }
227
228 let res = std::thread::scope(|s| {
229 s.spawn(|| {
230 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
231 let txn = Transaction::new_ref();
232
233 check_schema::<S>(txn);
234 if !self.indices_fixed {
235 fix_indices::<S>(txn);
236 self.indices_fixed = true;
237 }
238
239 TXN.take().unwrap().into_owner()
240 })
241 .join()
242 });
243 match res {
244 Ok(val) => self.transaction = val,
245 Err(payload) => std::panic::resume_unwind(payload),
246 }
247
248 self.transaction
250 .get()
251 .execute_batch("PRAGMA optimize;")
252 .unwrap();
253
254 let schema_version = schema_version(self.transaction.get());
255 self.transaction.with(|x| x.commit().unwrap());
256
257 Some(Database {
258 manager: self.manager,
259 schema_version: AtomicI64::new(schema_version),
260 schema: PhantomData,
261 mut_lock: parking_lot::FairMutex::new(()),
262 })
263 }
264}
265
266fn fix_indices<S: Schema>(txn: &Transaction<S>) {
267 let schema = read_schema(txn);
268 let expected_schema = crate::hash::Schema::new::<S>();
269
270 for (name, table) in schema.tables {
271 let expected_table = &expected_schema.tables[&name];
272
273 if expected_table.indices != table.indices {
274 for index_name in read_index_names_for_table(&crate::Transaction::new(), &name) {
276 let sql = sea_query::Index::drop()
277 .name(index_name)
278 .build(SqliteQueryBuilder);
279 txn.execute(&sql);
280 }
281
282 for sql in expected_table.create_indices(&name) {
284 txn.execute(&sql);
285 }
286 }
287 }
288
289 assert_eq!(expected_schema, read_schema(txn));
290}
291
292impl<S> Transaction<S> {
293 pub(crate) fn execute(&self, sql: &str) {
294 TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
295 .unwrap();
296 }
297}
298
299pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
300 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
301 .unwrap()
302}
303
304pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
306 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
307}
308
309fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
311 conn.pragma_update(None, "user_version", v)
312}
313
314pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
315 pretty_assertions::assert_eq!(
317 crate::hash::Schema::new::<S>().normalize(),
318 read_schema(txn).normalize(),
319 "schema is different (expected left, but got right)",
320 );
321}
322
323fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
324 let error = conn
325 .prepare("PRAGMA foreign_key_check")
326 .unwrap()
327 .query_map([], |row| row.get(2))
328 .unwrap()
329 .next();
330 error.transpose().unwrap()
331}
332
333impl<S> Transaction<S> {
334 #[cfg(test)]
335 pub(crate) fn schema(&self) -> Vec<String> {
336 TXN.with_borrow(|x| {
337 x.as_ref()
338 .unwrap()
339 .get()
340 .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
341 .unwrap()
342 .query_map([], |row| row.get("sql"))
343 .unwrap()
344 .map(|x| x.unwrap())
345 .collect()
346 })
347 }
348}
349
350impl<S: Send + Sync + Schema> Database<S> {
351 #[cfg(test)]
352 fn check_schema(&self, expect: expect_test::Expect) {
353 let mut schema = self.transaction(|txn| txn.schema());
354 schema.sort();
355 expect.assert_eq(&schema.join("\n"));
356 }
357}
358
359#[test]
360fn fix_indices_test() {
361 mod without_index {
362 #[crate::migration::schema(Schema)]
363 pub mod vN {
364 pub struct Foo {
365 pub bar: String,
366 }
367 }
368 }
369
370 mod with_index {
371 #[crate::migration::schema(Schema)]
372 pub mod vN {
373 pub struct Foo {
374 #[index]
375 pub bar: String,
376 }
377 }
378 }
379
380 let db = Database::<without_index::v0::Schema>::migrator(Config::open("index_test.sqlite"))
381 .unwrap()
382 .finish()
383 .unwrap();
384 db.check_schema(expect_test::expect![[
386 r#"CREATE TABLE "foo" ( "bar" text NOT NULL, "id" integer PRIMARY KEY ) STRICT"#
387 ]]);
388
389 let db_with_index =
390 Database::<with_index::v0::Schema>::migrator(Config::open("index_test.sqlite"))
391 .unwrap()
392 .finish()
393 .unwrap();
394 db_with_index.check_schema(expect_test::expect![[r#"
397 CREATE INDEX "foo_index_0" ON "foo" ("bar")
398 CREATE TABLE "foo" ( "bar" text NOT NULL, "id" integer PRIMARY KEY ) STRICT"#]]);
399
400 db.check_schema(expect_test::expect![[r#"
402 CREATE INDEX "foo_index_0" ON "foo" ("bar")
403 CREATE TABLE "foo" ( "bar" text NOT NULL, "id" integer PRIMARY KEY ) STRICT"#]]);
404
405 let db = Database::<without_index::v0::Schema>::migrator(Config::open("index_test.sqlite"))
406 .unwrap()
407 .finish()
408 .unwrap();
409 db.check_schema(expect_test::expect![[
411 r#"CREATE TABLE "foo" ( "bar" text NOT NULL, "id" integer PRIMARY KEY ) STRICT"#
412 ]]);
413}