trailbase_refinery/traits/
async.rs1use crate::error::WrapMigrationError;
2use crate::traits::{
3 ASSERT_MIGRATIONS_TABLE_QUERY, GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
4 insert_migration_query, verify_migrations,
5};
6use crate::{Error, Migration, Report, Target};
7
8use async_trait::async_trait;
9use std::string::ToString;
10
11#[async_trait]
12pub trait AsyncTransaction {
13 type Error: std::error::Error + Send + Sync + 'static;
14
15 async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
16 &mut self,
17 queries: T,
18 ) -> Result<usize, Self::Error>;
19}
20
21#[async_trait]
22pub trait AsyncQuery<T>: AsyncTransaction {
23 async fn query(&mut self, query: &str) -> Result<T, Self::Error>;
24}
25
26async fn migrate<T: AsyncTransaction>(
27 transaction: &mut T,
28 migrations: Vec<Migration>,
29 target: Target,
30 migration_table_name: &str,
31) -> Result<Report, Error> {
32 let mut applied_migrations = vec![];
33
34 for mut migration in migrations.into_iter() {
35 if let Target::Version(input_target) = target {
36 if input_target < migration.version() {
37 log::info!(
38 "stopping at migration: {}, due to user option",
39 input_target
40 );
41 break;
42 }
43 }
44
45 log::info!("applying migration: {}", migration);
46 migration.set_applied();
47 let update_query = insert_migration_query(&migration, migration_table_name);
48 transaction
49 .execute(
50 [
51 migration.sql().as_ref().expect("sql must be Some!"),
52 update_query.as_str(),
53 ]
54 .into_iter(),
55 )
56 .await
57 .migration_err(
58 &format!("error applying migration {}", migration),
59 Some(&applied_migrations),
60 )?;
61 applied_migrations.push(migration);
62 }
63 Ok(Report::new(applied_migrations))
64}
65
66async fn migrate_grouped<T: AsyncTransaction>(
67 transaction: &mut T,
68 migrations: Vec<Migration>,
69 target: Target,
70 migration_table_name: &str,
71) -> Result<Report, Error> {
72 let mut grouped_migrations = Vec::new();
73 let mut applied_migrations = Vec::new();
74
75 for mut migration in migrations.into_iter() {
76 if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
77 if input_target < migration.version() {
78 break;
79 }
80 }
81
82 migration.set_applied();
83 let query = insert_migration_query(&migration, migration_table_name);
84
85 let sql = migration.sql().expect("sql must be Some!").to_string();
86
87 if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
89 applied_migrations.push(migration);
90 grouped_migrations.push(sql);
91 }
92 grouped_migrations.push(query);
93 }
94
95 match target {
96 Target::Fake | Target::FakeVersion(_) => {
97 log::info!("not going to apply any migration as fake flag is enabled");
98 }
99 Target::Latest | Target::Version(_) => {
100 log::info!(
101 "going to apply batch migrations in single transaction: {:#?}",
102 applied_migrations.iter().map(ToString::to_string)
103 );
104 }
105 };
106
107 if let Target::Version(input_target) = target {
108 log::info!(
109 "stopping at migration: {}, due to user option",
110 input_target
111 );
112 }
113
114 let refs = grouped_migrations.iter().map(AsRef::as_ref);
115
116 transaction
117 .execute(refs)
118 .await
119 .migration_err("error applying migrations", None)?;
120
121 Ok(Report::new(applied_migrations))
122}
123
124#[async_trait]
125pub trait AsyncMigrate: AsyncQuery<Vec<Migration>>
126where
127 Self: Sized,
128{
129 fn assert_migrations_table_query(migration_table_name: &str) -> String {
132 ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
133 }
134
135 async fn get_last_applied_migration(
136 &mut self,
137 migration_table_name: &str,
138 ) -> Result<Option<Migration>, Error> {
139 let mut migrations = self
140 .query(
141 &GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name),
142 )
143 .await
144 .migration_err("error getting last applied migration", None)?;
145
146 Ok(migrations.pop())
147 }
148
149 async fn get_applied_migrations(
150 &mut self,
151 migration_table_name: &str,
152 ) -> Result<Vec<Migration>, Error> {
153 let migrations = self
154 .query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name))
155 .await
156 .migration_err("error getting applied migrations", None)?;
157
158 Ok(migrations)
159 }
160
161 async fn migrate(
162 &mut self,
163 migrations: &[Migration],
164 abort_divergent: bool,
165 abort_missing: bool,
166 grouped: bool,
167 target: Target,
168 migration_table_name: &str,
169 ) -> Result<Report, Error> {
170 self
171 .execute([Self::assert_migrations_table_query(migration_table_name).as_str()].into_iter())
172 .await
173 .migration_err("error asserting migrations table", None)?;
174
175 let applied_migrations = self
176 .query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name))
177 .await
178 .migration_err("error getting current schema version", None)?;
179
180 let migrations = verify_migrations(
181 applied_migrations,
182 migrations.to_vec(),
183 abort_divergent,
184 abort_missing,
185 )?;
186
187 if migrations.is_empty() {
188 log::info!("no migrations to apply");
189 }
190
191 if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
192 migrate_grouped(self, migrations, target, migration_table_name).await
193 } else {
194 migrate(self, migrations, target, migration_table_name).await
195 }
196 }
197}