rust_sql_organizer/searcher/
mod.rs1use glob::{glob, GlobError};
2use std::path::{Path, PathBuf};
3pub mod error;
4use error::Error;
5
6#[derive(Clone, Debug)]
7pub struct FileExtension {
8 extension: String,
9}
10
11impl FileExtension {
12 pub fn new(extension: &str) -> Result<FileExtension, Error> {
13 let extension = extension.trim();
14 if extension.len() == 0 {
15 return Err(Error::EmptyFileExtension);
16 }
17 Ok(FileExtension {
18 extension: extension.to_string(),
19 })
20 }
21
22 fn get_glob(&self) -> String {
23 return format!("**/*.{}", self.extension);
24 }
25}
26
27pub fn get_all_files(
28 path: &Path,
29 file_formats: &[FileExtension],
30) -> Result<Vec<Result<PathBuf, GlobError>>, Error> {
31 let mut result: Vec<Result<PathBuf, GlobError>> = Vec::new();
32 for file_format in file_formats {
33 let glob_str = file_format.get_glob();
34 let pattern_path = path.join(Path::new(&glob_str));
35 let pattern = pattern_path.to_str().expect("UTF-8 error in the pattern");
36 result.extend(glob(&pattern)?);
37 }
38 Ok(result)
39}
40
41#[cfg(test)]
42mod searcher_test {
43 use super::{get_all_files, FileExtension};
44 use std::fs::File;
45 use std::path::{Path, PathBuf};
46 use tempdir::TempDir;
47
48 #[test]
49 fn test_file_extension() {
50 let result = FileExtension::new("sql");
51 assert!(result.is_ok());
52 assert_eq!(result.unwrap().extension, "sql")
53 }
54
55 #[test]
56 fn test_file_extension_error() {
57 let result = FileExtension::new("");
58 assert!(result.is_err());
59 }
60
61 #[test]
62 fn test_file_extension_get_glob() {
63 let file_extension = FileExtension {
64 extension: "sql".to_string(),
65 };
66 let glob = file_extension.get_glob();
67 assert_eq!(glob, "**/*.sql")
68 }
69
70 #[cfg(test)]
71 fn create_temp_files(prefix: &str, file_names: &[&str]) -> TempDir {
72 let tmp_dir = TempDir::new(prefix).unwrap();
73 for &file_name in file_names {
74 File::create(tmp_dir.path().join(Path::new(file_name))).unwrap();
75 }
76 tmp_dir
77 }
78
79 #[test]
80 fn test_get_all_files() {
81 let files = ["test.sql", "test_2.sql", "test_3.txt", "test_4.snowsql"];
82 let expected_files = ["test.sql", "test_2.sql", "test_4.snowsql"];
83 let tmp_dir = create_temp_files("test_get_all_files", &files);
84 let file_extensions = [
85 FileExtension {
86 extension: "sql".to_string(),
87 },
88 FileExtension {
89 extension: "snowsql".to_string(),
90 },
91 ];
92 let all_files = get_all_files(&tmp_dir.path(), &file_extensions);
93 assert!(all_files.is_ok());
94 let all_files = all_files.unwrap();
95 let (ok_res, err_res): (Vec<_>, Vec<_>) = all_files.iter().partition(|&r| r.is_ok());
96 assert_eq!(err_res.len(), 0);
97 let ok_files: Vec<&PathBuf> = ok_res.iter().map(|&r| r.as_ref().unwrap()).collect();
98
99 for file in ok_files {
100 assert!(expected_files.contains(&file.file_name().unwrap().to_str().unwrap()));
101 assert!(file.file_name().unwrap().to_str().unwrap() != "test_3.txt");
102 }
103 }
104}