1use combine::{
2 EasyParser, ParseError, Parser, Stream, choice, many1,
3 parser::char::{alpha_num, char},
4 satisfy,
5};
6
7use std::{
8 fmt::{Display, Write},
9 path::{Path, PathBuf},
10};
11
12use crate::{ExtensionConfigError, ExtensionVersion};
13
14#[derive(Debug, serde::Deserialize, Clone, Default)]
16pub struct ExtensionNameVersionPair {
17 name: String,
18 version: Option<ExtensionVersion>,
19}
20
21impl Display for ExtensionNameVersionPair {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.write_str(&self.name)?;
24 if let Some(ver) = &self.version {
25 f.write_char('@')?;
26 write!(f, "{ver}")?;
27 }
28 Ok(())
29 }
30}
31
32#[cfg(test)]
33mod test_extension_name_version_pair_display {
34 use crate::ExtensionVersion;
35
36 #[test]
37 fn test_extension_name_version_pair_display() {
38 let item = super::ExtensionNameVersionPair {
39 name: String::from("ext"),
40 version: Some(ExtensionVersion::new("1.0")),
41 };
42 let display_result = item.to_string();
43 assert_eq!("ext@1.0", display_result)
44 }
45}
46
47impl ExtensionNameVersionPair {
48 pub fn validate_name(name: &str) -> Result<(), ExtensionConfigError> {
55 if name == "." || name == ".." || name.contains([':', '/', '\\']) {
56 ExtensionConfigError::invalid_extension_name(name)?;
57 }
58 Ok(())
59 }
60}
61
62#[cfg(test)]
63mod test_validate_name {
64 use super::ExtensionNameVersionPair;
65
66 #[test]
67 fn test_validate_name_invalid() {
68 assert!(ExtensionNameVersionPair::validate_name("..").is_err());
69 assert!(ExtensionNameVersionPair::validate_name("name:invalid").is_err());
70 }
71
72 #[test]
73 fn test_validate_name_valid() {
74 assert!(ExtensionNameVersionPair::validate_name("ext_name").is_ok());
75 }
76}
77
78impl ExtensionNameVersionPair {
79 pub fn try_new(
85 name: &str,
86 version: &Option<ExtensionVersion>,
87 ) -> Result<Self, ExtensionConfigError> {
88 Self::validate_name(name)?;
89
90 Ok(Self {
91 name: name.to_string(),
92 version: version.clone(),
93 })
94 }
95}
96
97#[cfg(test)]
98mod test_try_new {
99 use super::ExtensionNameVersionPair;
100
101 #[test]
102 fn test_try_new_ok() {
103 let version = Some(super::ExtensionVersion::new("1.0"));
104 let result = ExtensionNameVersionPair::try_new("ext", &version).unwrap();
105 assert_eq!("ext", result.name);
106 assert!(result.version.is_some());
107 }
108
109 #[test]
110 fn test_try_new_invalid_name() {
111 let version = None;
112 let result = ExtensionNameVersionPair::try_new("..", &version);
113 assert!(result.is_err());
114 }
115}
116
117impl ExtensionNameVersionPair {
118 pub fn resolve_version<P: AsRef<Path>>(
125 &mut self,
126 base_path: P,
127 ) -> Result<PathBuf, ExtensionConfigError> {
128 if let Some(version) = &self.version {
129 let path = base_path
130 .as_ref()
131 .join(PathBuf::from(format!("{}--{version}.sql", self.name)));
132 if path.exists() {
133 Ok(path)
134 } else {
135 ExtensionConfigError::extension_sql_not_exist(&self.name, version, path)
136 }
137 } else {
138 self.resolve_version_from_extension_files(base_path)
139 }
140 }
141
142 fn resolve_version_from_extension_files<P: AsRef<Path>>(
148 &mut self,
149 base_path: P,
150 ) -> Result<PathBuf, ExtensionConfigError> {
151 let name_prefix = format!("{}--", self.name);
152 let dir = base_path.as_ref().read_dir()?;
153 let mut version_candidates = Vec::new();
154 for entry in dir {
155 let entry = entry?;
156 if !entry.file_type()?.is_file() {
157 continue;
158 }
159 let name_path = PathBuf::from(entry.file_name());
160 let Some(ext) = name_path.extension().and_then(|v| v.to_str()) else {
161 continue;
162 };
163 let Some(stem) = name_path.file_stem().and_then(|v| v.to_str()) else {
164 continue;
165 };
166 if !ext.eq_ignore_ascii_case("sql") || !stem.starts_with(&name_prefix) {
167 continue;
168 }
169 let version = &stem[name_prefix.len()..];
170 if !version.is_empty() {
171 version_candidates.push(version.to_string());
172 }
173 }
174
175 if version_candidates.len() == 1 {
176 let version = &version_candidates[0];
177 self.version = Some(ExtensionVersion::new(version));
178 let path = base_path
179 .as_ref()
180 .join(PathBuf::from(format!("{}--{version}.sql", self.name)));
181 Ok(path)
182 } else if version_candidates.is_empty() {
183 ExtensionConfigError::extension_not_found(&self.name, base_path)
184 } else {
185 ExtensionConfigError::multiple_extension_version(
186 &self.name,
187 base_path,
188 &version_candidates,
189 )
190 }
191 }
192}
193
194#[cfg(test)]
195mod test_resolve_version {
196 use std::fs::File;
197
198 use testresult::TestResult;
199
200 use super::ExtensionNameVersionPair;
201
202 #[test]
203 fn test_resolve_version_with_version() -> TestResult {
204 let tmp_dir = tempfile::tempdir()?;
205 let file_path = tmp_dir.path().join("ext--1.0.sql");
206 File::create(&file_path)?;
207
208 let version = Some(super::ExtensionVersion::try_from("1.0")?);
209 let mut pair = ExtensionNameVersionPair::try_new("ext", &version)?;
210 let resolved = pair.resolve_version(tmp_dir.path())?;
211
212 assert_eq!(file_path, resolved);
213 Ok(())
214 }
215
216 #[test]
217 fn test_resolve_version_missing_file() -> TestResult {
218 let tmp_dir = tempfile::tempdir()?;
219 let version = Some(super::ExtensionVersion::try_from("1.0")?);
220 let mut pair = ExtensionNameVersionPair::try_new("ext", &version)?;
221
222 let result = pair.resolve_version(tmp_dir.path());
223 assert!(result.is_err());
224 Ok(())
225 }
226
227 #[test]
228 fn test_resolve_version_from_extension_files_single() -> TestResult {
229 let tmp_dir = tempfile::tempdir()?;
230 let file_path = tmp_dir.path().join("ext--1.0.sql");
231 File::create(&file_path)?;
232
233 let version = None;
234 let mut pair = ExtensionNameVersionPair::try_new("ext", &version)?;
235 let resolved = pair.resolve_version_from_extension_files(tmp_dir.path())?;
236
237 assert_eq!(file_path, resolved);
238 assert_eq!(Some(super::ExtensionVersion::new("1.0")), pair.version);
239 Ok(())
240 }
241
242 #[test]
243 fn test_resolve_version_from_extension_files_multiple() -> TestResult {
244 let tmp_dir = tempfile::tempdir()?;
245 File::create(tmp_dir.path().join("ext--1.0.sql"))?;
246 File::create(tmp_dir.path().join("ext--2.0.sql"))?;
247
248 let version = None;
249 let mut pair = ExtensionNameVersionPair::try_new("ext", &version)?;
250 let result = pair.resolve_version_from_extension_files(tmp_dir.path());
251
252 assert!(result.is_err());
253 Ok(())
254 }
255
256 #[test]
257 fn test_resolve_version_from_extension_files_missing() -> TestResult {
258 let tmp_dir = tempfile::tempdir()?;
259
260 let version = None;
261 let mut pair = ExtensionNameVersionPair::try_new("ext", &version)?;
262 let result = pair.resolve_version_from_extension_files(tmp_dir.path());
263
264 assert!(result.is_err());
265 Ok(())
266 }
267}
268
269impl ExtensionNameVersionPair {
270 fn name_parser<Input>() -> impl Parser<Input, Output = String>
271 where
272 Input: Stream<Token = char>,
273 Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
274 {
275 many1(choice((alpha_num(), char('_'), char('-'), char('.'))))
276 }
277
278 fn version_parser<Input>() -> impl Parser<Input, Output = String>
279 where
280 Input: Stream<Token = char>,
281 Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
282 {
283 many1(satisfy(|c: char| c != ',' && !c.is_whitespace()))
284 }
285}
286
287#[cfg(test)]
288mod test_parsers {
289 use combine::EasyParser;
290
291 #[test]
292 fn test_name_parser() {
293 let (parsed, remaining) = super::ExtensionNameVersionPair::name_parser()
294 .easy_parse("extension-1.0 rest")
295 .unwrap();
296 assert_eq!("extension-1.0", parsed);
297 assert_eq!(" rest", remaining);
298 }
299
300 #[test]
301 fn test_version_parser() {
302 let (parsed, remaining) = super::ExtensionNameVersionPair::version_parser()
303 .easy_parse("1.0,rest")
304 .unwrap();
305 assert_eq!("1.0", parsed);
306 assert_eq!(",rest", remaining);
307 }
308}
309
310impl ExtensionNameVersionPair {
311 pub(crate) fn parser<Input>()
312 -> impl Parser<Input, Output = Result<ExtensionNameVersionPair, ExtensionConfigError>>
313 where
314 Input: Stream<Token = char>,
315 Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
316 {
317 (
318 Self::name_parser(),
319 combine::optional(char('@').with(Self::version_parser())),
320 )
321 .map(|(name, ver)| {
322 let version = ver
323 .map(|v| ExtensionVersion::try_from(v.as_str()))
324 .transpose()?;
325 ExtensionNameVersionPair::try_new(&name, &version)
326 })
327 }
328}
329
330#[cfg(test)]
331mod test_parser {
332 use combine::EasyParser;
333
334 use super::ExtensionNameVersionPair;
335
336 #[test]
337 fn test_parser_with_version() {
338 let (pair, remaining) = ExtensionNameVersionPair::parser()
339 .easy_parse("ext@1.0,rest")
340 .unwrap();
341
342 let pair = pair.unwrap();
343 assert_eq!("ext", pair.name);
344 assert_eq!(Some(super::ExtensionVersion::new("1.0")), pair.version);
345 assert_eq!(",rest", remaining);
346 }
347
348 #[test]
349 fn test_parser_without_version() {
350 let (pair, remaining) = ExtensionNameVersionPair::parser()
351 .easy_parse("ext rest")
352 .unwrap();
353
354 let pair = pair.unwrap();
355 assert_eq!("ext", pair.name);
356 assert!(pair.version.is_none());
357 assert_eq!(" rest", remaining);
358 }
359}
360
361impl ExtensionNameVersionPair {
362 pub fn parse_name(expect_name: &str) -> Result<(String, &str), String> {
372 let parse_result = Self::name_parser().easy_parse(expect_name);
373 parse_result.map_err(|e| e.to_string())
374 }
375}
376
377#[cfg(test)]
378mod test_parse_name {
379 use super::ExtensionNameVersionPair;
380
381 #[test]
382 fn test_parse_name_ok() {
383 let (parsed, remaining) = ExtensionNameVersionPair::parse_name("ext rest").unwrap();
384 assert_eq!("ext", parsed);
385 assert_eq!(" rest", remaining);
386 }
387}
388
389impl ExtensionNameVersionPair {
390 pub fn parse_version(input: &str) -> Result<(String, &str), String> {
398 Self::version_parser()
399 .easy_parse(input)
400 .map_err(|e| e.to_string())
401 }
402}
403
404#[cfg(test)]
405mod test_parse_version {
406 use super::ExtensionNameVersionPair;
407
408 #[test]
409 fn test_parse_version_ok() {
410 let (parsed, remaining) = ExtensionNameVersionPair::parse_version("1.0 rest").unwrap();
411 assert_eq!("1.0", parsed);
412 assert_eq!(" rest", remaining);
413 }
414}