trailbase_refinery/
util.rs

1use crate::Migration;
2use crate::error::{Error, Kind};
3use crate::runner::Type;
4use regex::Regex;
5use std::ffi::OsStr;
6use std::path::{Path, PathBuf};
7use std::sync::OnceLock;
8use walkdir::{DirEntry, WalkDir};
9
10const STEM_RE: &str = r"^([U|V])(\d+(?:\.\d+)?)__(\w+)";
11
12/// Matches the stem of a migration file.
13fn file_stem_re() -> &'static Regex {
14  static RE: OnceLock<Regex> = OnceLock::new();
15  RE.get_or_init(|| Regex::new(STEM_RE).unwrap())
16}
17
18/// Matches the stem + extension of a SQL migration file.
19fn file_re_sql() -> &'static Regex {
20  static RE: OnceLock<Regex> = OnceLock::new();
21  RE.get_or_init(|| Regex::new([STEM_RE, r"\.sql$"].concat().as_str()).unwrap())
22}
23
24/// Matches the stem + extension of any migration file.
25fn file_re_all() -> &'static Regex {
26  static RE: OnceLock<Regex> = OnceLock::new();
27  RE.get_or_init(|| Regex::new([STEM_RE, r"\.(rs|sql)$"].concat().as_str()).unwrap())
28}
29
30/// enum containing the migration types used to search for migrations
31/// either just .sql files or both .sql and .rs
32pub enum MigrationType {
33  All,
34  Sql,
35}
36
37impl MigrationType {
38  fn file_match_re(&self) -> &'static Regex {
39    match self {
40      MigrationType::All => file_re_all(),
41      MigrationType::Sql => file_re_sql(),
42    }
43  }
44}
45
46/// Parse a migration filename stem into a prefix, version, and name.
47pub fn parse_migration_name(name: &str) -> Result<(Type, i32, String), Error> {
48  let captures = file_stem_re()
49    .captures(name)
50    .filter(|caps| caps.len() == 4)
51    .ok_or_else(|| Error::new(Kind::InvalidName, None))?;
52  let version: i32 = captures[2]
53    .parse()
54    .map_err(|_| Error::new(Kind::InvalidVersion, None))?;
55
56  let name: String = (&captures[3]).into();
57  let prefix = match &captures[1] {
58    "V" => Type::Versioned,
59    "U" => Type::Unversioned,
60    _ => unreachable!(),
61  };
62
63  Ok((prefix, version, name))
64}
65
66/// find migrations on file system recursively across directories given a location and
67/// [MigrationType]
68pub fn find_migration_files(
69  location: impl AsRef<Path>,
70  migration_type: MigrationType,
71) -> Result<impl Iterator<Item = PathBuf>, Error> {
72  let location: &Path = location.as_ref();
73  let location = location.canonicalize().map_err(|err| {
74    Error::new(
75      Kind::InvalidMigrationPath(location.to_path_buf(), err),
76      None,
77    )
78  })?;
79
80  let re = migration_type.file_match_re();
81  let file_paths = WalkDir::new(location)
82        .into_iter()
83        .filter_map(Result::ok)
84        .map(DirEntry::into_path)
85        // filter by migration file regex
86        .filter(
87            move |entry| match entry.file_name().and_then(OsStr::to_str) {
88                Some(_) if entry.is_dir() => false,
89                Some(file_name) if re.is_match(file_name) => true,
90                Some(file_name) => {
91                    log::warn!(
92                        "File \"{}\" does not adhere to the migration naming convention. Migrations must be named in the format [U|V]{{1}}__{{2}}.sql or [U|V]{{1}}__{{2}}.rs, where {{1}} represents the migration version and {{2}} the name.",
93                        file_name
94                    );
95                    false
96                }
97                None => false,
98            },
99        );
100
101  Ok(file_paths)
102}
103
104/// Loads SQL migrations from a path. This enables dynamic migration discovery, as opposed to
105/// embedding. The resulting collection is ordered by version.
106pub fn load_sql_migrations(location: impl AsRef<Path>) -> Result<Vec<Migration>, Error> {
107  let migration_files = find_migration_files(location, MigrationType::Sql)?;
108
109  let mut migrations = vec![];
110
111  for path in migration_files {
112    let sql = std::fs::read_to_string(path.as_path()).map_err(|e| {
113      let path = path.to_owned();
114      let kind = match e.kind() {
115        std::io::ErrorKind::NotFound => Kind::InvalidMigrationPath(path, e),
116        _ => Kind::InvalidMigrationFile(path, e),
117      };
118
119      Error::new(kind, None)
120    })?;
121
122    //safe to call unwrap as find_migration_filenames returns canonical paths
123    let filename = path
124      .file_stem()
125      .and_then(|file| file.to_os_string().into_string().ok())
126      .unwrap();
127
128    let migration = Migration::unapplied(&filename, &sql)?;
129    migrations.push(migration);
130  }
131
132  migrations.sort();
133  Ok(migrations)
134}
135
136#[cfg(test)]
137mod tests {
138  use super::{MigrationType, find_migration_files, load_sql_migrations};
139  use std::fs;
140  use std::path::PathBuf;
141  use tempfile::TempDir;
142
143  #[test]
144  fn finds_mod_migrations() {
145    let tmp_dir = TempDir::new().unwrap();
146    let migrations_dir = tmp_dir.path().join("migrations");
147    fs::create_dir(&migrations_dir).unwrap();
148    let sql1 = migrations_dir.join("V1__first.rs");
149    fs::File::create(&sql1).unwrap();
150    let sql2 = migrations_dir.join("V2__second.rs");
151    fs::File::create(&sql2).unwrap();
152
153    let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
154      .unwrap()
155      .collect();
156    mods.sort();
157    assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
158    assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
159  }
160
161  #[test]
162  fn ignores_mod_files_without_migration_regex_match() {
163    let tmp_dir = TempDir::new().unwrap();
164    let migrations_dir = tmp_dir.path().join("migrations");
165    fs::create_dir(&migrations_dir).unwrap();
166    let sql1 = migrations_dir.join("V1first.rs");
167    fs::File::create(sql1).unwrap();
168    let sql2 = migrations_dir.join("V2second.rs");
169    fs::File::create(sql2).unwrap();
170
171    let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
172    assert!(mods.next().is_none());
173  }
174
175  #[test]
176  fn finds_sql_migrations() {
177    let tmp_dir = TempDir::new().unwrap();
178    let migrations_dir = tmp_dir.path().join("migrations");
179    fs::create_dir(&migrations_dir).unwrap();
180    let sql1 = migrations_dir.join("V1__first.sql");
181    fs::File::create(&sql1).unwrap();
182    let sql2 = migrations_dir.join("V2__second.sql");
183    fs::File::create(&sql2).unwrap();
184
185    let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
186      .unwrap()
187      .collect();
188    mods.sort();
189    assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
190    assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
191  }
192
193  #[test]
194  fn finds_unversioned_migrations() {
195    let tmp_dir = TempDir::new().unwrap();
196    let migrations_dir = tmp_dir.path().join("migrations");
197    fs::create_dir(&migrations_dir).unwrap();
198    let sql1 = migrations_dir.join("U1__first.sql");
199    fs::File::create(&sql1).unwrap();
200    let sql2 = migrations_dir.join("U2__second.sql");
201    fs::File::create(&sql2).unwrap();
202
203    let mut mods: Vec<PathBuf> = find_migration_files(migrations_dir, MigrationType::All)
204      .unwrap()
205      .collect();
206    mods.sort();
207    assert_eq!(sql1.canonicalize().unwrap(), mods[0]);
208    assert_eq!(sql2.canonicalize().unwrap(), mods[1]);
209  }
210
211  #[test]
212  fn ignores_sql_files_without_migration_regex_match() {
213    let tmp_dir = TempDir::new().unwrap();
214    let migrations_dir = tmp_dir.path().join("migrations");
215    fs::create_dir(&migrations_dir).unwrap();
216    let sql1 = migrations_dir.join("V1first.sql");
217    fs::File::create(sql1).unwrap();
218    let sql2 = migrations_dir.join("V2second.sql");
219    fs::File::create(sql2).unwrap();
220
221    let mut mods = find_migration_files(migrations_dir, MigrationType::All).unwrap();
222    assert!(mods.next().is_none());
223  }
224
225  #[test]
226  fn loads_migrations_from_path() {
227    let tmp_dir = TempDir::new().unwrap();
228    let migrations_dir = tmp_dir.path().join("migrations");
229    fs::create_dir(&migrations_dir).unwrap();
230    let sql1 = migrations_dir.join("V1__first.sql");
231    fs::File::create(&sql1).unwrap();
232    let sql2 = migrations_dir.join("V2__second.sql");
233    fs::File::create(&sql2).unwrap();
234    let rs3 = migrations_dir.join("V3__third.rs");
235    fs::File::create(&rs3).unwrap();
236
237    let migrations = load_sql_migrations(migrations_dir).unwrap();
238    assert_eq!(migrations.len(), 2);
239    assert_eq!(&migrations[0].to_string(), "V1__first");
240    assert_eq!(&migrations[1].to_string(), "V2__second");
241  }
242}