rust_sql_organizer/searcher/
mod.rs

1use glob::{glob, GlobError};
2use std::path::{Path, PathBuf};
3pub mod error;
4use error::Error;
5
6#[derive(Clone, Debug)]
7pub struct FileExtension {
8    extension: String,
9}
10
11impl FileExtension {
12    pub fn new(extension: &str) -> Result<FileExtension, Error> {
13        let extension = extension.trim();
14        if extension.len() == 0 {
15            return Err(Error::EmptyFileExtension);
16        }
17        Ok(FileExtension {
18            extension: extension.to_string(),
19        })
20    }
21
22    fn get_glob(&self) -> String {
23        return format!("**/*.{}", self.extension);
24    }
25}
26
27pub fn get_all_files(
28    path: &Path,
29    file_formats: &[FileExtension],
30) -> Result<Vec<Result<PathBuf, GlobError>>, Error> {
31    let mut result: Vec<Result<PathBuf, GlobError>> = Vec::new();
32    for file_format in file_formats {
33        let glob_str = file_format.get_glob();
34        let pattern_path = path.join(Path::new(&glob_str));
35        let pattern = pattern_path.to_str().expect("UTF-8 error in the pattern");
36        result.extend(glob(&pattern)?);
37    }
38    Ok(result)
39}
40
41#[cfg(test)]
42mod searcher_test {
43    use super::{get_all_files, FileExtension};
44    use std::fs::File;
45    use std::path::{Path, PathBuf};
46    use tempdir::TempDir;
47
48    #[test]
49    fn test_file_extension() {
50        let result = FileExtension::new("sql");
51        assert!(result.is_ok());
52        assert_eq!(result.unwrap().extension, "sql")
53    }
54
55    #[test]
56    fn test_file_extension_error() {
57        let result = FileExtension::new("");
58        assert!(result.is_err());
59    }
60
61    #[test]
62    fn test_file_extension_get_glob() {
63        let file_extension = FileExtension {
64            extension: "sql".to_string(),
65        };
66        let glob = file_extension.get_glob();
67        assert_eq!(glob, "**/*.sql")
68    }
69
70    #[cfg(test)]
71    fn create_temp_files(prefix: &str, file_names: &[&str]) -> TempDir {
72        let tmp_dir = TempDir::new(prefix).unwrap();
73        for &file_name in file_names {
74            File::create(tmp_dir.path().join(Path::new(file_name))).unwrap();
75        }
76        tmp_dir
77    }
78
79    #[test]
80    fn test_get_all_files() {
81        let files = ["test.sql", "test_2.sql", "test_3.txt", "test_4.snowsql"];
82        let expected_files = ["test.sql", "test_2.sql", "test_4.snowsql"];
83        let tmp_dir = create_temp_files("test_get_all_files", &files);
84        let file_extensions = [
85            FileExtension {
86                extension: "sql".to_string(),
87            },
88            FileExtension {
89                extension: "snowsql".to_string(),
90            },
91        ];
92        let all_files = get_all_files(&tmp_dir.path(), &file_extensions);
93        assert!(all_files.is_ok());
94        let all_files = all_files.unwrap();
95        let (ok_res, err_res): (Vec<_>, Vec<_>) = all_files.iter().partition(|&r| r.is_ok());
96        assert_eq!(err_res.len(), 0);
97        let ok_files: Vec<&PathBuf> = ok_res.iter().map(|&r| r.as_ref().unwrap()).collect();
98
99        for file in ok_files {
100            assert!(expected_files.contains(&file.file_name().unwrap().to_str().unwrap()));
101            assert!(file.file_name().unwrap().to_str().unwrap() != "test_3.txt");
102        }
103    }
104}