ts_sql/
migrations.rs

1//! Helpers for running migrations
2//!
3use std::{
4    env::current_dir,
5    ffi::OsStr,
6    fs::{self, DirEntry},
7    io,
8    path::PathBuf,
9};
10
11/// Error variants for migrating a database.
12#[derive(Debug)]
13#[non_exhaustive]
14#[allow(missing_docs)]
15pub enum MigrationError {
16    #[non_exhaustive]
17    ExecuteMigration {
18        source: postgres::Error,
19        sql: String,
20    },
21
22    #[non_exhaustive]
23    ReadMigrationDirectory { source: io::Error },
24
25    #[non_exhaustive]
26    ReadMigrationFile { source: io::Error },
27}
28impl core::fmt::Display for MigrationError {
29    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30        match &self {
31            Self::ReadMigrationDirectory { .. } => write!(f, "could not read migration directory"),
32            Self::ReadMigrationFile { .. } => write!(f, "could not read a migration file"),
33            Self::ExecuteMigration { sql, .. } => write!(f, "migration `{sql}` failed to execute"),
34        }
35    }
36}
37impl core::error::Error for MigrationError {
38    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
39        match &self {
40            Self::ReadMigrationFile { source, .. }
41            | Self::ReadMigrationDirectory { source, .. } => Some(source),
42            Self::ExecuteMigration { source, .. } => Some(source),
43        }
44    }
45}
46
47/// Runs the migrations in `current_dir()/migrations/*.sql` on the client, migrations are executed
48/// in name order.
49pub fn perform_migrations(
50    client: &mut postgres::Client,
51    migrations_directory: Option<PathBuf>,
52) -> Result<(), MigrationError> {
53    let Some(entries) = get_migration_targets(migrations_directory)? else {
54        return Ok(());
55    };
56
57    for entry in entries {
58        let sql = fs::read_to_string(entry.path())
59            .map_err(|source| MigrationError::ReadMigrationFile { source })?;
60        client
61            .batch_execute(&sql)
62            .map_err(|source| MigrationError::ExecuteMigration { source, sql })?;
63    }
64
65    Ok(())
66}
67
68#[cfg(feature = "async")]
69/// Runs the migrations in `current_dir()/migrations/*.sql` on the client, migrations are executed
70/// in name order.
71pub async fn perform_migrations_async(
72    client: &tokio_postgres::Client,
73    migrations_directory: Option<PathBuf>,
74) -> Result<(), MigrationError> {
75    let Some(entries) = get_migration_targets(migrations_directory)? else {
76        return Ok(());
77    };
78
79    for entry in entries {
80        let sql = fs::read_to_string(entry.path())
81            .map_err(|source| MigrationError::ReadMigrationFile { source })?;
82        client
83            .batch_execute(&sql)
84            .await
85            .map_err(|source| MigrationError::ExecuteMigration { source, sql })?;
86    }
87
88    Ok(())
89}
90
91/// Returns the files that should be used to migrate.
92fn get_migration_targets(
93    migrations_directory: Option<PathBuf>,
94) -> Result<Option<Vec<DirEntry>>, MigrationError> {
95    let path = match migrations_directory {
96        Some(path) => path,
97        None => {
98            let Ok(current_dir) = current_dir() else {
99                return Ok(None);
100            };
101            current_dir.join("migrations")
102        }
103    };
104
105    if !fs::exists(&path).map_err(|source| MigrationError::ReadMigrationDirectory { source })? {
106        return Ok(None);
107    }
108
109    let directory =
110        fs::read_dir(&path).map_err(|source| MigrationError::ReadMigrationDirectory { source })?;
111    let mut entries: Vec<_> = directory
112        .filter_map(|entry| match entry {
113            Ok(entry) => entry
114                .path()
115                .extension()
116                .is_some_and(|extension| extension == OsStr::new("sql"))
117                .then_some(Ok(entry)),
118            Err(error) => Some(Err(error)),
119        })
120        .collect::<Result<_, _>>()
121        .map_err(|source| MigrationError::ReadMigrationFile { source })?;
122    entries.sort_by_key(DirEntry::file_name);
123
124    Ok(Some(entries))
125}