1use crate::syntax::validate_expect_expr;
10use crate::types::ResolvedSpec;
11use crate::{Result, SpecError};
12use std::collections::HashSet;
13use std::fs::{self, File};
14use std::io::Write;
15use std::path::{Component, Path, PathBuf};
16use walkdir::WalkDir;
17
18const GENERATED_MARKER: &str = ".spec-generated";
19
20#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
21pub struct GenerateOptions {
22 pub allow_unsafe_local_test_expect: bool,
23}
24
25fn build_fn_signature(spec: &ResolvedSpec) -> String {
26 let params = spec
27 .contract
28 .as_ref()
29 .and_then(|c| c.inputs.as_ref())
30 .map(|inputs| {
31 inputs
32 .iter()
33 .map(|(name, ty)| format!("{name}: {ty}"))
34 .collect::<Vec<_>>()
35 .join(", ")
36 })
37 .unwrap_or_default();
38
39 let return_type = spec
40 .contract
41 .as_ref()
42 .and_then(|c| c.returns.as_ref())
43 .map(|r| format!(" -> {}", r));
44
45 match return_type {
46 Some(ret) => format!("pub fn {}({}){}", spec.fn_name, params, ret),
47 None => format!("pub fn {}({})", spec.fn_name, params),
48 }
49}
50
51fn build_doc_comment(intent_why: &str) -> Option<String> {
52 let trimmed = intent_why.trim();
53 if trimmed.is_empty() {
54 return None;
55 }
56
57 let mut output = String::new();
58 for line in trimmed.lines() {
59 if line.trim().is_empty() {
60 output.push_str("///\n");
61 } else {
62 output.push_str("/// ");
63 output.push_str(line);
64 output.push('\n');
65 }
66 }
67
68 Some(output)
69}
70
71pub fn generate_code(spec: &ResolvedSpec) -> Result<String> {
72 generate_code_with_options(spec, &GenerateOptions::default())
73}
74
75pub fn generate_code_with_options(
76 spec: &ResolvedSpec,
77 options: &GenerateOptions,
78) -> Result<String> {
79 let (import_statements, dep_statements) = build_use_groups(spec)?;
80 let mut output = String::new();
81
82 for statement in import_statements {
83 output.push_str(&statement);
84 output.push('\n');
85 }
86
87 if !spec.imports.is_empty() && !spec.deps.is_empty() {
88 output.push('\n');
89 }
90
91 for statement in dep_statements {
92 output.push_str(&statement);
93 output.push('\n');
94 }
95
96 if !spec.imports.is_empty() || !spec.deps.is_empty() {
97 output.push('\n');
98 }
99
100 if let Some(doc_comment) = build_doc_comment(&spec.intent_why) {
101 output.push_str(&doc_comment);
102 }
103
104 let signature = build_fn_signature(spec);
105 let block = spec.body_rust.trim();
106 output.push_str(&format!("{signature} {block}"));
107 output.push('\n');
108
109 if !spec.local_tests.is_empty() {
110 output.push('\n');
112 output.push_str("#[cfg(test)]\n");
113 output.push_str("mod tests {\n");
114 output.push_str(" use super::*;\n\n");
115
116 for (index, local_test) in spec.local_tests.iter().enumerate() {
117 let expect = local_test.expect.trim();
118 validate_expect_expr(expect, options.allow_unsafe_local_test_expect).map_err(
119 |err| SpecError::Generator {
120 message: format!(
121 "invalid local test expect for unit '{}' test '{}': {}",
122 spec.id,
123 local_test.id,
124 err.message()
125 ),
126 },
127 )?;
128 output.push_str(" #[test]\n");
129 output.push_str(&format!(" fn test_{}() {{\n", local_test.id));
130 output.push_str(&format!(" assert!({expect});\n"));
131 output.push_str(" }\n");
132
133 if index + 1 != spec.local_tests.len() {
134 output.push('\n');
135 }
136 }
137
138 output.push_str("}\n");
139 }
140 Ok(output)
141}
142
143pub fn write_generated_file(output_path: &str, content: &str) -> Result<()> {
144 let path = Path::new(output_path);
145
146 if let Some(parent) = path.parent() {
147 fs::create_dir_all(parent).map_err(|err| SpecError::OutputDir {
148 message: format!(
149 "Unable to create output directory {}: {}",
150 parent.display(),
151 err
152 ),
153 })?;
154 } else {
155 return Err(SpecError::Generator {
156 message: format!(
157 "Unable to write {}: missing parent directory",
158 path.display()
159 ),
160 });
161 }
162
163 let parent_dir = path
164 .parent()
165 .ok_or_else(|| SpecError::Generator {
166 message: format!(
167 "Unable to write {}: missing parent directory",
168 path.display()
169 ),
170 })?
171 .to_path_buf();
172
173 let mut tmp = tempfile::Builder::new()
175 .prefix(".spec-tmp-")
176 .suffix(".tmp")
177 .tempfile_in(&parent_dir)
178 .map_err(|err| SpecError::Generator {
179 message: format!(
180 "Unable to create temp file in {}: {}",
181 parent_dir.display(),
182 err
183 ),
184 })?;
185
186 tmp.write_all(content.as_bytes())
187 .map_err(|err| SpecError::Generator {
188 message: format!("Unable to write temp file for {}: {}", path.display(), err),
189 })?;
190
191 if !content.ends_with('\n') {
192 tmp.write_all(b"\n").map_err(|err| SpecError::Generator {
193 message: format!(
194 "Unable to finalize temp file for {}: {}",
195 path.display(),
196 err
197 ),
198 })?;
199 }
200
201 tmp.flush().map_err(|err| SpecError::Generator {
202 message: format!("Unable to flush temp file for {}: {}", path.display(), err),
203 })?;
204
205 if cfg!(windows) && path.exists() {
207 fs::remove_file(path).map_err(|err| SpecError::Generator {
208 message: format!("Unable to remove existing {}: {}", path.display(), err),
209 })?;
210 }
211
212 let tmp_path = tmp.into_temp_path();
213 fs::rename(&tmp_path, path).map_err(|err| SpecError::Generator {
214 message: format!(
215 "Unable to rename temp file into {}: {}",
216 path.display(),
217 err
218 ),
219 })?;
220
221 Ok(())
222}
223
224pub fn clean_output_dir(
225 output_base: &Path,
226 generated_rs_rel_paths: &HashSet<PathBuf>,
227) -> Result<()> {
228 let base = safe_output_path(output_base)?;
229
230 let marker = base.join(GENERATED_MARKER);
231 if !marker.exists() {
232 return Err(SpecError::MissingMarker {
233 path: base.display().to_string(),
234 });
235 }
236
237 for entry in WalkDir::new(&base).follow_links(false) {
239 let entry = entry.map_err(SpecError::from)?;
240 if !entry.file_type().is_file() {
241 continue;
242 }
243 let path = entry.path();
244 if path.extension().and_then(|ext| ext.to_str()) != Some("rs") {
245 continue;
246 }
247
248 let rel = path
249 .strip_prefix(&base)
250 .map_err(|err| SpecError::Generator {
251 message: format!(
252 "Unable to compute relative path for {}: {}",
253 path.display(),
254 err
255 ),
256 })?;
257
258 if !generated_rs_rel_paths.contains(rel) {
259 fs::remove_file(path).map_err(|err| SpecError::Generator {
260 message: format!("Unable to remove {}: {}", path.display(), err),
261 })?;
262 }
263 }
264
265 for entry in WalkDir::new(&base).follow_links(false).contents_first(true) {
267 let entry = entry.map_err(SpecError::from)?;
268 if !entry.file_type().is_dir() || entry.file_type().is_symlink() {
269 continue;
270 }
271 let path = entry.path();
272 if path == base {
273 continue;
274 }
275
276 let mut entries = fs::read_dir(path).map_err(|err| SpecError::Generator {
277 message: format!("Unable to read dir {}: {}", path.display(), err),
278 })?;
279 if entries.next().is_none() {
280 fs::remove_dir(path).map_err(|err| SpecError::Generator {
281 message: format!("Unable to remove dir {}: {}", path.display(), err),
282 })?;
283 }
284 }
285
286 File::create(&marker).map_err(|err| SpecError::Generator {
287 message: format!("Unable to recreate marker {}: {}", marker.display(), err),
288 })?;
289
290 Ok(())
291}
292
293pub fn generate_mod_rs(unit_files: &[String], subdirs: &[String]) -> Result<String> {
294 let mut seen = HashSet::new();
295 let mut unit_mods = Vec::new();
296 let mut subdir_mods = Vec::new();
297
298 for unit_file in unit_files {
299 if let Some(name) = module_item_name(unit_file) {
300 let decl = format!("pub mod {};", name);
301 if seen.insert(decl.clone()) {
302 unit_mods.push(decl);
303 }
304 }
305 }
306
307 for subdir in subdirs {
308 if let Some(name) = module_item_name(subdir) {
309 let decl = format!("pub mod {};", name);
310 if seen.insert(decl.clone()) {
311 subdir_mods.push(decl);
312 }
313 }
314 }
315
316 unit_mods.sort();
317 subdir_mods.sort();
318
319 let mut output = String::new();
320 for line in &unit_mods {
321 output.push_str(line);
322 output.push('\n');
323 }
324
325 if !unit_mods.is_empty() && !subdir_mods.is_empty() {
326 output.push('\n');
327 }
328
329 for line in &subdir_mods {
330 output.push_str(line);
331 output.push('\n');
332 }
333
334 Ok(output)
335}
336
337fn build_use_groups(spec: &ResolvedSpec) -> Result<(Vec<String>, Vec<String>)> {
338 if let Some((dep1, dep2)) = ResolvedSpec::has_dep_collision(&spec.deps) {
339 return Err(SpecError::DepCollision {
340 dep1: dep1.clone(),
341 dep2: dep2.clone(),
342 fn_name: ResolvedSpec::dep_fn_name(dep1).to_string(),
343 path: spec.id.clone(),
344 });
345 }
346
347 let mut import_seen = HashSet::new();
348 let mut import_statements = Vec::new();
349
350 for import_path in &spec.imports {
351 if import_seen.insert(import_path.clone()) {
352 import_statements.push(format!("use {};", import_path));
353 }
354 }
355
356 let mut dep_seen = HashSet::new();
357 let mut dep_statements = Vec::new();
358
359 for dep in &spec.deps {
360 if dep_seen.insert(dep.clone()) {
361 dep_statements.push(format!("use {}", ResolvedSpec::dep_to_use_path(dep)));
362 }
363 }
364
365 Ok((import_statements, dep_statements))
366}
367
368fn module_item_name(fragment: &str) -> Option<String> {
369 Path::new(fragment)
370 .file_name()
371 .and_then(|name| name.to_str())
372 .map(|name| name.trim_end_matches(".rs").to_string())
373 .filter(|name| !name.is_empty())
374}
375
376pub fn safe_output_path<P: AsRef<Path>>(path: P) -> Result<PathBuf> {
377 let path = path.as_ref();
378 let project_root = canonicalize_existing_path(&std::env::current_dir().map_err(|err| {
379 SpecError::OutputDir {
380 message: format!("Unable to determine project root: {err}"),
381 }
382 })?)?;
383 let output_base = canonicalize_output_path(path)?;
384
385 if !output_base.starts_with(&project_root) {
386 return Err(SpecError::OutputDir {
387 message: format!(
388 "Refusing to generate into {}: output path is outside the project root {}",
389 output_base.display(),
390 project_root.display()
391 ),
392 });
393 }
394
395 Ok(output_base)
396}
397
398pub fn normalized_absolute_path<P: AsRef<Path>>(path: P) -> PathBuf {
399 let path = path.as_ref();
400 let mut normalized = if path.is_absolute() {
401 PathBuf::new()
402 } else {
403 std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))
404 };
405
406 for component in path.components() {
407 match component {
408 Component::Prefix(prefix) => normalized.push(prefix.as_os_str()),
409 Component::RootDir => normalized.push(component.as_os_str()),
410 Component::CurDir => {}
411 Component::ParentDir => {
412 normalized.pop();
413 }
414 Component::Normal(segment) => normalized.push(segment),
415 }
416 }
417
418 normalized
419}
420
421fn canonicalize_output_path(path: &Path) -> Result<PathBuf> {
422 let absolute = normalized_absolute_path(path);
423 if absolute.exists() {
424 return canonicalize_existing_path(&absolute);
425 }
426
427 let mut current = absolute.as_path();
428 let mut missing_segments = Vec::new();
429
430 while !current.exists() {
431 let segment = current.file_name().ok_or_else(|| SpecError::OutputDir {
432 message: format!(
433 "Unable to resolve output path {}: no existing ancestor found",
434 absolute.display()
435 ),
436 })?;
437 missing_segments.push(segment.to_os_string());
438 current = current.parent().ok_or_else(|| SpecError::OutputDir {
439 message: format!(
440 "Unable to resolve output path {}: no existing ancestor found",
441 absolute.display()
442 ),
443 })?;
444 }
445
446 let mut resolved = canonicalize_existing_path(current)?;
447 for segment in missing_segments.iter().rev() {
448 resolved.push(segment);
449 }
450
451 Ok(resolved)
452}
453
454fn canonicalize_existing_path(path: &Path) -> Result<PathBuf> {
455 path.canonicalize().map_err(|err| SpecError::OutputDir {
456 message: format!("Unable to canonicalize {}: {}", path.display(), err),
457 })
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463 use crate::syntax::{ExpectExprErrorKind, validate_expect_expr};
464 use crate::types::{Body, Intent, LocalTest, ResolvedSpec, SpecStruct};
465 #[cfg(unix)]
466 use std::os::unix::fs as unix_fs;
467 use tempfile::TempDir;
468
469 fn test_spec_with_intent(
470 deps: Vec<&str>,
471 imports: Vec<&str>,
472 body: &str,
473 intent_why: &str,
474 ) -> ResolvedSpec {
475 ResolvedSpec::from_spec(SpecStruct {
476 id: "pricing/apply_discount".to_string(),
477 kind: "function".to_string(),
478 intent: Intent {
479 why: intent_why.to_string(),
480 },
481 contract: None,
482 deps: deps.into_iter().map(|dep| dep.to_string()).collect(),
483 imports: imports
484 .into_iter()
485 .map(|import| import.to_string())
486 .collect(),
487 body: Body {
488 rust: body.to_string(),
489 },
490 local_tests: vec![],
491 links: None,
492 spec_version: None,
493 })
494 }
495
496 fn test_spec_with(deps: Vec<&str>, imports: Vec<&str>, body: &str) -> ResolvedSpec {
497 test_spec_with_intent(deps, imports, body, " ")
498 }
499
500 fn test_spec(deps: Vec<&str>, body: &str) -> ResolvedSpec {
501 test_spec_with(deps, vec![], body)
502 }
503
504 #[test]
505 fn generate_code_includes_doc_comment_from_intent() {
506 let spec = test_spec_with_intent(
507 vec![],
508 vec!["rust_decimal::Decimal"],
509 "{\n Decimal::ZERO\n}",
510 "Apply a percentage discount.",
511 );
512
513 let code = generate_code(&spec).unwrap();
514 assert_eq!(
515 code,
516 "use rust_decimal::Decimal;\n\n/// Apply a percentage discount.\npub fn apply_discount() {\n Decimal::ZERO\n}\n"
517 );
518 }
519
520 #[test]
521 fn generate_code_multiline_intent_produces_multiline_doc_comment() {
522 let spec = test_spec_with_intent(
523 vec![],
524 vec![],
525 "{\n Decimal::ZERO\n}",
526 "\nFirst line.\n\nSecond line.\n",
527 );
528
529 let code = generate_code(&spec).unwrap();
530 assert_eq!(
531 code,
532 "/// First line.\n///\n/// Second line.\npub fn apply_discount() {\n Decimal::ZERO\n}\n"
533 );
534 }
535
536 #[test]
537 fn generate_code_omits_doc_comment_for_blank_intent() {
538 let spec = test_spec_with_intent(vec![], vec![], "{\n Decimal::ZERO\n}", " \n ");
539
540 let code = generate_code(&spec).unwrap();
541 assert_eq!(code, "pub fn apply_discount() {\n Decimal::ZERO\n}\n");
542 }
543
544 #[test]
545 fn test_generate_code_prepends_use_statements() {
546 let spec = test_spec(
547 vec!["money/round", "utils/math/normalize"],
548 "{\n round(Decimal::ONE)\n}",
549 );
550
551 let code = generate_code(&spec).unwrap();
552 assert_eq!(
553 code,
554 "use crate::money::round::round;\nuse crate::utils::math::normalize::normalize;\n\npub fn apply_discount() {\n round(Decimal::ONE)\n}\n"
555 );
556 }
557
558 #[test]
559 fn test_generate_code_rejects_dep_collision() {
560 let spec = test_spec(
561 vec!["money/round", "utils/round"],
562 "{\n round(Decimal::ONE)\n}",
563 );
564
565 let err = generate_code(&spec).unwrap_err();
566 assert!(err.to_string().contains("Dep fn_name collision"));
567 }
568
569 #[test]
570 fn imports_field_generates_correct_use_statement() {
571 let spec = test_spec_with(
572 vec![],
573 vec!["rust_decimal::Decimal"],
574 "{\n Decimal::ZERO\n}",
575 );
576
577 let code = generate_code(&spec).unwrap();
578 assert_eq!(
579 code,
580 "use rust_decimal::Decimal;\n\npub fn apply_discount() {\n Decimal::ZERO\n}\n"
581 );
582 }
583
584 #[test]
585 fn deps_unchanged_after_imports_split() {
586 let spec = test_spec_with(
587 vec!["money/round"],
588 vec![],
589 "{\n round(Decimal::ZERO)\n}",
590 );
591
592 let code = generate_code(&spec).unwrap();
593 assert!(code.contains("use crate::money::round::round;"));
594 }
595
596 #[test]
597 fn imports_emitted_before_deps_in_use_statements() {
598 let spec = test_spec_with(
599 vec!["money/round"],
600 vec!["rust_decimal::Decimal"],
601 "{\n round(Decimal::ZERO)\n}",
602 );
603
604 let code = generate_code(&spec).unwrap();
605 assert_eq!(
606 code,
607 "use rust_decimal::Decimal;\n\nuse crate::money::round::round;\n\npub fn apply_discount() {\n round(Decimal::ZERO)\n}\n"
608 );
609 }
610
611 #[test]
612 fn test_write_generated_file_creates_parent_dirs() {
613 let temp_dir = TempDir::new().unwrap();
614 let file_path = temp_dir
615 .path()
616 .join("generated/spec/pricing/apply_discount.rs");
617
618 write_generated_file(
619 file_path.to_str().unwrap(),
620 "pub fn apply_discount() -> Decimal { Decimal::ZERO }\n",
621 )
622 .unwrap();
623
624 let contents = fs::read_to_string(&file_path).unwrap();
625 assert_eq!(
626 contents,
627 "pub fn apply_discount() -> Decimal { Decimal::ZERO }\n"
628 );
629 }
630
631 #[test]
632 fn test_generate_mod_rs_lists_units_and_subdirs() {
633 let content = generate_mod_rs(
634 &["apply_discount.rs".to_string(), "refund.rs".to_string()],
635 &["taxes".to_string(), "discounts".to_string()],
636 )
637 .unwrap();
638
639 assert_eq!(
640 content,
641 "pub mod apply_discount;\npub mod refund;\n\npub mod discounts;\npub mod taxes;\n"
642 );
643 }
644
645 #[test]
646 fn generate_local_tests_produces_cfg_test_block() {
647 let mut spec = test_spec_with(
648 vec![],
649 vec!["rust_decimal::Decimal"],
650 "{\n Decimal::ZERO\n}",
651 );
652 spec.local_tests = vec![LocalTest {
653 id: "happy_path".to_string(),
654 expect: "apply_discount() == Decimal::ZERO".to_string(),
655 }];
656
657 let code = generate_code(&spec).unwrap();
658 assert!(code.contains("#[cfg(test)]\nmod tests {"));
659 assert!(code.contains("use super::*;"));
660 assert!(code.contains("fn test_happy_path() {"));
661 assert!(code.contains("assert!(apply_discount() == Decimal::ZERO);"));
662 }
663
664 #[test]
665 fn generate_no_local_tests_produces_no_test_block() {
666 let spec = test_spec_with(vec![], vec![], "{ }");
667 let code = generate_code(&spec).unwrap();
668 assert!(!code.contains("#[cfg(test)]"));
669 assert!(!code.contains("mod tests {"));
670 }
671
672 #[test]
673 fn generate_code_rejects_unsafe_expect_at_sink() {
674 let mut spec = test_spec_with(vec![], vec![], "{ true }");
675 spec.local_tests = vec![LocalTest {
676 id: "unsafe_attempt".to_string(),
677 expect: "{ let ok = apply_discount(); ok }".to_string(),
678 }];
679
680 let err = generate_code(&spec).unwrap_err().to_string();
681 assert!(err.contains("pricing/apply_discount"), "{err}");
682 assert!(err.contains("unsafe_attempt"), "{err}");
683 assert!(err.contains("block, unsafe, closure"), "{err}");
684 }
685
686 #[test]
687 fn generate_code_rejects_deeply_nested_expect_at_sink() {
688 let mut spec = test_spec_with(vec![], vec![], "{ true }");
689 spec.local_tests = vec![LocalTest {
690 id: "deep".to_string(),
691 expect: format!("{}true{}", "(".repeat(200), ")".repeat(200)),
692 }];
693
694 let err = generate_code(&spec).unwrap_err().to_string();
695 assert!(err.contains("pricing/apply_discount"), "{err}");
696 assert!(err.contains("deep"), "{err}");
697 assert!(err.contains("maximum depth of 128"), "{err}");
698 }
699
700 #[test]
701 fn generate_code_sink_guard_includes_unit_and_test_id_in_error() {
702 let mut spec = test_spec_with(vec![], vec![], "{ true }");
703 spec.local_tests = vec![LocalTest {
704 id: "broken".to_string(),
705 expect: "(".to_string(),
706 }];
707
708 let err = generate_code(&spec).unwrap_err().to_string();
709 assert!(err.contains("pricing/apply_discount"), "{err}");
710 assert!(err.contains("broken"), "{err}");
711 }
712
713 #[test]
714 fn generate_code_with_options_preserves_escape_hatch() {
715 let mut spec = test_spec_with(vec![], vec![], "{ true }");
716 spec.local_tests = vec![LocalTest {
717 id: "unsafe_allowed".to_string(),
718 expect: "{ let ok = apply_discount(); ok }".to_string(),
719 }];
720
721 let code = generate_code_with_options(
722 &spec,
723 &GenerateOptions {
724 allow_unsafe_local_test_expect: true,
725 },
726 )
727 .unwrap();
728
729 assert!(code.contains("assert!({ let ok = apply_discount(); ok });"));
730 }
731
732 #[test]
733 fn shared_expect_validation_reports_too_deep_before_syn_parse() {
734 let result = validate_expect_expr(
735 &format!("{}true{}", "(".repeat(200), ")".repeat(200)),
736 false,
737 );
738 match result {
739 Err(ExpectExprErrorKind::TooDeep { max_depth }) => assert_eq!(max_depth, 128),
740 Err(other) => panic!("expected too-deep error, got {:?}", other),
741 Ok(_) => panic!("expected too-deep error, got success"),
742 }
743 }
744
745 #[test]
746 fn clean_output_dir_removes_stale_module_from_prior_run() {
747 let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
748 let base = temp_dir.path().join("generated/spec");
749 let pricing = base.join("pricing");
750 let test_mod = base.join("test");
751 fs::create_dir_all(&pricing).unwrap();
752 fs::create_dir_all(&test_mod).unwrap();
753 fs::write(base.join(GENERATED_MARKER), "").unwrap();
754
755 fs::write(base.join("mod.rs"), "pub mod pricing;\n").unwrap();
757 fs::write(pricing.join("apply_discount.rs"), "fn a() {}\n").unwrap();
758 fs::write(pricing.join("mod.rs"), "pub mod apply_discount;\n").unwrap();
759
760 fs::write(test_mod.join("foo.rs"), "fn stale() {}\n").unwrap();
762 fs::write(test_mod.join("mod.rs"), "pub mod foo;\n").unwrap();
763
764 let mut generated = HashSet::new();
765 generated.insert(PathBuf::from("pricing/apply_discount.rs"));
766 generated.insert(PathBuf::from("pricing/mod.rs"));
767 generated.insert(PathBuf::from("mod.rs"));
768
769 clean_output_dir(&base, &generated).unwrap();
770
771 assert!(pricing.join("apply_discount.rs").exists());
772 assert!(pricing.join("mod.rs").exists());
773
774 assert!(!test_mod.join("foo.rs").exists());
775 assert!(!test_mod.join("mod.rs").exists());
776 assert!(!test_mod.exists(), "stale module dir should be removed");
777
778 assert!(base.join("mod.rs").exists());
779 assert!(base.join(GENERATED_MARKER).exists());
780 }
781
782 #[test]
783 fn test_clean_output_dir_refuses_without_marker() {
784 let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
785 let base = temp_dir.path().join("generated/spec");
786 fs::create_dir_all(&base).unwrap();
787
788 let generated = HashSet::new();
789 let err = clean_output_dir(&base, &generated).unwrap_err();
790 assert!(matches!(err, SpecError::MissingMarker { .. }));
791 }
792
793 #[test]
794 #[cfg(unix)]
795 fn test_clean_output_dir_does_not_follow_symlink_dirs() {
796 let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
797 let base = temp_dir.path().join("generated/spec");
798 let pricing = base.join("pricing");
799 fs::create_dir_all(&pricing).unwrap();
800 fs::write(base.join(GENERATED_MARKER), "").unwrap();
801 fs::write(pricing.join("apply_discount.rs"), "fn a() {}\n").unwrap();
802
803 let outside_dir = temp_dir.path().join("outside");
804 fs::create_dir_all(&outside_dir).unwrap();
805 let outside_rs = outside_dir.join("outside.rs");
806 fs::write(&outside_rs, "fn outside() {}\n").unwrap();
807
808 unix_fs::symlink(&outside_dir, pricing.join("link")).unwrap();
809
810 let generated = HashSet::new();
811 clean_output_dir(&base, &generated).unwrap();
812
813 assert!(!pricing.join("apply_discount.rs").exists());
814 assert!(
815 outside_rs.exists(),
816 "clean_output_dir must not delete files through symlinks"
817 );
818 assert!(base.join(GENERATED_MARKER).exists());
819 }
820
821 #[test]
822 fn safe_output_path_accepts_existing_path_inside_project_root() {
823 let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
824 let path = temp_dir.path().join("generated/spec");
825 fs::create_dir_all(&path).unwrap();
826
827 let resolved = safe_output_path(&path).unwrap();
828 assert_eq!(resolved, path.canonicalize().unwrap());
829 }
830
831 #[test]
832 fn safe_output_path_resolves_nonexistent_nested_path_from_existing_ancestor() {
833 let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
834 let base = temp_dir.path().join("generated");
835 fs::create_dir_all(&base).unwrap();
836 let path = base.join("spec/pricing");
837
838 let resolved = safe_output_path(&path).unwrap();
839 assert_eq!(
840 resolved,
841 base.canonicalize().unwrap().join("spec").join("pricing")
842 );
843 }
844
845 #[test]
846 #[cfg(unix)]
847 fn safe_output_path_rejects_symlink_escape_in_nonexistent_path() {
848 let temp_dir = TempDir::new_in(std::env::current_dir().unwrap()).unwrap();
849 let outside = tempfile::TempDir::new().unwrap();
850 let link = temp_dir.path().join("escape");
851 unix_fs::symlink(outside.path(), &link).unwrap();
852
853 let err = safe_output_path(link.join("generated/spec")).unwrap_err();
854 assert!(err.to_string().contains("outside the project root"));
855 }
856}