Skip to main content

spec_core/
generator.rs

1//! Generator module: Generate Rust code from ResolvedSpec
2//!
3//! Implements the M1 generation path from PLAN.md:
4//! - prepend `use ...` imports for imports + deps
5//! - write generated `.rs` files
6//! - generate `mod.rs` contents
7//! - owned-tree orphan cleanup with `.spec-generated` marker safety rails
8
9use crate::syntax::validate_expect_expr;
10use crate::types::ResolvedSpec;
11use crate::{Result, SpecError};
12use std::collections::HashSet;
13use std::fs::{self, File};
14use std::io::Write;
15use std::path::{Component, Path, PathBuf};
16use walkdir::WalkDir;
17
18const GENERATED_MARKER: &str = ".spec-generated";
19
20#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
21pub struct GenerateOptions {
22    pub allow_unsafe_local_test_expect: bool,
23}
24
25fn build_fn_signature(spec: &ResolvedSpec) -> String {
26    let params = spec
27        .contract
28        .as_ref()
29        .and_then(|c| c.inputs.as_ref())
30        .map(|inputs| {
31            inputs
32                .iter()
33                .map(|(name, ty)| format!("{name}: {ty}"))
34                .collect::<Vec<_>>()
35                .join(", ")
36        })
37        .unwrap_or_default();
38
39    let return_type = spec
40        .contract
41        .as_ref()
42        .and_then(|c| c.returns.as_ref())
43        .map(|r| format!(" -> {}", r));
44
45    match return_type {
46        Some(ret) => format!("pub fn {}({}){}", spec.fn_name, params, ret),
47        None => format!("pub fn {}({})", spec.fn_name, params),
48    }
49}
50
51fn build_doc_comment(intent_why: &str) -> Option<String> {
52    let trimmed = intent_why.trim();
53    if trimmed.is_empty() {
54        return None;
55    }
56
57    let mut output = String::new();
58    for line in trimmed.lines() {
59        if line.trim().is_empty() {
60            output.push_str("///\n");
61        } else {
62            output.push_str("/// ");
63            output.push_str(line);
64            output.push('\n');
65        }
66    }
67
68    Some(output)
69}
70
71pub fn generate_code(spec: &ResolvedSpec) -> Result<String> {
72    generate_code_with_options(spec, &GenerateOptions::default())
73}
74
75pub fn generate_code_with_options(
76    spec: &ResolvedSpec,
77    options: &GenerateOptions,
78) -> Result<String> {
79    let (import_statements, dep_statements) = build_use_groups(spec)?;
80    let mut output = String::new();
81
82    for statement in import_statements {
83        output.push_str(&statement);
84        output.push('\n');
85    }
86
87    if !spec.imports.is_empty() && !spec.deps.is_empty() {
88        output.push('\n');
89    }
90
91    for statement in dep_statements {
92        output.push_str(&statement);
93        output.push('\n');
94    }
95
96    if !spec.imports.is_empty() || !spec.deps.is_empty() {
97        output.push('\n');
98    }
99
100    if let Some(doc_comment) = build_doc_comment(&spec.intent_why) {
101        output.push_str(&doc_comment);
102    }
103
104    let signature = build_fn_signature(spec);
105    let block = spec.body_rust.trim();
106    output.push_str(&format!("{signature} {block}"));
107    output.push('\n');
108
109    if !spec.local_tests.is_empty() {
110        // One blank line between the generated unit body and the tests module.
111        output.push('\n');
112        output.push_str("#[cfg(test)]\n");
113        output.push_str("mod tests {\n");
114        output.push_str("    use super::*;\n\n");
115
116        for (index, local_test) in spec.local_tests.iter().enumerate() {
117            let expect = local_test.expect.trim();
118            validate_expect_expr(expect, options.allow_unsafe_local_test_expect).map_err(
119                |err| SpecError::Generator {
120                    message: format!(
121                        "invalid local test expect for unit '{}' test '{}': {}",
122                        spec.id,
123                        local_test.id,
124                        err.message()
125                    ),
126                },
127            )?;
128            output.push_str("    #[test]\n");
129            output.push_str(&format!("    fn test_{}() {{\n", local_test.id));
130            output.push_str(&format!("        assert!({expect});\n"));
131            output.push_str("    }\n");
132
133            if index + 1 != spec.local_tests.len() {
134                output.push('\n');
135            }
136        }
137
138        output.push_str("}\n");
139    }
140    Ok(output)
141}
142
143pub fn write_generated_file(output_path: &str, content: &str) -> Result<()> {
144    let path = Path::new(output_path);
145
146    if let Some(parent) = path.parent() {
147        fs::create_dir_all(parent).map_err(|err| SpecError::OutputDir {
148            message: format!(
149                "Unable to create output directory {}: {}",
150                parent.display(),
151                err
152            ),
153        })?;
154    } else {
155        return Err(SpecError::Generator {
156            message: format!(
157                "Unable to write {}: missing parent directory",
158                path.display()
159            ),
160        });
161    }
162
163    let parent_dir = path
164        .parent()
165        .ok_or_else(|| SpecError::Generator {
166            message: format!(
167                "Unable to write {}: missing parent directory",
168                path.display()
169            ),
170        })?
171        .to_path_buf();
172
173    // Write to a temp file in the same directory and rename into place (per-file atomic).
174    let mut tmp = tempfile::Builder::new()
175        .prefix(".spec-tmp-")
176        .suffix(".tmp")
177        .tempfile_in(&parent_dir)
178        .map_err(|err| SpecError::Generator {
179            message: format!(
180                "Unable to create temp file in {}: {}",
181                parent_dir.display(),
182                err
183            ),
184        })?;
185
186    tmp.write_all(content.as_bytes())
187        .map_err(|err| SpecError::Generator {
188            message: format!("Unable to write temp file for {}: {}", path.display(), err),
189        })?;
190
191    if !content.ends_with('\n') {
192        tmp.write_all(b"\n").map_err(|err| SpecError::Generator {
193            message: format!(
194                "Unable to finalize temp file for {}: {}",
195                path.display(),
196                err
197            ),
198        })?;
199    }
200
201    tmp.flush().map_err(|err| SpecError::Generator {
202        message: format!("Unable to flush temp file for {}: {}", path.display(), err),
203    })?;
204
205    // On Windows, renaming over an existing file fails; remove it first.
206    if cfg!(windows) && path.exists() {
207        fs::remove_file(path).map_err(|err| SpecError::Generator {
208            message: format!("Unable to remove existing {}: {}", path.display(), err),
209        })?;
210    }
211
212    let tmp_path = tmp.into_temp_path();
213    fs::rename(&tmp_path, path).map_err(|err| SpecError::Generator {
214        message: format!(
215            "Unable to rename temp file into {}: {}",
216            path.display(),
217            err
218        ),
219    })?;
220
221    Ok(())
222}
223
224pub fn clean_output_dir(
225    output_base: &Path,
226    generated_rs_rel_paths: &HashSet<PathBuf>,
227) -> Result<()> {
228    let base = safe_output_path(output_base)?;
229
230    let marker = base.join(GENERATED_MARKER);
231    if !marker.exists() {
232        return Err(SpecError::MissingMarker {
233            path: base.display().to_string(),
234        });
235    }
236
237    // Remove orphaned `.rs` files (anything not in the generated set).
238    for entry in WalkDir::new(&base).follow_links(false) {
239        let entry = entry.map_err(SpecError::from)?;
240        if !entry.file_type().is_file() {
241            continue;
242        }
243        let path = entry.path();
244        if path.extension().and_then(|ext| ext.to_str()) != Some("rs") {
245            continue;
246        }
247
248        let rel = path
249            .strip_prefix(&base)
250            .map_err(|err| SpecError::Generator {
251                message: format!(
252                    "Unable to compute relative path for {}: {}",
253                    path.display(),
254                    err
255                ),
256            })?;
257
258        if !generated_rs_rel_paths.contains(rel) {
259            fs::remove_file(path).map_err(|err| SpecError::Generator {
260                message: format!("Unable to remove {}: {}", path.display(), err),
261            })?;
262        }
263    }
264
265    // Remove empty directories bottom-up (but never remove the base itself).
266    for entry in WalkDir::new(&base).follow_links(false).contents_first(true) {
267        let entry = entry.map_err(SpecError::from)?;
268        if !entry.file_type().is_dir() || entry.file_type().is_symlink() {
269            continue;
270        }
271        let path = entry.path();
272        if path == base {
273            continue;
274        }
275
276        let mut entries = fs::read_dir(path).map_err(|err| SpecError::Generator {
277            message: format!("Unable to read dir {}: {}", path.display(), err),
278        })?;
279        if entries.next().is_none() {
280            fs::remove_dir(path).map_err(|err| SpecError::Generator {
281                message: format!("Unable to remove dir {}: {}", path.display(), err),
282            })?;
283        }
284    }
285
286    File::create(&marker).map_err(|err| SpecError::Generator {
287        message: format!("Unable to recreate marker {}: {}", marker.display(), err),
288    })?;
289
290    Ok(())
291}
292
293pub fn generate_mod_rs(unit_files: &[String], subdirs: &[String]) -> Result<String> {
294    let mut seen = HashSet::new();
295    let mut unit_mods = Vec::new();
296    let mut subdir_mods = Vec::new();
297
298    for unit_file in unit_files {
299        if let Some(name) = module_item_name(unit_file) {
300            let decl = format!("pub mod {};", name);
301            if seen.insert(decl.clone()) {
302                unit_mods.push(decl);
303            }
304        }
305    }
306
307    for subdir in subdirs {
308        if let Some(name) = module_item_name(subdir) {
309            let decl = format!("pub mod {};", name);
310            if seen.insert(decl.clone()) {
311                subdir_mods.push(decl);
312            }
313        }
314    }
315
316    unit_mods.sort();
317    subdir_mods.sort();
318
319    let mut output = String::new();
320    for line in &unit_mods {
321        output.push_str(line);
322        output.push('\n');
323    }
324
325    if !unit_mods.is_empty() && !subdir_mods.is_empty() {
326        output.push('\n');
327    }
328
329    for line in &subdir_mods {
330        output.push_str(line);
331        output.push('\n');
332    }
333
334    Ok(output)
335}
336
337fn build_use_groups(spec: &ResolvedSpec) -> Result<(Vec<String>, Vec<String>)> {
338    if let Some((dep1, dep2)) = ResolvedSpec::has_dep_collision(&spec.deps) {
339        return Err(SpecError::DepCollision {
340            dep1: dep1.clone(),
341            dep2: dep2.clone(),
342            fn_name: ResolvedSpec::dep_fn_name(dep1).to_string(),
343            path: spec.id.clone(),
344        });
345    }
346
347    let mut import_seen = HashSet::new();
348    let mut import_statements = Vec::new();
349
350    for import_path in &spec.imports {
351        if import_seen.insert(import_path.clone()) {
352            import_statements.push(format!("use {};", import_path));
353        }
354    }
355
356    let mut dep_seen = HashSet::new();
357    let mut dep_statements = Vec::new();
358
359    for dep in &spec.deps {
360        if dep_seen.insert(dep.clone()) {
361            dep_statements.push(format!("use {}", ResolvedSpec::dep_to_use_path(dep)));
362        }
363    }
364
365    Ok((import_statements, dep_statements))
366}
367
368fn module_item_name(fragment: &str) -> Option<String> {
369    Path::new(fragment)
370        .file_name()
371        .and_then(|name| name.to_str())
372        .map(|name| name.trim_end_matches(".rs").to_string())
373        .filter(|name| !name.is_empty())
374}
375
376pub fn safe_output_path<P: AsRef<Path>>(path: P) -> Result<PathBuf> {
377    let path = path.as_ref();
378    let project_root = canonicalize_existing_path(&std::env::current_dir().map_err(|err| {
379        SpecError::OutputDir {
380            message: format!("Unable to determine project root: {err}"),
381        }
382    })?)?;
383    let output_base = canonicalize_output_path(path)?;
384
385    if !output_base.starts_with(&project_root) {
386        return Err(SpecError::OutputDir {
387            message: format!(
388                "Refusing to generate into {}: output path is outside the project root {}",
389                output_base.display(),
390                project_root.display()
391            ),
392        });
393    }
394
395    Ok(output_base)
396}
397
398pub fn normalized_absolute_path<P: AsRef<Path>>(path: P) -> PathBuf {
399    let path = path.as_ref();
400    let mut normalized = if path.is_absolute() {
401        PathBuf::new()
402    } else {
403        std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
404    };
405
406    for component in path.components() {
407        match component {
408            Component::Prefix(prefix) => normalized.push(prefix.as_os_str()),
409            Component::RootDir => normalized.push(component.as_os_str()),
410            Component::CurDir => {}
411            Component::ParentDir => {
412                normalized.pop();
413            }
414            Component::Normal(segment) => normalized.push(segment),
415        }
416    }
417
418    normalized
419}
420
421fn canonicalize_output_path(path: &Path) -> Result<PathBuf> {
422    let absolute = normalized_absolute_path(path);
423    if absolute.exists() {
424        return canonicalize_existing_path(&absolute);
425    }
426
427    let mut current = absolute.as_path();
428    let mut missing_segments = Vec::new();
429
430    while !current.exists() {
431        let segment = current.file_name().ok_or_else(|| SpecError::OutputDir {
432            message: format!(
433                "Unable to resolve output path {}: no existing ancestor found",
434                absolute.display()
435            ),
436        })?;
437        missing_segments.push(segment.to_os_string());
438        current = current.parent().ok_or_else(|| SpecError::OutputDir {
439            message: format!(
440                "Unable to resolve output path {}: no existing ancestor found",
441                absolute.display()
442            ),
443        })?;
444    }
445
446    let mut resolved = canonicalize_existing_path(current)?;
447    for segment in missing_segments.iter().rev() {
448        resolved.push(segment);
449    }
450
451    Ok(resolved)
452}
453
454fn canonicalize_existing_path(path: &Path) -> Result<PathBuf> {
455    path.canonicalize().map_err(|err| SpecError::OutputDir {
456        message: format!("Unable to canonicalize {}: {}", path.display(), err),
457    })
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463    use crate::syntax::{ExpectExprErrorKind, validate_expect_expr};
464    use crate::types::{Body, Intent, LocalTest, ResolvedSpec, SpecStruct};
465    #[cfg(unix)]
466    use std::os::unix::fs as unix_fs;
467    use tempfile::TempDir;
468
469    fn test_spec_with_intent(
470        deps: Vec<&str>,
471        imports: Vec<&str>,
472        body: &str,
473        intent_why: &str,
474    ) -> ResolvedSpec {
475        ResolvedSpec::from_spec(SpecStruct {
476            id: "pricing/apply_discount".to_string(),
477            kind: "function".to_string(),
478            intent: Intent {
479                why: intent_why.to_string(),
480            },
481            contract: None,
482            deps: deps.into_iter().map(|dep| dep.to_string()).collect(),
483            imports: imports
484                .into_iter()
485                .map(|import| import.to_string())
486                .collect(),
487            body: Body {
488                rust: body.to_string(),
489            },
490            local_tests: vec![],
491            links: None,
492            spec_version: None,
493        })
494    }
495
496    fn test_spec_with(deps: Vec<&str>, imports: Vec<&str>, body: &str) -> ResolvedSpec {
497        test_spec_with_intent(deps, imports, body, " ")
498    }
499
500    fn test_spec(deps: Vec<&str>, body: &str) -> ResolvedSpec {
501        test_spec_with(deps, vec![], body)
502    }
503
504    #[test]
505    fn generate_code_includes_doc_comment_from_intent() {
506        let spec = test_spec_with_intent(
507            vec![],
508            vec!["rust_decimal::Decimal"],
509            "{\n    Decimal::ZERO\n}",
510            "Apply a percentage discount.",
511        );
512
513        let code = generate_code(&spec).unwrap();
514        assert_eq!(
515            code,
516            "use rust_decimal::Decimal;\n\n/// Apply a percentage discount.\npub fn apply_discount() {\n    Decimal::ZERO\n}\n"
517        );
518    }
519
520    #[test]
521    fn generate_code_multiline_intent_produces_multiline_doc_comment() {
522        let spec = test_spec_with_intent(
523            vec![],
524            vec![],
525            "{\n    Decimal::ZERO\n}",
526            "\nFirst line.\n\nSecond line.\n",
527        );
528
529        let code = generate_code(&spec).unwrap();
530        assert_eq!(
531            code,
532            "/// First line.\n///\n/// Second line.\npub fn apply_discount() {\n    Decimal::ZERO\n}\n"
533        );
534    }
535
536    #[test]
537    fn generate_code_omits_doc_comment_for_blank_intent() {
538        let spec = test_spec_with_intent(vec![], vec![], "{\n    Decimal::ZERO\n}", "   \n  ");
539
540        let code = generate_code(&spec).unwrap();
541        assert_eq!(code, "pub fn apply_discount() {\n    Decimal::ZERO\n}\n");
542    }
543
544    #[test]
545    fn test_generate_code_prepends_use_statements() {
546        let spec = test_spec(
547            vec!["money/round", "utils/math/normalize"],
548            "{\n    round(Decimal::ONE)\n}",
549        );
550
551        let code = generate_code(&spec).unwrap();
552        assert_eq!(
553            code,
554            "use crate::money::round::round;\nuse crate::utils::math::normalize::normalize;\n\npub fn apply_discount() {\n    round(Decimal::ONE)\n}\n"
555        );
556    }
557
558    #[test]
559    fn test_generate_code_rejects_dep_collision() {
560        let spec = test_spec(
561            vec!["money/round", "utils/round"],
562            "{\n    round(Decimal::ONE)\n}",
563        );
564
565        let err = generate_code(&spec).unwrap_err();
566        assert!(err.to_string().contains("Dep fn_name collision"));
567    }
568
569    #[test]
570    fn imports_field_generates_correct_use_statement() {
571        let spec = test_spec_with(
572            vec![],
573            vec!["rust_decimal::Decimal"],
574            "{\n    Decimal::ZERO\n}",
575        );
576
577        let code = generate_code(&spec).unwrap();
578        assert_eq!(
579            code,
580            "use rust_decimal::Decimal;\n\npub fn apply_discount() {\n    Decimal::ZERO\n}\n"
581        );
582    }
583
584    #[test]
585    fn deps_unchanged_after_imports_split() {
586        let spec = test_spec_with(
587            vec!["money/round"],
588            vec![],
589            "{\n    round(Decimal::ZERO)\n}",
590        );
591
592        let code = generate_code(&spec).unwrap();
593        assert!(code.contains("use crate::money::round::round;"));
594    }
595
596    #[test]
597    fn imports_emitted_before_deps_in_use_statements() {
598        let spec = test_spec_with(
599            vec!["money/round"],
600            vec!["rust_decimal::Decimal"],
601            "{\n    round(Decimal::ZERO)\n}",
602        );
603
604        let code = generate_code(&spec).unwrap();
605        assert_eq!(
606            code,
607            "use rust_decimal::Decimal;\n\nuse crate::money::round::round;\n\npub fn apply_discount() {\n    round(Decimal::ZERO)\n}\n"
608        );
609    }
610
611    #[test]
612    fn test_write_generated_file_creates_parent_dirs() {
613        let temp_dir = TempDir::new().unwrap();
614        let file_path = temp_dir
615            .path()
616            .join("generated/spec/pricing/apply_discount.rs");
617
618        write_generated_file(
619            file_path.to_str().unwrap(),
620            "pub fn apply_discount() -> Decimal { Decimal::ZERO }\n",
621        )
622        .unwrap();
623
624        let contents = fs::read_to_string(&file_path).unwrap();
625        assert_eq!(
626            contents,
627            "pub fn apply_discount() -> Decimal { Decimal::ZERO }\n"
628        );
629    }
630
631    #[test]
632    fn test_generate_mod_rs_lists_units_and_subdirs() {
633        let content = generate_mod_rs(
634            &["apply_discount.rs".to_string(), "refund.rs".to_string()],
635            &["taxes".to_string(), "discounts".to_string()],
636        )
637        .unwrap();
638
639        assert_eq!(
640            content,
641            "pub mod apply_discount;\npub mod refund;\n\npub mod discounts;\npub mod taxes;\n"
642        );
643    }
644
645    #[test]
646    fn generate_local_tests_produces_cfg_test_block() {
647        let mut spec = test_spec_with(
648            vec![],
649            vec!["rust_decimal::Decimal"],
650            "{\n    Decimal::ZERO\n}",
651        );
652        spec.local_tests = vec![LocalTest {
653            id: "happy_path".to_string(),
654            expect: "apply_discount() == Decimal::ZERO".to_string(),
655        }];
656
657        let code = generate_code(&spec).unwrap();
658        assert!(code.contains("#[cfg(test)]\nmod tests {"));
659        assert!(code.contains("use super::*;"));
660        assert!(code.contains("fn test_happy_path() {"));
661        assert!(code.contains("assert!(apply_discount() == Decimal::ZERO);"));
662    }
663
664    #[test]
665    fn generate_no_local_tests_produces_no_test_block() {
666        let spec = test_spec_with(vec![], vec![], "{ }");
667        let code = generate_code(&spec).unwrap();
668        assert!(!code.contains("#[cfg(test)]"));
669        assert!(!code.contains("mod tests {"));
670    }
671
672    #[test]
673    fn generate_code_rejects_unsafe_expect_at_sink() {
674        let mut spec = test_spec_with(vec![], vec![], "{ true }");
675        spec.local_tests = vec![LocalTest {
676            id: "unsafe_attempt".to_string(),
677            expect: "{ let ok = apply_discount(); ok }".to_string(),
678        }];
679
680        let err = generate_code(&spec).unwrap_err().to_string();
681        assert!(err.contains("pricing/apply_discount"), "{err}");
682        assert!(err.contains("unsafe_attempt"), "{err}");
683        assert!(err.contains("block, unsafe, closure"), "{err}");
684    }
685
686    #[test]
687    fn generate_code_rejects_deeply_nested_expect_at_sink() {
688        let mut spec = test_spec_with(vec![], vec![], "{ true }");
689        spec.local_tests = vec![LocalTest {
690            id: "deep".to_string(),
691            expect: format!("{}true{}", "(".repeat(200), ")".repeat(200)),
692        }];
693
694        let err = generate_code(&spec).unwrap_err().to_string();
695        assert!(err.contains("pricing/apply_discount"), "{err}");
696        assert!(err.contains("deep"), "{err}");
697        assert!(err.contains("maximum depth of 128"), "{err}");
698    }
699
700    #[test]
701    fn generate_code_sink_guard_includes_unit_and_test_id_in_error() {
702        let mut spec = test_spec_with(vec![], vec![], "{ true }");
703        spec.local_tests = vec![LocalTest {
704            id: "broken".to_string(),
705            expect: "(".to_string(),
706        }];
707
708        let err = generate_code(&spec).unwrap_err().to_string();
709        assert!(err.contains("pricing/apply_discount"), "{err}");
710        assert!(err.contains("broken"), "{err}");
711    }
712
713    #[test]
714    fn generate_code_with_options_preserves_escape_hatch() {
715        let mut spec = test_spec_with(vec![], vec![], "{ true }");
716        spec.local_tests = vec![LocalTest {
717            id: "unsafe_allowed".to_string(),
718            expect: "{ let ok = apply_discount(); ok }".to_string(),
719        }];
720
721        let code = generate_code_with_options(
722            &spec,
723            &GenerateOptions {
724                allow_unsafe_local_test_expect: true,
725            },
726        )
727        .unwrap();
728
729        assert!(code.contains("assert!({ let ok = apply_discount(); ok });"));
730    }
731
732    #[test]
733    fn shared_expect_validation_reports_too_deep_before_syn_parse() {
734        let result = validate_expect_expr(
735            &format!("{}true{}", "(".repeat(200), ")".repeat(200)),
736            false,
737        );
738        match result {
739            Err(ExpectExprErrorKind::TooDeep { max_depth }) => assert_eq!(max_depth, 128),
740            Err(other) => panic!("expected too-deep error, got {:?}", other),
741            Ok(_) => panic!("expected too-deep error, got success"),
742        }
743    }
744
745    #[test]
746    fn clean_output_dir_removes_stale_module_from_prior_run() {
747        let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
748        let base = temp_dir.path().join("generated/spec");
749        let pricing = base.join("pricing");
750        let test_mod = base.join("test");
751        fs::create_dir_all(&pricing).unwrap();
752        fs::create_dir_all(&test_mod).unwrap();
753        fs::write(base.join(GENERATED_MARKER), "").unwrap();
754
755        // "Current run" generated files.
756        fs::write(base.join("mod.rs"), "pub mod pricing;\n").unwrap();
757        fs::write(pricing.join("apply_discount.rs"), "fn a() {}\n").unwrap();
758        fs::write(pricing.join("mod.rs"), "pub mod apply_discount;\n").unwrap();
759
760        // Stale module files from a prior run should be removed (and the empty dir pruned).
761        fs::write(test_mod.join("foo.rs"), "fn stale() {}\n").unwrap();
762        fs::write(test_mod.join("mod.rs"), "pub mod foo;\n").unwrap();
763
764        let mut generated = HashSet::new();
765        generated.insert(PathBuf::from("pricing/apply_discount.rs"));
766        generated.insert(PathBuf::from("pricing/mod.rs"));
767        generated.insert(PathBuf::from("mod.rs"));
768
769        clean_output_dir(&base, &generated).unwrap();
770
771        assert!(pricing.join("apply_discount.rs").exists());
772        assert!(pricing.join("mod.rs").exists());
773
774        assert!(!test_mod.join("foo.rs").exists());
775        assert!(!test_mod.join("mod.rs").exists());
776        assert!(!test_mod.exists(), "stale module dir should be removed");
777
778        assert!(base.join("mod.rs").exists());
779        assert!(base.join(GENERATED_MARKER).exists());
780    }
781
782    #[test]
783    fn test_clean_output_dir_refuses_without_marker() {
784        let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
785        let base = temp_dir.path().join("generated/spec");
786        fs::create_dir_all(&base).unwrap();
787
788        let generated = HashSet::new();
789        let err = clean_output_dir(&base, &generated).unwrap_err();
790        assert!(matches!(err, SpecError::MissingMarker { .. }));
791    }
792
793    #[test]
794    #[cfg(unix)]
795    fn test_clean_output_dir_does_not_follow_symlink_dirs() {
796        let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
797        let base = temp_dir.path().join("generated/spec");
798        let pricing = base.join("pricing");
799        fs::create_dir_all(&pricing).unwrap();
800        fs::write(base.join(GENERATED_MARKER), "").unwrap();
801        fs::write(pricing.join("apply_discount.rs"), "fn a() {}\n").unwrap();
802
803        let outside_dir = temp_dir.path().join("outside");
804        fs::create_dir_all(&outside_dir).unwrap();
805        let outside_rs = outside_dir.join("outside.rs");
806        fs::write(&outside_rs, "fn outside() {}\n").unwrap();
807
808        unix_fs::symlink(&outside_dir, pricing.join("link")).unwrap();
809
810        let generated = HashSet::new();
811        clean_output_dir(&base, &generated).unwrap();
812
813        assert!(!pricing.join("apply_discount.rs").exists());
814        assert!(
815            outside_rs.exists(),
816            "clean_output_dir must not delete files through symlinks"
817        );
818        assert!(base.join(GENERATED_MARKER).exists());
819    }
820
821    #[test]
822    fn safe_output_path_accepts_existing_path_inside_project_root() {
823        let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
824        let path = temp_dir.path().join("generated/spec");
825        fs::create_dir_all(&path).unwrap();
826
827        let resolved = safe_output_path(&path).unwrap();
828        assert_eq!(resolved, path.canonicalize().unwrap());
829    }
830
831    #[test]
832    fn safe_output_path_resolves_nonexistent_nested_path_from_existing_ancestor() {
833        let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
834        let base = temp_dir.path().join("generated");
835        fs::create_dir_all(&base).unwrap();
836        let path = base.join("spec/pricing");
837
838        let resolved = safe_output_path(&path).unwrap();
839        assert_eq!(
840            resolved,
841            base.canonicalize().unwrap().join("spec").join("pricing")
842        );
843    }
844
845    #[test]
846    #[cfg(unix)]
847    fn safe_output_path_rejects_symlink_escape_in_nonexistent_path() {
848        let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
849        let outside = tempfile::TempDir::new().unwrap();
850        let link = temp_dir.path().join("escape");
851        unix_fs::symlink(outside.path(), &link).unwrap();
852
853        let err = safe_output_path(link.join("generated/spec")).unwrap_err();
854        assert!(err.to_string().contains("outside the project root"));
855    }
856}