Skip to main content

sql_fun_core/extensions/
name_version_pair.rs

1use combine::{
2    EasyParser, ParseError, Parser, Stream, choice, many1,
3    parser::char::{alpha_num, char},
4    satisfy,
5};
6
7use std::{
8    fmt::{Display, Write},
9    path::{Path, PathBuf},
10};
11
12use crate::{ExtensionConfigError, ExtensionVersion};
13
14/// Postgres extensin name version pair.
15#[derive(Debug, serde::Deserialize, Clone, Default)]
16pub struct ExtensionNameVersionPair {
17    name: String,
18    version: Option<ExtensionVersion>,
19}
20
21impl Display for ExtensionNameVersionPair {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.write_str(&self.name)?;
24        if let Some(ver) = &self.version {
25            f.write_char('@')?;
26            write!(f, "{ver}")?;
27        }
28        Ok(())
29    }
30}
31
32#[cfg(test)]
33mod test_extension_name_version_pair_display {
34    use crate::ExtensionVersion;
35
36    #[test]
37    fn test_extension_name_version_pair_display() {
38        let item = super::ExtensionNameVersionPair {
39            name: String::from("ext"),
40            version: Some(ExtensionVersion::new("1.0")),
41        };
42        let display_result = item.to_string();
43        assert_eq!("ext@1.0", display_result)
44    }
45}
46
47impl ExtensionNameVersionPair {
48    /// validate str for extension name
49    ///
50    /// # Errors
51    ///
52    /// Returns [`ExtensionConfigError::InvalidExtensionName`] when the provided
53    /// name contains unsupported characters.
54    pub fn validate_name(name: &str) -> Result<(), ExtensionConfigError> {
55        if name == "." || name == ".." || name.contains([':', '/', '\\']) {
56            ExtensionConfigError::invalid_extension_name(name)?;
57        }
58        Ok(())
59    }
60}
61
62#[cfg(test)]
63mod test_validate_name {
64    use super::ExtensionNameVersionPair;
65
66    #[test]
67    fn test_validate_name_invalid() {
68        assert!(ExtensionNameVersionPair::validate_name("..").is_err());
69        assert!(ExtensionNameVersionPair::validate_name("name:invalid").is_err());
70    }
71
72    #[test]
73    fn test_validate_name_valid() {
74        assert!(ExtensionNameVersionPair::validate_name("ext_name").is_ok());
75    }
76}
77
78impl ExtensionNameVersionPair {
79    /// creates instance
80    ///
81    /// # Errors
82    ///
83    /// Returns [`ExtensionConfigError`] if the extension name is invalid.
84    pub fn try_new(
85        name: &str,
86        version: &Option<ExtensionVersion>,
87    ) -> Result<Self, ExtensionConfigError> {
88        Self::validate_name(name)?;
89
90        Ok(Self {
91            name: name.to_string(),
92            version: version.clone(),
93        })
94    }
95}
96
97#[cfg(test)]
98mod test_try_new {
99    use super::ExtensionNameVersionPair;
100
101    #[test]
102    fn test_try_new_ok() {
103        let version = Some(super::ExtensionVersion::new("1.0"));
104        let result = ExtensionNameVersionPair::try_new("ext", &version).unwrap();
105        assert_eq!("ext", result.name);
106        assert!(result.version.is_some());
107    }
108
109    #[test]
110    fn test_try_new_invalid_name() {
111        let version = None;
112        let result = ExtensionNameVersionPair::try_new("..", &version);
113        assert!(result.is_err());
114    }
115}
116
117impl ExtensionNameVersionPair {
118    /// get extension sql path
119    ///
120    /// # Errors
121    ///
122    /// Returns [`ExtensionConfigError`] when the SQL file cannot be located or
123    /// metadata resolution fails.
124    pub fn resolve_version<P: AsRef<Path>>(
125        &mut self,
126        base_path: P,
127    ) -> Result<PathBuf, ExtensionConfigError> {
128        if let Some(version) = &self.version {
129            let path = base_path
130                .as_ref()
131                .join(PathBuf::from(format!("{}--{version}.sql", self.name)));
132            if path.exists() {
133                Ok(path)
134            } else {
135                ExtensionConfigError::extension_sql_not_exist(&self.name, version, path)
136            }
137        } else {
138            self.resolve_version_from_extension_files(base_path)
139        }
140    }
141
142    ///
143    /// # Errors
144    ///
145    /// Returns [`ExtensionConfigError`] when directory traversal fails or when
146    /// the version cannot be inferred uniquely.
147    fn resolve_version_from_extension_files<P: AsRef<Path>>(
148        &mut self,
149        base_path: P,
150    ) -> Result<PathBuf, ExtensionConfigError> {
151        let name_prefix = format!("{}--", self.name);
152        let dir = base_path.as_ref().read_dir()?;
153        let mut version_candidates = Vec::new();
154        for entry in dir {
155            let entry = entry?;
156            if !entry.file_type()?.is_file() {
157                continue;
158            }
159            let name_path = PathBuf::from(entry.file_name());
160            let Some(ext) = name_path.extension().and_then(|v| v.to_str()) else {
161                continue;
162            };
163            let Some(stem) = name_path.file_stem().and_then(|v| v.to_str()) else {
164                continue;
165            };
166            if !ext.eq_ignore_ascii_case("sql") || !stem.starts_with(&name_prefix) {
167                continue;
168            }
169            let version = &stem[name_prefix.len()..];
170            if !version.is_empty() {
171                version_candidates.push(version.to_string());
172            }
173        }
174
175        if version_candidates.len() == 1 {
176            let version = &version_candidates[0];
177            self.version = Some(ExtensionVersion::new(version));
178            let path = base_path
179                .as_ref()
180                .join(PathBuf::from(format!("{}--{version}.sql", self.name)));
181            Ok(path)
182        } else if version_candidates.is_empty() {
183            ExtensionConfigError::extension_not_found(&self.name, base_path)
184        } else {
185            ExtensionConfigError::multiple_extension_version(
186                &self.name,
187                base_path,
188                &version_candidates,
189            )
190        }
191    }
192}
193
194#[cfg(test)]
195mod test_resolve_version {
196    use std::fs::File;
197
198    use testresult::TestResult;
199
200    use super::ExtensionNameVersionPair;
201
202    #[test]
203    fn test_resolve_version_with_version() -> TestResult {
204        let tmp_dir = tempfile::tempdir()?;
205        let file_path = tmp_dir.path().join("ext--1.0.sql");
206        File::create(&file_path)?;
207
208        let version = Some(super::ExtensionVersion::try_from("1.0")?);
209        let mut pair = ExtensionNameVersionPair::try_new("ext", &version)?;
210        let resolved = pair.resolve_version(tmp_dir.path())?;
211
212        assert_eq!(file_path, resolved);
213        Ok(())
214    }
215
216    #[test]
217    fn test_resolve_version_missing_file() -> TestResult {
218        let tmp_dir = tempfile::tempdir()?;
219        let version = Some(super::ExtensionVersion::try_from("1.0")?);
220        let mut pair = ExtensionNameVersionPair::try_new("ext", &version)?;
221
222        let result = pair.resolve_version(tmp_dir.path());
223        assert!(result.is_err());
224        Ok(())
225    }
226
227    #[test]
228    fn test_resolve_version_from_extension_files_single() -> TestResult {
229        let tmp_dir = tempfile::tempdir()?;
230        let file_path = tmp_dir.path().join("ext--1.0.sql");
231        File::create(&file_path)?;
232
233        let version = None;
234        let mut pair = ExtensionNameVersionPair::try_new("ext", &version)?;
235        let resolved = pair.resolve_version_from_extension_files(tmp_dir.path())?;
236
237        assert_eq!(file_path, resolved);
238        assert_eq!(Some(super::ExtensionVersion::new("1.0")), pair.version);
239        Ok(())
240    }
241
242    #[test]
243    fn test_resolve_version_from_extension_files_multiple() -> TestResult {
244        let tmp_dir = tempfile::tempdir()?;
245        File::create(tmp_dir.path().join("ext--1.0.sql"))?;
246        File::create(tmp_dir.path().join("ext--2.0.sql"))?;
247
248        let version = None;
249        let mut pair = ExtensionNameVersionPair::try_new("ext", &version)?;
250        let result = pair.resolve_version_from_extension_files(tmp_dir.path());
251
252        assert!(result.is_err());
253        Ok(())
254    }
255
256    #[test]
257    fn test_resolve_version_from_extension_files_missing() -> TestResult {
258        let tmp_dir = tempfile::tempdir()?;
259
260        let version = None;
261        let mut pair = ExtensionNameVersionPair::try_new("ext", &version)?;
262        let result = pair.resolve_version_from_extension_files(tmp_dir.path());
263
264        assert!(result.is_err());
265        Ok(())
266    }
267}
268
269impl ExtensionNameVersionPair {
270    fn name_parser<Input>() -> impl Parser<Input, Output = String>
271    where
272        Input: Stream<Token = char>,
273        Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
274    {
275        many1(choice((alpha_num(), char('_'), char('-'), char('.'))))
276    }
277
278    fn version_parser<Input>() -> impl Parser<Input, Output = String>
279    where
280        Input: Stream<Token = char>,
281        Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
282    {
283        many1(satisfy(|c: char| c != ',' && !c.is_whitespace()))
284    }
285}
286
287#[cfg(test)]
288mod test_parsers {
289    use combine::EasyParser;
290
291    #[test]
292    fn test_name_parser() {
293        let (parsed, remaining) = super::ExtensionNameVersionPair::name_parser()
294            .easy_parse("extension-1.0 rest")
295            .unwrap();
296        assert_eq!("extension-1.0", parsed);
297        assert_eq!(" rest", remaining);
298    }
299
300    #[test]
301    fn test_version_parser() {
302        let (parsed, remaining) = super::ExtensionNameVersionPair::version_parser()
303            .easy_parse("1.0,rest")
304            .unwrap();
305        assert_eq!("1.0", parsed);
306        assert_eq!(",rest", remaining);
307    }
308}
309
310impl ExtensionNameVersionPair {
311    pub(crate) fn parser<Input>()
312    -> impl Parser<Input, Output = Result<ExtensionNameVersionPair, ExtensionConfigError>>
313    where
314        Input: Stream<Token = char>,
315        Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
316    {
317        (
318            Self::name_parser(),
319            combine::optional(char('@').with(Self::version_parser())),
320        )
321            .map(|(name, ver)| {
322                let version = ver
323                    .map(|v| ExtensionVersion::try_from(v.as_str()))
324                    .transpose()?;
325                ExtensionNameVersionPair::try_new(&name, &version)
326            })
327    }
328}
329
330#[cfg(test)]
331mod test_parser {
332    use combine::EasyParser;
333
334    use super::ExtensionNameVersionPair;
335
336    #[test]
337    fn test_parser_with_version() {
338        let (pair, remaining) = ExtensionNameVersionPair::parser()
339            .easy_parse("ext@1.0,rest")
340            .unwrap();
341
342        let pair = pair.unwrap();
343        assert_eq!("ext", pair.name);
344        assert_eq!(Some(super::ExtensionVersion::new("1.0")), pair.version);
345        assert_eq!(",rest", remaining);
346    }
347
348    #[test]
349    fn test_parser_without_version() {
350        let (pair, remaining) = ExtensionNameVersionPair::parser()
351            .easy_parse("ext rest")
352            .unwrap();
353
354        let pair = pair.unwrap();
355        assert_eq!("ext", pair.name);
356        assert!(pair.version.is_none());
357        assert_eq!(" rest", remaining);
358    }
359}
360
361impl ExtensionNameVersionPair {
362    /// parse name from str
363    ///
364    /// returns (`parsed_name`, remaining)
365    ///
366    /// Note: parsed name was not strictly checked use [`Self::validate_name`]
367    ///
368    /// # Errors
369    ///
370    /// Returns [`String`] when parsing fails.
371    pub fn parse_name(expect_name: &str) -> Result<(String, &str), String> {
372        let parse_result = Self::name_parser().easy_parse(expect_name);
373        parse_result.map_err(|e| e.to_string())
374    }
375}
376
377#[cfg(test)]
378mod test_parse_name {
379    use super::ExtensionNameVersionPair;
380
381    #[test]
382    fn test_parse_name_ok() {
383        let (parsed, remaining) = ExtensionNameVersionPair::parse_name("ext rest").unwrap();
384        assert_eq!("ext", parsed);
385        assert_eq!(" rest", remaining);
386    }
387}
388
389impl ExtensionNameVersionPair {
390    /// parse version from str
391    ///
392    /// returns (`parsed_version`, remaining)
393    ///
394    /// # Errors
395    ///
396    /// Returns [`String`] when parsing fails.
397    pub fn parse_version(input: &str) -> Result<(String, &str), String> {
398        Self::version_parser()
399            .easy_parse(input)
400            .map_err(|e| e.to_string())
401    }
402}
403
404#[cfg(test)]
405mod test_parse_version {
406    use super::ExtensionNameVersionPair;
407
408    #[test]
409    fn test_parse_version_ok() {
410        let (parsed, remaining) = ExtensionNameVersionPair::parse_version("1.0 rest").unwrap();
411        assert_eq!("1.0", parsed);
412        assert_eq!(" rest", remaining);
413    }
414}