tern_derive/internal/
parse.rs

1use regex::Regex;
2use std::path::{Path, PathBuf};
3use std::{env, ffi::OsStr, fs, sync::OnceLock};
4
5pub fn cargo_manifest_dir() -> PathBuf {
6    let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
7    PathBuf::from(manifest_dir)
8}
9
10fn filename_re() -> &'static Regex {
11    static RE: OnceLock<Regex> = OnceLock::new();
12    RE.get_or_init(|| Regex::new(r"^V(\d+)__(\w+)\.(sql|rs)$").unwrap())
13}
14
15#[derive(Debug, Clone)]
16pub struct SqlSource {
17    pub module: String,
18    pub version: i64,
19    pub description: String,
20    pub content: String,
21    pub no_tx: bool,
22}
23
24#[derive(Debug, Clone)]
25pub struct RustSource {
26    pub module: String,
27    pub version: i64,
28    pub description: String,
29    pub content: String,
30}
31
32#[derive(Debug, Clone)]
33pub enum MigrationSource {
34    Sql(SqlSource),
35    Rs(RustSource),
36}
37
38impl MigrationSource {
39    pub fn from_migration_dir(
40        migration_dir: impl AsRef<Path>,
41    ) -> Result<Vec<MigrationSource>, SourceError> {
42        let location = migration_dir.as_ref().canonicalize().map_err(|e| {
43            SourceError::Path(
44                format!(
45                    "invalid migration path {:?}",
46                    migration_dir.as_ref().to_path_buf()
47                ),
48                e.to_string(),
49            )
50        })?;
51        let mut sources = Self::parse_sources(location)?;
52
53        // order asc by version
54        sources.sort_by_key(|s| match s {
55            Self::Sql(s) => s.version,
56            Self::Rs(s) => s.version,
57        });
58        Validator::new(sources.iter().map(|v| v.migration_id()).collect::<Vec<_>>()).validate()?;
59
60        Ok(sources)
61    }
62
63    fn parse_sources(location: PathBuf) -> Result<Vec<MigrationSource>, SourceError> {
64        let sources = fs::read_dir(location)
65            .map_err(|_| SourceError::Directory("could not read migration directory".to_string()))?
66            .filter_map(|entry| {
67                let e = entry.ok()?;
68                if e.file_name()
69                    .to_str()
70                    .is_some_and(|f| f == "mod.rs" || f.starts_with("."))
71                {
72                    None
73                } else {
74                    Some(e.path())
75                }
76            })
77            .map(Self::parse)
78            .collect::<Result<Vec<_>, _>>()?;
79
80        Ok(sources)
81    }
82
83    fn migration_id(&self) -> (i64, String) {
84        match self {
85            Self::Sql(SqlSource {
86                version,
87                description,
88                ..
89            }) => (*version, description.clone()),
90            Self::Rs(RustSource {
91                version,
92                description,
93                ..
94            }) => (*version, description.clone()),
95        }
96    }
97
98    fn parse(filepath: impl AsRef<Path>) -> Result<Self, SourceError> {
99        let filepath = filepath.as_ref();
100        let module = filepath.file_stem().ok_or(SourceError::Name(format!(
101            "no filename stem found {:?}",
102            filepath.to_str()
103        )))?;
104        let (ver, description, ext) = filepath
105            .file_name()
106            .and_then(|n| {
107                let filename = OsStr::to_str(n)?;
108                let captures = filename_re().captures(filename)?;
109                let version = captures.get(1)?.as_str();
110                let description = captures.get(2)?.as_str();
111                let source_type = captures.get(3)?.as_str();
112                Some((version, description, source_type))
113            })
114            .ok_or(SourceError::Name(format!(
115                r"format is `^V(\d+)__(\w+)\.(sql|rs)$`, got {:?}",
116                filepath.to_str(),
117            )))?;
118        let version: i64 = ver
119            .parse()
120            .map_err(|_| SourceError::Name("invalid version, expected i64".to_string()))?;
121        let source_type = SourceType::from_ext(ext)?;
122        let content = fs::read_to_string(filepath).map_err(|e| SourceError::Io(e.to_string()))?;
123        let module = module
124            .to_str()
125            .ok_or(SourceError::Name(
126                "utf-8 decoding filename failed".to_string(),
127            ))?
128            .to_string();
129        let this = match source_type {
130            SourceType::Sql => {
131                let no_tx = Self::no_tx(&content);
132                let sql_source = SqlSource {
133                    module,
134                    version,
135                    description: description.to_string(),
136                    content,
137                    no_tx,
138                };
139                Self::Sql(sql_source)
140            }
141            _ => {
142                let rust_source = RustSource {
143                    module,
144                    version,
145                    description: description.to_string(),
146                    content,
147                };
148                Self::Rs(rust_source)
149            }
150        };
151
152        Ok(this)
153    }
154
155    /// For static SQL migrations, parse the first line to see if the special
156    /// `tern:noTransaction` annotation is present.
157    fn no_tx(content: &str) -> bool {
158        content
159            .lines()
160            .take(1)
161            .next()
162            .map(|l| l.contains("tern:noTransaction"))
163            .unwrap_or_default()
164    }
165}
166
167struct Validator {
168    ids: Vec<(i64, String)>,
169}
170
171impl Validator {
172    fn new(mut ids: Vec<(i64, String)>) -> Self {
173        ids.sort_by_key(|(v, _)| *v);
174        Self { ids }
175    }
176
177    fn duplicate_versions(&self) -> Result<(), SourceError> {
178        let mut m = std::collections::HashMap::new();
179        let mut offending_versions = Vec::new();
180        for (version, description) in &self.ids {
181            if m.insert(version, description).is_some() {
182                offending_versions.push(*version);
183            }
184        }
185        if !offending_versions.is_empty() {
186            return Err(Version {
187                message: "duplicate migration version found".to_string(),
188                offending_versions,
189            })?;
190        }
191
192        Ok(())
193    }
194
195    fn missing_versions(&self) -> Result<(), SourceError> {
196        let size = self.ids.len() as i64;
197        match self.ids.last() {
198            Some((v, _)) if *v != size => {
199                for (ix, (version, _)) in self.ids.iter().enumerate() {
200                    let expected = (ix + 1) as i64;
201                    if *version != expected {
202                        return Err(Version {
203                            message: format!(
204                                "expected version {expected} for a set with {size} migrations"
205                            ),
206                            offending_versions: vec![*version],
207                        })?;
208                    }
209                }
210                Ok(())
211            }
212            _ => Ok(()),
213        }
214    }
215
216    fn validate(&self) -> Result<(), SourceError> {
217        self.duplicate_versions()?;
218        self.missing_versions()?;
219        Ok(())
220    }
221}
222
223#[derive(Debug)]
224#[allow(dead_code)]
225pub enum SourceError {
226    Directory(String),
227    Path(String, String),
228    Name(String),
229    Ext(String),
230    Io(String),
231    Sql(i64, String),
232    Version(Version),
233}
234
235impl From<Version> for SourceError {
236    fn from(value: Version) -> Self {
237        Self::Version(value)
238    }
239}
240
241pub struct Version {
242    message: String,
243    offending_versions: Vec<i64>,
244}
245
246impl std::fmt::Debug for Version {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        let vs = self
249            .offending_versions
250            .iter()
251            .map(|v| v.to_string())
252            .collect::<Vec<_>>()
253            .join(",");
254        let versions_field = format!("[{vs}]");
255        f.debug_struct("Version")
256            .field("message", &self.message)
257            .field("offending_versions", &versions_field)
258            .finish()
259    }
260}
261
262#[derive(Debug, Clone, Copy)]
263enum SourceType {
264    Sql,
265    Rust,
266}
267
268impl SourceType {
269    pub fn from_ext(ext: &str) -> Result<Self, SourceError> {
270        match ext {
271            "sql" => Ok(Self::Sql),
272            "rs" => Ok(Self::Rust),
273            _ => Err(SourceError::Ext(format!(
274                "got file extension {ext}, expected `sql` or `rs`"
275            ))),
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::{SourceError, Validator, Version};
283
284    fn to_validator(vs: Vec<i64>) -> Validator {
285        let ids = vs
286            .into_iter()
287            .map(|v| (v, v.to_string()))
288            .collect::<Vec<_>>();
289        Validator::new(ids)
290    }
291
292    #[test]
293    fn duplicate_version() {
294        let vs = vec![1, 2, 3, 3, 4, 5, 6, 6, 7];
295        let validator = to_validator(vs);
296        let res = validator.duplicate_versions();
297        assert!(
298            matches!(res, Err(SourceError::Version(Version { offending_versions, ..})) if offending_versions == vec![3, 6])
299        );
300    }
301
302    #[test]
303    fn missing_version() {
304        let vs = vec![1, 2, 3, 4, 5, 6, 8, 9, 10, 11];
305        let validator = to_validator(vs);
306        let res = validator.missing_versions();
307        assert!(
308            matches!(res, Err(SourceError::Version(Version { offending_versions, ..})) if offending_versions == vec![8])
309        );
310    }
311
312    #[test]
313    fn source_ok() {
314        let vs = vec![1, 2, 3, 4, 5, 6];
315        let validator = to_validator(vs);
316        let res = validator.validate();
317        assert!(res.is_ok())
318    }
319}