1use crate::error::{DatabaseError as _, TernResult};
16
17use chrono::{DateTime, Utc};
18use futures_core::{future::BoxFuture, Future};
19use std::{fmt::Write, time::Instant};
20
21pub trait MigrationContext
23where
24 Self: MigrationSource<Ctx = Self> + Send + Sync + 'static,
25{
26 const HISTORY_TABLE: &str;
32
33 type Exec: Executor;
35
36 fn executor(&mut self) -> &mut Self::Exec;
38
39 fn apply<'migration, 'conn: 'migration, M>(
43 &'conn mut self,
44 migration: &'migration M,
45 ) -> BoxFuture<'migration, TernResult<AppliedMigration>>
46 where
47 M: Migration<Ctx = Self> + Send + Sync + ?Sized,
48 {
49 Box::pin(async move {
50 let start = Instant::now();
51 let query = M::build(migration, self).await?;
52 let executor = self.executor();
53
54 if migration.no_tx() {
55 executor
56 .apply_no_tx(&query)
57 .await
58 .void_tern_migration_result(migration)?;
59 } else {
60 executor
61 .apply_tx(&query)
62 .await
63 .void_tern_migration_result(migration)?;
64 }
65
66 let applied_at = Utc::now();
67 let duration_ms = start.elapsed().as_millis() as i64;
68 let applied = migration.to_applied(duration_ms, applied_at, query.sql());
69 executor
70 .insert_applied_migration(Self::HISTORY_TABLE, &applied)
71 .await?;
72
73 Ok(applied)
74 })
75 }
76
77 fn latest_version(&mut self) -> BoxFuture<'_, TernResult<Option<i64>>> {
79 Box::pin(async move {
80 let latest = self
81 .executor()
82 .get_all_applied(Self::HISTORY_TABLE)
83 .await?
84 .into_iter()
85 .fold(None, |acc, m| match acc {
86 None => Some(m.version),
87 Some(v) if m.version > v => Some(m.version),
88 _ => acc,
89 });
90
91 Ok(latest)
92 })
93 }
94
95 fn previously_applied(&mut self) -> BoxFuture<'_, TernResult<Vec<AppliedMigration>>> {
97 Box::pin(self.executor().get_all_applied(Self::HISTORY_TABLE))
98 }
99
100 fn check_history_table(&mut self) -> BoxFuture<'_, TernResult<()>> {
102 Box::pin(
103 self.executor()
104 .create_history_if_not_exists(Self::HISTORY_TABLE),
105 )
106 }
107
108 fn drop_history_table(&mut self) -> BoxFuture<'_, TernResult<()>> {
110 Box::pin(self.executor().drop_history(Self::HISTORY_TABLE))
111 }
112
113 fn insert_applied<'migration, 'conn: 'migration>(
115 &'conn mut self,
116 applied: &'migration AppliedMigration,
117 ) -> BoxFuture<'migration, TernResult<()>> {
118 Box::pin(
119 self.executor()
120 .insert_applied_migration(Self::HISTORY_TABLE, applied),
121 )
122 }
123
124 fn upsert_applied<'migration, 'conn: 'migration>(
126 &'conn mut self,
127 applied: &'migration AppliedMigration,
128 ) -> BoxFuture<'migration, TernResult<()>> {
129 Box::pin(
130 self.executor()
131 .upsert_applied_migration(Self::HISTORY_TABLE, applied),
132 )
133 }
134}
135
136pub trait Executor
139where
140 Self: Send + Sync + 'static,
141{
142 type Queries: QueryRepository;
145
146 fn apply_tx(&mut self, query: &Query) -> impl Future<Output = TernResult<()>> + Send;
148
149 fn apply_no_tx(&mut self, query: &Query) -> impl Future<Output = TernResult<()>> + Send;
151
152 fn create_history_if_not_exists(
154 &mut self,
155 history_table: &str,
156 ) -> impl Future<Output = TernResult<()>> + Send;
157
158 fn drop_history(&mut self, history_table: &str) -> impl Future<Output = TernResult<()>> + Send;
160
161 fn get_all_applied(
163 &mut self,
164 history_table: &str,
165 ) -> impl Future<Output = TernResult<Vec<AppliedMigration>>> + Send;
166
167 fn insert_applied_migration(
169 &mut self,
170 history_table: &str,
171 applied: &AppliedMigration,
172 ) -> impl Future<Output = TernResult<()>> + Send;
173
174 fn upsert_applied_migration(
176 &mut self,
177 history_table: &str,
178 applied: &AppliedMigration,
179 ) -> impl Future<Output = TernResult<()>> + Send;
180}
181
182pub trait QueryRepository {
185 fn create_history_if_not_exists_query(history_table: &str) -> Query;
188
189 fn drop_history_query(history_table: &str) -> Query;
191
192 fn insert_into_history_query(history_table: &str, applied: &AppliedMigration) -> Query;
194
195 fn select_star_from_history_query(history_table: &str) -> Query;
197
198 fn upsert_history_query(history_table: &str, applied: &AppliedMigration) -> Query;
200}
201
202pub trait Migration
204where
205 Self: Send + Sync,
206{
207 type Ctx: MigrationContext;
209
210 fn migration_id(&self) -> MigrationId;
212
213 fn content(&self) -> String;
216
217 fn no_tx(&self) -> bool;
219
220 fn build<'a>(&'a self, ctx: &'a mut Self::Ctx) -> BoxFuture<'a, TernResult<Query>>;
222
223 fn version(&self) -> i64 {
225 self.migration_id().version()
226 }
227
228 fn to_applied(
231 &self,
232 duration_ms: i64,
233 applied_at: DateTime<Utc>,
234 content: &str,
235 ) -> AppliedMigration {
236 AppliedMigration::new(self.migration_id(), content, duration_ms, applied_at)
237 }
238}
239
240pub trait MigrationSource {
243 type Ctx: MigrationContext;
246
247 fn migration_set(&self, last_applied: Option<i64>) -> MigrationSet<Self::Ctx>;
249}
250
251pub struct MigrationSet<Ctx: ?Sized> {
254 pub migrations: Vec<Box<dyn Migration<Ctx = Ctx>>>,
255}
256
257impl<Ctx> MigrationSet<Ctx>
258where
259 Ctx: MigrationContext,
260{
261 pub fn new<T>(vs: T) -> MigrationSet<Ctx>
262 where
263 T: Into<Vec<Box<dyn Migration<Ctx = Ctx>>>>,
264 {
265 let mut migrations = vs.into();
266 migrations.sort_by_key(|m| m.version());
267 MigrationSet { migrations }
268 }
269
270 pub fn len(&self) -> usize {
272 self.migrations.len()
273 }
274
275 pub fn versions(&self) -> Vec<i64> {
277 self.migrations
278 .iter()
279 .map(|m| m.version())
280 .collect::<Vec<_>>()
281 }
282
283 pub fn migration_ids(&self) -> Vec<MigrationId> {
285 self.migrations
286 .iter()
287 .map(|m| m.migration_id())
288 .collect::<Vec<_>>()
289 }
290
291 pub fn max(&self) -> Option<i64> {
293 self.versions().iter().max().copied()
294 }
295
296 pub fn is_empty(&self) -> bool {
298 self.len() == 0
299 }
300}
301
302pub trait QueryBuilder {
308 type Ctx: MigrationContext;
310
311 fn build(&self, ctx: &mut Self::Ctx) -> impl Future<Output = TernResult<Query>> + Send;
313}
314
315#[derive(Debug, Clone)]
317pub struct Query(pub(crate) String);
318
319impl Query {
320 pub fn new(sql: String) -> Self {
321 Self(sql)
322 }
323
324 fn sanitize(&self) -> String {
325 use regex::Regex;
326 let block_comment = Regex::new(r"\/\*(?s).*?(?-s)\*\/").unwrap();
327 let sql = self
328 .sql()
329 .trim()
330 .lines()
331 .filter(|line| {
332 let line = line.trim();
333 !line.starts_with("--") || line.is_empty()
334 })
335 .map(|line| {
336 let mut stripped = line.to_string();
338 let offset = stripped.find("--").unwrap_or(stripped.len());
339 stripped.replace_range(offset.., "");
340 stripped.trim_end().to_string()
341 })
342 .collect::<Vec<_>>()
343 .join("\n");
344 let stripped = block_comment.replace_all(&sql, "");
345
346 if !stripped.ends_with(";") {
347 format!("{stripped};")
348 } else {
349 stripped.to_string()
350 }
351 }
352
353 pub fn sql(&self) -> &str {
354 &self.0
355 }
356
357 pub fn append(&mut self, other: Self) -> TernResult<()> {
359 let mut buf = String::new();
360 writeln!(buf, "{}", self.0)?;
361 writeln!(buf, "{}", other.0)?;
362 self.0 = buf;
363 Ok(())
364 }
365
366 pub fn split_statements(&self) -> TernResult<Vec<String>> {
382 let mut statements = Vec::new();
383 self.sanitize()
384 .lines()
385 .try_fold(String::new(), |mut buf, line| {
386 if line.trim().is_empty() {
387 return Ok(buf);
388 }
389 writeln!(buf, "{line}")?;
390 if line.ends_with(";") {
393 statements.push(buf);
394 Ok::<String, std::fmt::Error>(String::new())
395 } else {
396 Ok(buf)
397 }
398 })?;
399
400 Ok(statements)
401 }
402}
403
404impl std::fmt::Display for Query {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 self.0.fmt(f)
407 }
408}
409
410#[derive(Debug, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
412pub struct MigrationId {
413 version: i64,
415 description: String,
417}
418
419impl MigrationId {
420 pub fn new(version: i64, description: String) -> Self {
421 Self {
422 version,
423 description,
424 }
425 }
426
427 pub fn version(&self) -> i64 {
428 self.version
429 }
430
431 pub fn description(&self) -> String {
432 self.description.clone()
433 }
434}
435
436impl std::fmt::Display for MigrationId {
437 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
438 write!(f, "V{}__{}", self.version, self.description)
439 }
440}
441
442impl From<AppliedMigration> for MigrationId {
443 fn from(value: AppliedMigration) -> Self {
444 Self {
445 version: value.version,
446 description: value.description,
447 }
448 }
449}
450
451#[derive(Debug, Clone)]
454#[cfg_attr(feature = "sqlx", derive(sqlx::FromRow))]
455pub struct AppliedMigration {
456 pub version: i64,
458 pub description: String,
460 pub content: String,
462 pub duration_ms: i64,
464 pub applied_at: DateTime<Utc>,
466}
467
468impl AppliedMigration {
469 pub fn new(
470 id: MigrationId,
471 content: &str,
472 duration_ms: i64,
473 applied_at: DateTime<Utc>,
474 ) -> Self {
475 Self {
476 version: id.version,
477 description: id.description,
478 content: content.into(),
479 duration_ms,
480 applied_at,
481 }
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::Query;
488
489 const SQL_IN1: &str = "
490-- This is a comment.
491SELECT
492 *
493FROM
494 the_schema.the_table
495WHERE
496 everything = 'is_good'
497";
498 const SQL_OUT1: &str = "SELECT
499 *
500FROM
501 the_schema.the_table
502WHERE
503 everything = 'is_good';
504";
505 const SQL_IN2: &str = "
506-- tern:noTransaction
507SELECT count(e.*),
508 e.x,
509 e.y -- This is the column called `y`
510FROM /* A comment block can even be like this */ the_table
511 as e
512JOIN another USING (id)
513/*
514This is a multi
515line
516comment
517*/
518WHERE false;
519
520SELECT a
521from x
522-- Asdfsdfsdfsdfsdsdf /* Unnecessary comment */
523where false
524
525;
526";
527 const SQL_OUT21: &str = "SELECT count(e.*),
528 e.x,
529 e.y
530FROM the_table
531 as e
532JOIN another USING (id)
533WHERE false;
534";
535
536 const SQL_OUT22: &str = "SELECT a
537from x
538where false
539;
540";
541
542 #[test]
543 fn split_one() {
544 let q1 = Query::new(SQL_IN1.to_string());
545 let res1 = q1.split_statements();
546 assert!(res1.is_ok());
547 let split1 = res1.unwrap();
548 assert_eq!(split1, vec![SQL_OUT1.to_string()]);
549 }
550
551 #[test]
552 fn split_two() {
553 let q2 = Query::new(SQL_IN2.to_string());
554 let res2 = q2.split_statements();
555 assert!(res2.is_ok());
556 let split2 = res2.unwrap();
557 assert_eq!(split2, vec![SQL_OUT21.to_string(), SQL_OUT22.to_string()]);
558 }
559}