trailbase_refinery_core/
util.rs

1use crate::error::{Error, Kind};
2use crate::runner::Type;
3use crate::Migration;
4use regex::Regex;
5use std::ffi::OsStr;
6use std::path::{Path, PathBuf};
7use std::sync::OnceLock;
8use walkdir::{DirEntry, WalkDir};
9
10const STEM_RE: &'static str = r"^([U|V])(\d+(?:\.\d+)?)__(\w+)";
11
12/// Matches the stem of a migration file.
13fn file_stem_re() -> &'static Regex {
14    static RE: OnceLock<Regex> = OnceLock::new();
15    RE.get_or_init(|| Regex::new(STEM_RE).unwrap())
16}
17
18/// Matches the stem + extension of a SQL migration file.
19fn file_re_sql() -> &'static Regex {
20    static RE: OnceLock<Regex> = OnceLock::new();
21    RE.get_or_init(|| Regex::new([STEM_RE, r"\.sql$"].concat().as_str()).unwrap())
22}
23
24/// Matches the stem + extension of any migration file.
25fn file_re_all() -> &'static Regex {
26    static RE: OnceLock<Regex> = OnceLock::new();
27    RE.get_or_init(|| Regex::new([STEM_RE, r"\.(rs|sql)$"].concat().as_str()).unwrap())
28}
29
30/// enum containing the migration types used to search for migrations
31/// either just .sql files or both .sql and .rs
32pub enum MigrationType {
33    All,
34    Sql,
35}
36
37impl MigrationType {
38    fn file_match_re(&self) -> &'static Regex {
39        match self {
40            MigrationType::All => file_re_all(),
41            MigrationType::Sql => file_re_sql(),
42        }
43    }
44}
45
46/// Parse a migration filename stem into a prefix, version, and name.
47pub fn parse_migration_name(name: &str) -> Result<(Type, i32, String), Error> {
48    let captures = file_stem_re()
49        .captures(name)
50        .filter(|caps| caps.len() == 4)
51        .ok_or_else(|| Error::new(Kind::InvalidName, None))?;
52    let version: i32 = captures[2]
53        .parse()
54        .map_err(|_| Error::new(Kind::InvalidVersion, None))?;
55
56    let name: String = (&captures[3]).into();
57    let prefix = match &captures[1] {
58        "V" => Type::Versioned,
59        "U" => Type::Unversioned,
60        _ => unreachable!(),
61    };
62
63    Ok((prefix, version, name))
64}
65
66/// find migrations on file system recursively across directories given a location and
67/// [MigrationType]
68pub fn find_migration_files(
69    location: impl AsRef<Path>,
70    migration_type: MigrationType,
71) -> Result<impl Iterator<Item = PathBuf>, Error> {
72    let location: &Path = location.as_ref();
73    let location = location.canonicalize().map_err(|err| {
74        Error::new(
75            Kind::InvalidMigrationPath(location.to_path_buf(), err),
76            None,
77        )
78    })?;
79
80    let re = migration_type.file_match_re();
81    let file_paths = WalkDir::new(location)
82        .into_iter()
83        .filter_map(Result::ok)
84        .map(DirEntry::into_path)
85        // filter by migration file regex
86        .filter(
87            move |entry| match entry.file_name().and_then(OsStr::to_str) {
88                Some(_) if entry.is_dir() => false,
89                Some(file_name) if re.is_match(file_name) => true,
90                Some(file_name) => {
91                    log::warn!(
92                        "File \"{}\" does not adhere to the migration naming convention. Migrations must be named in the format [U|V]{{1}}__{{2}}.sql or [U|V]{{1}}__{{2}}.rs, where {{1}} represents the migration version and {{2}} the name.",
93                        file_name
94                    );
95                    false
96                }
97                None => false,
98            },
99        );
100
101    Ok(file_paths)
102}
103
104/// Loads SQL migrations from a path. This enables dynamic migration discovery, as opposed to
105/// embedding. The resulting collection is ordered by version.
106pub fn load_sql_migrations(location: impl AsRef<Path>) -> Result<Vec<Migration>, Error> {
107    let migration_files = find_migration_files(location, MigrationType::Sql)?;
108
109    let mut migrations = vec![];
110
111    for path in migration_files {
112        let sql = std::fs::read_to_string(path.as_path()).map_err(|e| {
113            let path = path.to_owned();
114            let kind = match e.kind() {
115                std::io::ErrorKind::NotFound => Kind::InvalidMigrationPath(path, e),
116                _ => Kind::InvalidMigrationFile(path, e),
117            };
118
119            Error::new(kind, None)
120        })?;
121
122        //safe to call unwrap as find_migration_filenames returns canonical paths
123        let filename = path
124            .file_stem()
125            .and_then(|file| file.to_os_string().into_string().ok())
126            .unwrap();
127
128        let migration = Migration::unapplied(&filename, &sql)?;
129        migrations.push(migration);
130    }
131
132    migrations.sort();
133    Ok(migrations)
134}
135
136#[cfg(test)]
137mod tests {
138    use super::{find_migration_files, load_sql_migrations, MigrationType};
139    use std::fs;
140    use std::path::PathBuf;
141    use tempfile::TempDir;
142
143    #[test]
144    fn finds_mod_migrations() {
145        let tmp_dir = TempDir::new().unwrap();
146        let migrations_dir = tmp_dir.path().join("migrations");
147        fs::create_dir(&migrations_dir).unwrap();
148        let sql1 = migrations_dir.join("V1__first.rs");
149        fs::File::create(&sql1).unwrap();
150        let sql2 = migrations_dir.join("V2__second.rs");
151        fs::File::create(&sql2).unwrap();
152
153        let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
154            .unwrap()
155            .collect();
156        mods.sort();
157        assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
158        assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
159    }
160
161    #[test]
162    fn ignores_mod_files_without_migration_regex_match() {
163        let tmp_dir = TempDir::new().unwrap();
164        let migrations_dir = tmp_dir.path().join("migrations");
165        fs::create_dir(&migrations_dir).unwrap();
166        let sql1 = migrations_dir.join("V1first.rs");
167        fs::File::create(sql1).unwrap();
168        let sql2 = migrations_dir.join("V2second.rs");
169        fs::File::create(sql2).unwrap();
170
171        let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
172        assert!(mods.next().is_none());
173    }
174
175    #[test]
176    fn finds_sql_migrations() {
177        let tmp_dir = TempDir::new().unwrap();
178        let migrations_dir = tmp_dir.path().join("migrations");
179        fs::create_dir(&migrations_dir).unwrap();
180        let sql1 = migrations_dir.join("V1__first.sql");
181        fs::File::create(&sql1).unwrap();
182        let sql2 = migrations_dir.join("V2__second.sql");
183        fs::File::create(&sql2).unwrap();
184
185        let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
186            .unwrap()
187            .collect();
188        mods.sort();
189        assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
190        assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
191    }
192
193    #[test]
194    fn finds_unversioned_migrations() {
195        let tmp_dir = TempDir::new().unwrap();
196        let migrations_dir = tmp_dir.path().join("migrations");
197        fs::create_dir(&migrations_dir).unwrap();
198        let sql1 = migrations_dir.join("U1__first.sql");
199        fs::File::create(&sql1).unwrap();
200        let sql2 = migrations_dir.join("U2__second.sql");
201        fs::File::create(&sql2).unwrap();
202
203        let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
204            .unwrap()
205            .collect();
206        mods.sort();
207        assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
208        assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
209    }
210
211    #[test]
212    fn ignores_sql_files_without_migration_regex_match() {
213        let tmp_dir = TempDir::new().unwrap();
214        let migrations_dir = tmp_dir.path().join("migrations");
215        fs::create_dir(&migrations_dir).unwrap();
216        let sql1 = migrations_dir.join("V1first.sql");
217        fs::File::create(sql1).unwrap();
218        let sql2 = migrations_dir.join("V2second.sql");
219        fs::File::create(sql2).unwrap();
220
221        let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
222        assert!(mods.next().is_none());
223    }
224
225    #[test]
226    fn loads_migrations_from_path() {
227        let tmp_dir = TempDir::new().unwrap();
228        let migrations_dir = tmp_dir.path().join("migrations");
229        fs::create_dir(&migrations_dir).unwrap();
230        let sql1 = migrations_dir.join("V1__first.sql");
231        fs::File::create(&sql1).unwrap();
232        let sql2 = migrations_dir.join("V2__second.sql");
233        fs::File::create(&sql2).unwrap();
234        let rs3 = migrations_dir.join("V3__third.rs");
235        fs::File::create(&rs3).unwrap();
236
237        let migrations = load_sql_migrations(migrations_dir).unwrap();
238        assert_eq!(migrations.len(), 2);
239        assert_eq!(&migrations[0].to_string(), "V1__first");
240        assert_eq!(&migrations[1].to_string(), "V2__second");
241    }
242}