Skip to main content

spec_core/
loader.rs

1//! Loader module: Read and parse .unit.spec files from the filesystem
2//!
3//! Functions:
4//! - Load single .unit.spec file
5//! - Load directory recursively
6//! - UTF-8 validation before YAML parsing
7//! - Error tracking with file paths
8
9use crate::types::{LoadedSpec, SpecSource, SpecStruct};
10use crate::validator::validate_raw_yaml;
11use crate::{Result, SpecError};
12use std::fs;
13use std::path::Path;
14#[cfg(test)]
15use std::path::PathBuf;
16use walkdir::WalkDir;
17
18#[cfg(test)]
19use crate::validator::validate_semantic;
20
21/// Result of a collect-all directory load.
22#[derive(Debug, Default)]
23pub struct DirectoryLoadReport {
24    pub specs: Vec<LoadedSpec>,
25    pub errors: Vec<SpecError>,
26    pub warnings: Vec<crate::SpecWarning>,
27    pub total_files: usize,
28}
29
30fn read_yaml_value<P: AsRef<Path>>(path: P) -> Result<(String, serde_yaml_bw::Value)> {
31    let path = path.as_ref();
32    let path_str = path.to_string_lossy().to_string();
33
34    // Read file as bytes for UTF-8 validation
35    let bytes = fs::read(path)?;
36
37    // Validate UTF-8
38    if std::str::from_utf8(&bytes).is_err() {
39        return Err(SpecError::InvalidUtf8 { path: path_str });
40    }
41
42    // Parse YAML to Value (preserves raw author input)
43    let yaml_value: serde_yaml_bw::Value =
44        serde_yaml_bw::from_slice(&bytes).map_err(|e| SpecError::YamlParse {
45            message: e.to_string(),
46            path: path_str.clone(),
47        })?;
48
49    Ok((path_str, yaml_value))
50}
51
52/// Load a single .unit.spec file
53///
54/// Returns the parsed SpecStruct with its source file information.
55/// Performs UTF-8 validation before YAML parsing.
56pub fn load_file<P: AsRef<Path>>(path: P) -> Result<LoadedSpec> {
57    let (path_str, yaml_value) = read_yaml_value(path)?;
58
59    // Validate the raw authored YAML before serde can normalize or drop fields.
60    validate_raw_yaml(&yaml_value, &path_str)?;
61
62    // Deserialize to SpecStruct
63    let spec: SpecStruct =
64        serde_yaml_bw::from_value(yaml_value).map_err(|e| SpecError::YamlParse {
65            message: e.to_string(),
66            path: path_str.clone(),
67        })?;
68
69    Ok(LoadedSpec {
70        source: SpecSource {
71            file_path: path_str,
72            id: spec.id.clone(),
73        },
74        spec,
75    })
76}
77
78/// Load all .unit.spec files from a directory recursively
79///
80/// Returns a vector of LoadedSpec, sorted by file path.
81/// Non-.unit.spec files are skipped.
82/// Empty directories return an empty vec (not an error).
83pub fn load_directory<P: AsRef<Path>>(dir: P) -> Result<Vec<LoadedSpec>> {
84    let report = load_directory_report(dir);
85    if let Some(err) = report.errors.into_iter().next() {
86        return Err(err);
87    }
88    Ok(report.specs)
89}
90
91/// Load all .unit.spec files from a directory recursively, collecting traversal
92/// warnings and continuing past symlink cycles.
93pub fn load_directory_report<P: AsRef<Path>>(dir: P) -> DirectoryLoadReport {
94    let dir = dir.as_ref();
95    let mut report = DirectoryLoadReport::default();
96
97    for entry in WalkDir::new(dir).follow_links(true) {
98        match entry {
99            Ok(entry) => {
100                let path = entry.path();
101
102                if !path.is_file() {
103                    continue;
104                }
105
106                let Some(name) = path.file_name().and_then(|n| n.to_str()) else {
107                    continue;
108                };
109
110                if !name.ends_with(".unit.spec") {
111                    continue;
112                }
113
114                report.total_files += 1;
115                match load_file(path) {
116                    Ok(spec) => report.specs.push(spec),
117                    Err(err) => report.errors.push(err),
118                }
119            }
120            Err(err) => {
121                if let Some(warning) = walkdir_cycle_warning(&err) {
122                    report.warnings.push(warning);
123                } else {
124                    report.errors.push(walkdir_error(err));
125                }
126            }
127        }
128    }
129
130    report
131        .specs
132        .sort_by(|a, b| a.source.file_path.cmp(&b.source.file_path));
133    report
134}
135
136/// Load all .unit.spec files from a directory recursively and collect all errors.
137///
138/// Unlike `load_directory`, this helper continues after failures so callers can
139/// present grouped diagnostics for the full directory.
140#[cfg(test)]
141pub(crate) fn load_directory_collect_all<P: AsRef<Path>>(dir: P) -> DirectoryLoadReport {
142    let mut report = load_directory_report(dir);
143    let loaded_specs = std::mem::take(&mut report.specs);
144
145    for spec in loaded_specs {
146        match validate_semantic(&spec) {
147            Ok(()) => report.specs.push(spec),
148            Err(err) => report.errors.push(err),
149        }
150    }
151
152    report
153        .specs
154        .sort_by(|a, b| a.source.file_path.cmp(&b.source.file_path));
155    report
156}
157
158fn walkdir_cycle_warning(err: &walkdir::Error) -> Option<crate::SpecWarning> {
159    err.loop_ancestor()
160        .map(|_| crate::SpecWarning::SymlinkCycleSkipped {
161            path: err
162                .path()
163                .map(|path| path.display().to_string())
164                .unwrap_or_else(|| "<unknown>".to_string()),
165        })
166}
167
168fn walkdir_error(err: walkdir::Error) -> SpecError {
169    SpecError::Traversal {
170        message: err.to_string(),
171        path: err
172            .path()
173            .map(|path| path.display().to_string())
174            .unwrap_or_else(|| "<unknown>".to_string()),
175    }
176}
177
178/// Check if a path is a .unit.spec file
179pub fn is_unit_spec(path: &Path) -> bool {
180    path.file_name()
181        .and_then(|n| n.to_str())
182        .map(|n| n.ends_with(".unit.spec"))
183        .unwrap_or(false)
184}
185
186/// Get the output directory for a generated file based on its module path
187///
188/// Returns the directory path where the .rs file should be written.
189/// E.g., for ID "pricing/apply_discount" with output base "./generated/spec",
190/// returns "./generated/spec/pricing"
191#[cfg(test)]
192pub(crate) fn output_dir_for_spec(output_base: impl AsRef<Path>, module_path: &str) -> PathBuf {
193    let mut path = output_base.as_ref().to_path_buf();
194    if !module_path.is_empty() {
195        path = path.join(module_path.replace('/', std::path::MAIN_SEPARATOR_STR));
196    }
197    path
198}
199
200/// Get the file path for a generated .rs file
201///
202/// E.g., for ID "pricing/apply_discount" with output base "./generated/spec",
203/// returns "./generated/spec/pricing/apply_discount.rs"
204#[cfg(test)]
205pub(crate) fn output_file_path(output_base: impl AsRef<Path>, id: &str) -> PathBuf {
206    let parts: Vec<&str> = id.split('/').collect();
207    let mut path = output_base.as_ref().to_path_buf();
208
209    if parts.len() > 1 {
210        // All but last segment form the directory path
211        for segment in &parts[..parts.len() - 1] {
212            path = path.join(segment);
213        }
214    }
215
216    // Last segment is the file name
217    let fn_name = parts.last().unwrap_or(&id);
218    path.push(format!("{fn_name}.rs"));
219
220    path
221}
222
223/// Get the directory path for a module's mod.rs file
224///
225/// E.g., for module path "pricing" with output base "./generated/spec",
226/// returns "./generated/spec/pricing"
227#[cfg(test)]
228pub(crate) fn mod_rs_dir(output_base: impl AsRef<Path>, module_path: &str) -> PathBuf {
229    if module_path.is_empty() {
230        output_base.as_ref().to_path_buf()
231    } else {
232        output_base
233            .as_ref()
234            .join(module_path.replace('/', std::path::MAIN_SEPARATOR_STR))
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use std::io::Write;
242    use tempfile::NamedTempFile;
243    use tempfile::TempDir;
244
245    #[test]
246    fn test_load_valid_file() {
247        let mut temp_file = NamedTempFile::new().unwrap();
248        let yaml = r#"
249id: pricing/apply_discount
250kind: function
251intent:
252  why: Apply a percentage discount.
253body:
254  rust: |
255    pub fn apply_discount(subtotal: f64, rate: f64) -> f64 {
256        subtotal - subtotal * rate
257    }
258"#;
259        temp_file.write_all(yaml.as_bytes()).unwrap();
260
261        let loaded = load_file(temp_file.path()).unwrap();
262        assert_eq!(loaded.spec.id, "pricing/apply_discount");
263        assert_eq!(loaded.spec.kind, "function");
264        assert_eq!(loaded.spec.intent.why, "Apply a percentage discount.");
265    }
266
267    #[test]
268    fn test_load_file_rejects_unknown_fields_before_deserialization() {
269        let mut temp_file = NamedTempFile::with_suffix(".unit.spec").unwrap();
270        let yaml = r#"
271id: pricing/apply_discount
272kind: function
273intent:
274  why: Apply a percentage discount.
275body:
276  rust: |
277    pub fn apply_discount(subtotal: f64, rate: f64) -> f64 {
278        subtotal - subtotal * rate
279    }
280extra_field: should_fail
281"#;
282        temp_file.write_all(yaml.as_bytes()).unwrap();
283
284        let result = load_file(temp_file.path());
285        assert!(result.is_err());
286        let err = result.unwrap_err().to_string();
287        assert!(err.contains("Schema validation failed"));
288        assert!(err.contains("unknown field"));
289    }
290
291    #[test]
292    fn test_load_file_not_found() {
293        let result = load_file("/nonexistent/file.unit.spec");
294        assert!(result.is_err());
295        assert!(result.unwrap_err().to_string().contains("No such file"));
296    }
297
298    #[test]
299    fn test_load_invalid_yaml() {
300        let mut temp_file = NamedTempFile::with_suffix("spec").unwrap();
301        temp_file.write_all(b"invalid: [").unwrap();
302        temp_file.flush().unwrap();
303
304        let result = load_file(temp_file.path());
305        assert!(result.is_err());
306        let err_msg = result.unwrap_err().to_string();
307        assert!(
308            err_msg.contains("parse") || err_msg.contains("YAML") || err_msg.contains("mapping")
309        );
310    }
311
312    #[test]
313    fn test_load_non_utf8() {
314        let mut temp_file = NamedTempFile::with_suffix(".unit.spec").unwrap();
315        // Write invalid UTF-8 bytes
316        temp_file.write_all(&[0x80, 0x81, 0x82, 0x83]).unwrap();
317        temp_file.flush().unwrap();
318
319        let result = load_file(temp_file.path());
320        assert!(result.is_err());
321        let err = result.unwrap_err().to_string();
322        assert!(err.contains("File is not valid UTF-8"));
323    }
324
325    #[test]
326    fn test_load_directory() {
327        let temp_dir = TempDir::new().unwrap();
328
329        // Create valid .unit.spec file
330        let file1 = temp_dir.path().join("pricing.unit.spec");
331        fs::write(
332            &file1,
333            r#"
334id: pricing/apply
335kind: function
336intent:
337  why: Apply pricing.
338body:
339  rust: pub fn apply() {}
340"#,
341        )
342        .unwrap();
343
344        // Create nested subdirectory with spec
345        let subdir = temp_dir.path().join("utils");
346        fs::create_dir(&subdir).unwrap();
347        let file2 = subdir.join("math.unit.spec");
348        fs::write(
349            &file2,
350            r#"
351id: utils/math/round
352kind: function
353intent:
354  why: Round numbers.
355body:
356  rust: pub fn round() {}
357"#,
358        )
359        .unwrap();
360
361        // Create a non-.unit.spec file (should be skipped)
362        let other_file = temp_dir.path().join("readme.txt");
363        fs::write(&other_file, "# Readme").unwrap();
364
365        let specs = load_directory(temp_dir.path()).unwrap();
366        assert_eq!(specs.len(), 2);
367    }
368
369    #[test]
370    fn test_load_directory_collect_all() {
371        let temp_dir = TempDir::new().unwrap();
372
373        fs::write(
374            temp_dir.path().join("good.unit.spec"),
375            r#"
376id: pricing/apply
377kind: function
378intent:
379  why: Apply pricing.
380body:
381  rust: "{ }"
382"#,
383        )
384        .unwrap();
385
386        fs::write(
387            temp_dir.path().join("bad.unit.spec"),
388            r#"
389id: pricing/type
390kind: function
391intent:
392  why: Bad keyword id.
393body:
394  rust: "{ }"
395"#,
396        )
397        .unwrap();
398
399        fs::write(temp_dir.path().join("notes.txt"), "ignore me").unwrap();
400
401        let report = load_directory_collect_all(temp_dir.path());
402        assert_eq!(report.specs.len(), 1);
403        assert_eq!(report.errors.len(), 1);
404        assert!(
405            report.errors[0]
406                .to_string()
407                .contains("Rust reserved keyword")
408        );
409    }
410
411    #[test]
412    #[cfg(unix)]
413    fn test_load_directory_report_skips_symlink_cycle_with_warning() {
414        use std::os::unix::fs as unix_fs;
415
416        let temp_dir = TempDir::new().unwrap();
417        let units_dir = temp_dir.path().join("units");
418        fs::create_dir_all(units_dir.join("pricing")).unwrap();
419        fs::write(
420            units_dir.join("pricing/apply.unit.spec"),
421            r#"
422id: pricing/apply
423kind: function
424intent:
425  why: Apply pricing.
426body:
427  rust: "{ }"
428"#,
429        )
430        .unwrap();
431
432        unix_fs::symlink(&units_dir, units_dir.join("loop")).unwrap();
433
434        let report = load_directory_report(&units_dir);
435        assert_eq!(report.specs.len(), 1);
436        assert!(report.errors.is_empty());
437        assert_eq!(report.warnings.len(), 1);
438        assert!(
439            report.warnings[0]
440                .to_string()
441                .contains("skipped symlink cycle")
442        );
443    }
444
445    #[test]
446    fn test_load_empty_directory() {
447        let temp_dir = TempDir::new().unwrap();
448        let specs = load_directory(temp_dir.path()).unwrap();
449        assert!(specs.is_empty());
450    }
451
452    #[test]
453    fn test_is_unit_spec_requires_exact_suffix() {
454        assert!(is_unit_spec(Path::new("pricing/apply_discount.unit.spec")));
455        assert!(!is_unit_spec(Path::new(
456            "pricing/apply_discount.unit.spec.bak"
457        )));
458        assert!(!is_unit_spec(Path::new("pricing/apply_discount.spec")));
459    }
460
461    #[test]
462    fn test_output_dir_for_spec() {
463        let base = Path::new("./generated/spec");
464
465        assert_eq!(
466            output_dir_for_spec(base, "pricing"),
467            PathBuf::from("./generated/spec/pricing")
468        );
469
470        assert_eq!(
471            output_dir_for_spec(base, "utils/math"),
472            PathBuf::from("./generated/spec/utils/math")
473        );
474
475        assert_eq!(
476            output_dir_for_spec(base, ""),
477            PathBuf::from("./generated/spec")
478        );
479    }
480
481    #[test]
482    fn test_output_file_path() {
483        let base = Path::new("./generated/spec");
484
485        assert_eq!(
486            output_file_path(base, "pricing/apply_discount"),
487            PathBuf::from("./generated/spec/pricing/apply_discount.rs")
488        );
489
490        assert_eq!(
491            output_file_path(base, "utils/math/round"),
492            PathBuf::from("./generated/spec/utils/math/round.rs")
493        );
494    }
495
496    #[test]
497    fn test_mod_rs_dir() {
498        let base = Path::new("./generated/spec");
499
500        assert_eq!(
501            mod_rs_dir(base, "pricing"),
502            PathBuf::from("./generated/spec/pricing")
503        );
504
505        assert_eq!(mod_rs_dir(base, ""), PathBuf::from("./generated/spec"));
506    }
507
508    #[test]
509    fn test_empty_file() {
510        let mut temp_file = NamedTempFile::with_suffix(".unit.spec").unwrap();
511        temp_file.write_all(b"").unwrap();
512        temp_file.flush().unwrap();
513
514        let result = load_file(temp_file.path());
515        assert!(result.is_err());
516        let err = result.unwrap_err().to_string();
517        assert!(
518            err.contains("missing")
519                || err.contains("EOF")
520                || err.contains("end of file")
521                || err.contains("Unknown entry")
522                || err.contains("Schema validation failed")
523        );
524    }
525}