sqlx_repo/
migrator.rs

1use std::marker::PhantomData;
2
3use crate::{Result, SqlxDBNum};
4use futures::future::BoxFuture;
5use sqlx::migrate::{Migration as SqlxMigration, MigrationSource, MigrationType};
6
7#[derive(Debug)]
8pub struct RepoMigrationSource<D> {
9    migrations: Vec<Migration>,
10    marker: PhantomData<D>,
11}
12
13#[derive(Debug, Clone, Copy)]
14pub struct Migration {
15    pub name: &'static str,
16    pub queries: &'static [&'static str],
17}
18
19impl<'a, D: SqlxDBNum> MigrationSource<'a> for RepoMigrationSource<D> {
20    fn resolve(self) -> BoxFuture<'a, Result<Vec<SqlxMigration>, sqlx::error::BoxDynError>> {
21        Box::pin(async move {
22            let migrations = self.migrations
23                .iter()
24                .enumerate()
25                .map(|(pos, migration)| {
26                    let query_pos = D::pos();
27                    let query = match migration.queries.get(query_pos) {
28                        Some(&query) => query,
29                        None => Err("failed to generate migration, tried to get query at index {query_pos}, which doesn't exist")?
30                    };
31                    Ok(SqlxMigration::new(pos as _, migration.name.into(), MigrationType::Simple, query.into(), false))
32                })
33                .collect::<Result<_>>()?;
34            Ok(migrations)
35        })
36    }
37}
38
39impl<D: SqlxDBNum> Default for RepoMigrationSource<D> {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl<D: SqlxDBNum> RepoMigrationSource<D> {
46    pub fn new() -> Self {
47        Self {
48            migrations: vec![],
49            marker: PhantomData,
50        }
51    }
52
53    pub fn add_migration(&mut self, migration: Migration) {
54        self.migrations.push(migration);
55    }
56}
57
58pub async fn init_migrator<D: SqlxDBNum>(
59    migrations: &[Migration],
60) -> Result<sqlx::migrate::Migrator, sqlx::migrate::MigrateError> {
61    let mut source = RepoMigrationSource::<D>::new();
62    migrations
63        .iter()
64        .for_each(|migration| source.add_migration(*migration));
65    sqlx::migrate::Migrator::new(source).await
66}