1use crate::prelude::*;
25use db_support::{get_user_version, set_user_version_stmt, table_column_names};
26
27use bsql::{ErasedError1 as EE1, Context1 as _};
29
30use Exception as E;
31type Err = DbError<EE1>;
32
33#[derive(Debug, Clone)]
34pub struct MigrationData {
35 pub schemas: Vec<Schema>,
36 pub migrations: Vec<Migration>,
37}
38
39#[derive(Debug, Clone)]
40pub struct Schema {
41 pub version: SchemaVersion,
42 pub schema: String,
43}
44
45#[derive(Debug, Clone)]
46pub struct Migration {
47 pub old_version: SchemaVersion,
48 pub new_version: SchemaVersion,
49 pub exceptions: Vec<Exception>,
50}
51
52#[derive(Debug, Clone, Copy)]
53pub enum OnComplete {
54 Rollback,
55 Commit,
56}
57
58#[derive(Debug, Clone, Copy, Eq, PartialEq, derive_more::Display)]
59pub enum Outcome {
60 #[display("already version {_0}")]
61 Already(SchemaVersion),
62 #[display("migrated to version {_0}")]
63 Migrated(SchemaVersion),
64 #[display("migration tested, from {current} to {tested}")]
65 Tested {
66 current: SchemaVersion,
67 tested: SchemaVersion,
68 },
69}
70
71#[derive(Debug, Clone)]
72pub enum Exception {
73 ReplacementColumnValue {
74 table: &'static str,
75 col: &'static str,
76 val_sql: String,
83 },
84 NewTable {
86 table: &'static str,
87 },
88 GeneralAfterCopyOut {
94 sql: String,
95 },
96 GeneralFinal {
102 sql: String,
103 },
104 ExpectedRowCount {
105 table: &'static str,
106 select_sql: String,
108 },
109 UncheckedRowCount {
110 table: &'static str,
111 },
112 IncreasedRowCount {
113 table: &'static str,
114 },
115}
116
117const SQLITE_SEQUENCE: &str = "sqlite_sequence";
119
120#[derive(Debug, Clone, Eq, PartialEq)]
121enum RowCountException {
122 Unchecked,
123 Expect(String),
125 Increased,
126}
127
128#[derive(Debug, Eq, PartialEq, Default)]
129struct PreprocessedExceptions {
130 replacement_col_vals: HashMap<(String, String), String>,
132 expected_row_count: HashMap<String, RowCountException>,
134 no_copy_tables: HashSet<String>,
136 general_final: String,
138 general_after_copy_out: String,
140}
141
142impl MigrationData {
143 pub fn check(
144 &self,
145 target_version: SchemaVersion,
146 max_version: SchemaVersion,
147 ) -> Result<(), EE1> {
148 expect1!(
149 self.schemas.iter()
150 .map(|sch| sch.version)
151 .all_unique()
152 )?;
153 let mut mig_from_to = HashSet::new();
154 for mig in &self.migrations {
155 expect1!(mig.new_version == mig.old_version + 1)?;
156 expect1!(mig.new_version <= max_version)?;
157 expect1!(mig_from_to.insert((mig.old_version, mig.new_version)))?;
158 let _: PreprocessedExceptions = mig.preprocess_exceptions()?;
159 };
160 expect1!(
161 self.migrations.iter().next().is_none() ||
162 self.get_schema(target_version).is_ok()
163 )?;
164 Ok(())
165 }
166
167 fn get_schema(&self, version: SchemaVersion) -> Result<&str, EE1> {
168 let found = &self.schemas.iter()
169 .find(|sch| sch.version == version)
170 .ok_or_else(|| anyerror1!(
171 "missing schema for {version}"
172 ))?;
173 Ok(&found.schema)
174 }
175}
176
177impl Migration {
178 fn preprocess_exceptions(&self) -> Result<PreprocessedExceptions, EE1> {
179 let mut replacement_col_vals = HashMap::new();
180 let mut expected_row_count = HashMap::new();
181 let mut no_copy_tables = HashSet::new();
182 let mut general_final = None;
183 let mut general_after_copy_out = None;
184 let os = |s: &str| -> String { s.to_owned() };
185
186 let store_general = |g: &mut Option<String>, sql: &str| {
187 expect1!(g.is_none())?;
188 *g = Some(sql.into());
189 Ok::<_, EE1>(())
190 };
191 let mut row_count = |table: &str, ex| {
192 expect1!(
193 expected_row_count.insert((*table).to_owned(), ex)
194 .is_none()
195 )?;
196 Ok::<_, EE1>(())
197 };
198
199 for e in &self.exceptions {
200 match e {
201 E::ReplacementColumnValue { table, col, val_sql } => expect1!(
202 replacement_col_vals.insert(
203 (os(table), os(col)),
204 val_sql.clone(),
205 )
206 .is_none()
207 )?,
208 E::NewTable { table } => expect1!(
209 no_copy_tables.insert(os(table))
210 )?,
211 E::ExpectedRowCount { table, select_sql } => row_count(
212 table,
213 RowCountException::Expect(select_sql.clone()),
214 )?,
215 E::UncheckedRowCount { table } => row_count(
216 table,
217 RowCountException::Unchecked,
218 )?,
219 E::IncreasedRowCount { table } => row_count(
220 table,
221 RowCountException::Increased,
222 )?,
223 E::GeneralFinal { sql } => {
224 store_general(&mut general_final, sql)?;
225 },
226 E::GeneralAfterCopyOut { sql } => {
227 store_general(&mut general_after_copy_out, sql)?;
228 },
229 }
230 }
231
232 let general_final = general_final.unwrap_or_default();
233 let general_after_copy_out
234 = general_after_copy_out.unwrap_or_default();
235
236 Ok(PreprocessedExceptions {
237 replacement_col_vals,
238 expected_row_count,
239 no_copy_tables,
240 general_final,
241 general_after_copy_out,
242 })
243 }
244}
245
246fn set_foreign_keys(conn: &rusqlite::Connection, enable: bool)
247 -> Result<(), Err>
248{
249 let stmt = format!("PRAGMA foreign_keys={}",
250 if enable { "ON" } else { "OFF" });
251 conn.execute_batch(&stmt)
252 .with_db_context(|| stmt.clone())?;
253 Ok(())
254}
255
256fn exec_batch_logged(
257 progress: &mut dyn io::Write,
258 conn: &rusqlite::Connection,
259 what: &str,
260 stmt: &str
261)
262 -> Result<(), Err>
263{
264 writeln!(progress, "migration, executing statement, {what}:\n{stmt}\n")
265 .context1("report progress")?;
266 conn.execute_batch(stmt)
267 .db_context(what)
268}
269
270impl OnComplete {
271 pub fn from_commit_bool(commit: bool) -> Self {
272 if commit {
273 OnComplete::Commit
274 } else {
275 OnComplete::Rollback
276 }
277 }
278}
279
280impl Outcome {
281 fn version_after(&self) -> SchemaVersion {
282 *match self {
283 Outcome::Migrated(m) => m,
284 Outcome::Already(m) => m,
285 Outcome::Tested { tested, .. } => tested,
286 }
287 }
288}
289
290pub fn prepare_idempotent(
308 db_file: &Path,
309 temp_dir: &Path,
310 timeout: &bsql::Timeout,
311 progress: &mut dyn io::Write,
312 target_version: SchemaVersion,
313 max_version: SchemaVersion,
314 migration_data: &MigrationData,
315 on_step_complete: OnComplete,
316) -> Result<Outcome, AE> {
317 let mk_conn = || -> Result<_, Err> {
318 rusqlite::Connection::open(db_file)
319 .db_context("access db during schema preparation")
320 };
321
322 macro_rules! retry_loop { { $f:expr } => {
324 timeout.generic_retry_loop_erasederror1($f)
325 } }
326
327 loop {
328 let mut conn = retry_loop!(mk_conn)?;
329
330 let stored_version = retry_loop!(|| -> Result<_, Err> {
331 let dbt = conn
332 .transaction_with_behavior(
333 rusqlite::TransactionBehavior::Immediate
334 )
335 .db_context("start transaction, for schema setup")?;
336
337 let stored_version = get_user_version(&dbt)
338 .map_err(DbError::Sql)?;
339
340 if stored_version == 0 {
341 writeln!(progress, "initialising schema in empty database")
342 .context1("write")?;
343
344 let target_schema = migration_data.get_schema(target_version)
345 .map_err(DbError::Other)?;
346
347 dbt.execute_batch(&target_schema)
351 .db_context("install fresh schema")?;
352 dbt
353 .execute_batch(&set_user_version_stmt(target_version))
354 .db_context("install fresh schema: set user_version")?;
355 dbt.commit()
356 .db_context("install fresh schema: commit")?;
357 return Ok(target_version);
358 }
359
360 Ok::<_, Err>(stored_version)
361 })?;
362
363 if stored_version >= target_version {
364 if stored_version > max_version {
365 Err(anyhow!(
366 "db already contains schema version {stored_version} and we only support {target_version}"
367 ))?;
368 }
369 return Ok(Outcome::Already(stored_version));
370 }
371
372 writeln!(
375 progress,
376 "attempting migration from schema version {stored_version}"
377 )
378 .context("write")?;
379
380 let mut conn = Some(conn);
381 let outcome = retry_loop!(|| {
382 let conn = conn.take().map(Ok)
383 .unwrap_or_else(|| mk_conn())?;
384
385 migration_core(
386 conn,
387 temp_dir,
388 progress,
389 stored_version,
390 stored_version + 1,
391 migration_data,
392 on_step_complete,
393 )
394 })?;
395
396 if outcome.version_after() == target_version {
401 return Ok(outcome);
402 }
403 match outcome {
404 Outcome::Migrated(_) | Outcome::Already(_) => {}, Outcome::Tested { current, tested } => Err(anyhow!(
407 "needed multi-step migration: db has {current}, tested migration to {tested} OK, but target was {target_version}"
408 ))?,
409 }
410 }
411}
412
413pub fn migration_core<'w>(
414 mut main_conn: rusqlite::Connection,
415 temp_dir: &Path,
416 progress: &'w mut dyn io::Write,
417 old_version: SchemaVersion,
418 new_version: SchemaVersion,
419 migration_data: &MigrationData,
420 on_complete: OnComplete,
421) -> Result<Outcome, Err> {
422 let temp_dir = temp_dir.to_str()
423 .ok_or_else(|| anyerror1!("temp dir path not utf-8 {temp_dir:?}"))?;
424 let tmp_db = format!("{temp_dir}/migration-aside.db");
425
426 macro_rules! progressln { { $($m:tt)* } => {
427 writeln!(progress, $($m)*).context1("report progress")
428 } }
429
430 progressln!("migration, preparing")?;
431
432 migration_data.check(new_version, SchemaVersion::MAX)?;
435
436 let migration = migration_data
437 .migrations.iter()
438 .find(|exc|
439 exc.old_version == old_version &&
440 exc.new_version == new_version)
441 .ok_or_else(|| anyerror1!(
442 "unsupported migration {old_version} => {new_version}"
443 ))?;
444
445 let new_schema = migration_data.get_schema(new_version)?;
446
447 let mut es = migration.preprocess_exceptions()
448 .context1("preprocess exceptions")?;
449
450 let preclean_glob = format!("{tmp_db}*");
453 for preclean in glob::glob(&preclean_glob)
454 .with_context1(|| preclean_glob.clone())
455 .context1("glob tmp dirs")?
456 {
457 let preclean = preclean.context1("identify path to clean")?;
458 match fs::remove_file(&preclean) {
459 Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(()),
460 other => other,
461 }
462 .with_context1(|| preclean.to_string_lossy().to_string())
463 .context1("clean tmp db file")?;
464 }
465
466 let tmp_conn = rusqlite::Connection::open(&tmp_db)
472 .with_db_context(|| format!("open tmp db {tmp_db:?}"))?;
473 tmp_conn.execute_batch(&new_schema)
474 .db_context("execute current schema")?;
475 tmp_conn.close()
476 .map_err(|(_, e)| e)
477 .db_context("close on tmp db current schema")?;
478
479 main_conn.execute("ATTACH DATABASE ? AS aside", [&tmp_db])
482 .db_context("attach temp database")?;
483
484 set_foreign_keys(&main_conn, false)?;
487
488 let dbt = main_conn
491 .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
492 .db_context("start transaction")?;
493
494 let found_old_version = get_user_version(&dbt)
497 .map_err(DbError::Sql)?;
498 if found_old_version >= new_version {
499 return Ok(Outcome::Already(found_old_version));
500 }
501 if found_old_version != old_version {
502 Err(anyerror1!(
503 "running schema migration from {old_version} but db has {found_old_version}"
504 ))?;
505 }
506
507 let tables = {
512 let mut tables = dbt
513 .prepare(r#"
515 SELECT name FROM aside.sqlite_schema WHERE type = 'table'
516 ORDER BY name ASC
517 "#)
518 .db_context("prepare find tables")?
519 .query_map([], |row| row.get::<_, String>(0))
520 .db_context("find tables")?
521 .collect::<Result<Vec<_>, _>>()
522 .db_context("convert table name")?
523 .into_iter()
524 .filter_map(|table| Some(match &*table {
527 SQLITE_SEQUENCE => (0, table),
528 s if s.starts_with("sqlite_") => return None,
529 _ => (1, table),
530 }))
531 .collect_vec();
532
533 tables.sort();
535 tables.into_iter().map(|(_, t)| t).collect_vec()
536 };
537
538 let expected_row_count: HashMap<_, (u64, Option<_>)> = {
541 let mut out = HashMap::new();
542 for table in &tables {
543 let get_exp = |sql: &str| -> Result<u64, Err> {
544 let exp = dbt.query_one(sql, [], |row| row.get(0))
545 .db_context("query existing row count")?;
546 Ok(exp)
547 };
548 let get_current = || get_exp(&format!(
549 "SELECT count(*) FROM main.{table}"
550 ));
551 let exp = match es.expected_row_count.remove(&**table) {
552 None => {
553 let e = get_current()?;
554 (e, Some(e))
555 },
556 Some(RowCountException::Expect(exp)) => {
557 let e = get_exp(&exp)?;
558 (e, Some(e))
559 },
560 Some(RowCountException::Increased) => {
561 let e = get_current()?;
562 (e, None)
563 },
564 Some(RowCountException::Unchecked) => continue,
565 };
566 out.insert(table, exp.into());
567 }
568 out
569 };
570
571 progressln!("migration, copying data out")?;
574
575 for table in &tables {
576 if es.no_copy_tables.remove(&*table) { continue };
577
578 let new_cols = table_column_names(&dbt, &format!("aside.{table}"))
579 .map_err(DbError::Sql)?;
580 let new_col_names = new_cols.iter().map(|s| &**s)
581 .intersperse(", ").collect::<String>();
582
583 let mut s = String::new();
584
585 s += &format!(r#"
586 INSERT INTO aside.{table}
587 ( {new_col_names} )
588 SELECT
589"#);
590
591 let new_col_settings = new_cols.iter().map(|col| {
592 let col_id = (table.clone(), col.clone());
593 let new_val = match es.replacement_col_vals.remove(&col_id) {
594 None => format!("old.{col}"),
595 Some(val_sql) => format!("{val_sql}"),
596 };
597 format!(
598r#" {new_val} AS {col}"#)
599 }).collect_vec();
600
601 s.extend(
602 new_col_settings.iter().map(|s| &**s)
603 .intersperse(",\n")
604 );
605 s += &format!(r#"
606 FROM main.{table} AS old
607"#);
608
609 exec_batch_logged(progress, &dbt, "copy data out", &s)?;
610 }
611
612 exec_batch_logged(progress, &dbt, "GeneralAfterCopyOut",
613 &mem::take(&mut es.general_after_copy_out))?;
614
615 progressln!("migration, turning around")?;
616
617 for table in &tables {
627 if table == SQLITE_SEQUENCE {
628 continue
630 }
631
632 exec_batch_logged(progress, &dbt, "clear old table", &format!(r#"
633 DROP TABLE IF EXISTS main.{table}
634"#
635 ))?;
636 }
637
638 dbt.execute_batch(&new_schema)
643 .db_context("set up new schema in main")?;
644
645 progressln!("migration, copying data back")?;
648
649 for table in &tables {
652 exec_batch_logged(progress, &dbt, "copy data in", &format!(
653r#" INSERT INTO main.{table} SELECT * FROM aside.{table}"#
654 ))?;
655 }
656
657 exec_batch_logged(progress, &dbt, "GeneralFinal",
660 &mem::take(&mut es.general_final))?;
661
662 progressln!("migration, checking")?;
665
666 exec_batch_logged(progress, &dbt, "check FK", r#"
668 PRAGMA foreign_key_check;
669"#)?;
670 set_foreign_keys(&dbt, true)?;
673
674 expect1!(expected_row_count.iter().next().is_some())?;
676 for (table, (exp_min, exp_max)) in expected_row_count {
677 let got: u64 = dbt.query_one(
678 &format!("SELECT count(*) FROM main.{table}"),
679 [], |row| row.get(0),
680 ).db_context("get new row count")?;
681 (got >= exp_min).then_some(()).ok_or_else(|| anyerror1!(
682 "table {table}: expected >= {exp_min} rows after migration but had {got}"
683 ))?;
684 if let Some(exp_max) = exp_max {
685 (got <= exp_max).then_some(()).ok_or_else(|| anyerror1!(
686 "table {table}: expected <= {exp_max} rows after migration but had {got}"
687 ))?;
688 }
689 }
690
691 expect1!(es == PreprocessedExceptions::default(),
692 "unused exception(s) {es:#?}")?;
693
694 dbt.execute_batch(&set_user_version_stmt(new_version))
698 .db_context("set user version after migration")?;
699
700 match on_complete {
702 OnComplete::Commit => {
703 progressln!("migration complete, committing")?;
704 dbt.commit().db_context("commit")?;
705 progressln!("migrated to schema version {new_version}")?;
706 Ok(Outcome::Migrated(new_version))
707 },
708 OnComplete::Rollback => {
709 progressln!("migration check, apparently successful")?;
710 Ok(Outcome::Tested {
711 current: old_version,
712 tested: new_version,
713 })
714 }
715 }
716}
717
718#[cfg(test)]
719mod db_migr_test;