1use std::collections::HashSet;
2use std::fs::File;
3use std::io::{stdin, stdout, Read, Write};
4use std::path::PathBuf;
5
6use clap::Args;
7
8#[derive(thiserror::Error, Debug)]
9pub enum Error {
10 #[error("line {line}: unmatched #endif")]
11 UnmatchedEndif { line: usize },
12 #[error("line {line}: unmatched #else")]
13 UnmatchedElse { line: usize },
14 #[error("line {line}: unmatched #elif")]
15 UnmatchedElif { line: usize },
16 #[error("line {line}: unclosed #ifdef/#ifndef")]
17 UnmatchedIfdef { line: usize },
18 #[error("line {line}: duplicate #else")]
19 DuplicateElse { line: usize },
20 #[error("line {line}: #elif after #else")]
21 ElifAfterElse { line: usize },
22 #[error("line {line}: {directive} missing symbol")]
23 MissingSymbol { line: usize, directive: String },
24 #[error("line {line}: invalid {directive} syntax")]
25 InvalidDirectiveSyntax { line: usize, directive: String },
26 #[error("error reading file: {0}")]
27 ReadFile(std::io::Error),
28 #[error("error writing output: {0}")]
29 WriteOutput(std::io::Error),
30}
31
32#[derive(Args, Debug, Clone)]
33#[command()]
34pub struct Cmd {
35 #[arg()]
37 pub input: Option<PathBuf>,
38
39 #[arg(long, value_delimiter = ',', conflicts_with = "all_features")]
41 pub features: Vec<String>,
42
43 #[arg(long)]
45 pub all_features: bool,
46}
47
48#[allow(clippy::struct_excessive_bools)]
49struct IfBlock {
50 active: bool,
51 parent_active: bool,
52 any_branch_taken: bool,
53 else_seen: bool,
54 start_line: usize,
55}
56
57fn is_comment_or_empty(rest: &str) -> bool {
58 let rest = rest.trim_start();
59 rest.is_empty() || rest.starts_with("//") || rest.starts_with("/*")
60}
61
62fn parse_symbol_directive<'a>(line: &'a str, directive: &str) -> Result<Option<&'a str>, ()> {
63 let Some(rest) = line.strip_prefix(directive) else {
64 return Ok(None);
65 };
66
67 if rest.is_empty() {
68 return Err(());
69 }
70
71 let mut chars = rest.chars();
72 let Some(first) = chars.next() else {
73 return Err(());
74 };
75 if !first.is_whitespace() {
76 return Err(());
77 }
78
79 let after_ws = rest.trim_start();
80 if after_ws.is_empty() || after_ws.starts_with("//") || after_ws.starts_with("/*") {
81 return Err(());
82 }
83
84 let symbol_end = after_ws.find(char::is_whitespace).unwrap_or(after_ws.len());
85 let symbol = &after_ws[..symbol_end];
86 if symbol.is_empty() {
87 return Err(());
88 }
89
90 let trailing = &after_ws[symbol_end..];
91 if !is_comment_or_empty(trailing) {
92 return Err(());
93 }
94
95 Ok(Some(symbol))
96}
97
98fn is_flagless_directive(line: &str, directive: &str) -> Result<bool, ()> {
99 let Some(rest) = line.strip_prefix(directive) else {
100 return Ok(false);
101 };
102
103 if rest.is_empty() {
104 return Ok(true);
105 }
106
107 let first = rest.chars().next().unwrap_or(' ');
108 if !first.is_whitespace() {
109 return Err(());
110 }
111
112 if is_comment_or_empty(rest) {
113 Ok(true)
114 } else {
115 Err(())
116 }
117}
118
119fn symbol_directive_err(trimmed: &str, directive: &str, line: usize) -> Error {
120 if trimmed == directive {
121 Error::MissingSymbol {
122 line,
123 directive: directive.to_string(),
124 }
125 } else {
126 Error::InvalidDirectiveSyntax {
127 line,
128 directive: directive.to_string(),
129 }
130 }
131}
132
133fn flagless_directive_err(directive: &str, line: usize) -> Error {
134 Error::InvalidDirectiveSyntax {
135 line,
136 directive: directive.to_string(),
137 }
138}
139
140fn collect_symbols(input: &str) -> Vec<String> {
143 let mut symbols = Vec::new();
144 let mut seen = HashSet::new();
145 for raw_line in input.split_inclusive('\n') {
146 let trimmed = raw_line.strip_suffix('\n').unwrap_or(raw_line).trim();
147 for directive in &["#ifdef", "#ifndef", "#elif"] {
148 if let Ok(Some(symbol)) = parse_symbol_directive(trimmed, directive) {
149 if seen.insert(symbol.to_string()) {
150 symbols.push(symbol.to_string());
151 }
152 break;
153 }
154 }
155 }
156 symbols
157}
158
159pub fn preprocess(input: &str, features: &[&str]) -> Result<String, Error> {
167 let mut output = String::new();
168 let mut stack: Vec<IfBlock> = Vec::new();
169 let feature_set: HashSet<&str> = features.iter().copied().collect();
170
171 for (i, raw_line) in input.split_inclusive('\n').enumerate() {
172 let line_num = i + 1;
173 let trimmed = raw_line.strip_suffix('\n').unwrap_or(raw_line).trim();
174 process_line(
175 trimmed,
176 line_num,
177 raw_line,
178 &mut stack,
179 &feature_set,
180 &mut output,
181 )?;
182 }
183
184 if let Some(block) = stack.last() {
185 return Err(Error::UnmatchedIfdef {
186 line: block.start_line,
187 });
188 }
189
190 Ok(output)
191}
192
193#[allow(clippy::too_many_lines)]
194fn process_line(
195 trimmed: &str,
196 line_num: usize,
197 raw_line: &str,
198 stack: &mut Vec<IfBlock>,
199 feature_set: &HashSet<&str>,
200 output: &mut String,
201) -> Result<(), Error> {
202 if trimmed.starts_with("#ifdef") {
203 let symbol = parse_symbol_directive(trimmed, "#ifdef")
204 .map_err(|()| symbol_directive_err(trimmed, "#ifdef", line_num))?
205 .ok_or_else(|| Error::InvalidDirectiveSyntax {
206 line: line_num,
207 directive: "#ifdef".to_string(),
208 })?;
209 let parent_active = stack.last().is_none_or(|b| b.active);
210 let active = parent_active && feature_set.contains(symbol);
211 stack.push(IfBlock {
212 active,
213 parent_active,
214 any_branch_taken: active,
215 else_seen: false,
216 start_line: line_num,
217 });
218 } else if trimmed.starts_with("#ifndef") {
219 let symbol = parse_symbol_directive(trimmed, "#ifndef")
220 .map_err(|()| symbol_directive_err(trimmed, "#ifndef", line_num))?
221 .ok_or_else(|| Error::InvalidDirectiveSyntax {
222 line: line_num,
223 directive: "#ifndef".to_string(),
224 })?;
225 let parent_active = stack.last().is_none_or(|b| b.active);
226 let active = parent_active && !feature_set.contains(symbol);
227 stack.push(IfBlock {
228 active,
229 parent_active,
230 any_branch_taken: active,
231 else_seen: false,
232 start_line: line_num,
233 });
234 } else if trimmed.starts_with("#elif") {
235 let symbol = parse_symbol_directive(trimmed, "#elif")
236 .map_err(|()| symbol_directive_err(trimmed, "#elif", line_num))?
237 .ok_or_else(|| Error::InvalidDirectiveSyntax {
238 line: line_num,
239 directive: "#elif".to_string(),
240 })?;
241 let block = stack
242 .last_mut()
243 .ok_or(Error::UnmatchedElif { line: line_num })?;
244 if block.else_seen {
245 return Err(Error::ElifAfterElse { line: line_num });
246 }
247 let newly_active =
248 block.parent_active && !block.any_branch_taken && feature_set.contains(symbol);
249 block.active = newly_active;
250 block.any_branch_taken = block.any_branch_taken || newly_active;
251 } else if trimmed.starts_with("#else") {
252 is_flagless_directive(trimmed, "#else")
253 .map_err(|()| flagless_directive_err("#else", line_num))?;
254 let block = stack
255 .last_mut()
256 .ok_or(Error::UnmatchedElse { line: line_num })?;
257 if block.else_seen {
258 return Err(Error::DuplicateElse { line: line_num });
259 }
260 block.else_seen = true;
261 block.active = block.parent_active && !block.any_branch_taken;
262 } else if trimmed.starts_with("#endif") {
263 is_flagless_directive(trimmed, "#endif")
264 .map_err(|()| flagless_directive_err("#endif", line_num))?;
265 if stack.pop().is_none() {
266 return Err(Error::UnmatchedEndif { line: line_num });
267 }
268 } else if stack.last().is_none_or(|b| b.active) {
269 output.push_str(raw_line);
270 }
271 Ok(())
272}
273
274impl Cmd {
275 pub fn run(&self) -> Result<(), Error> {
281 let mut content = String::new();
282 let mut reader: Box<dyn Read> = match &self.input {
283 Some(path) => Box::new(File::open(path).map_err(Error::ReadFile)?),
284 None => Box::new(stdin()),
285 };
286 reader
287 .read_to_string(&mut content)
288 .map_err(Error::ReadFile)?;
289
290 let collected;
291 let features: Vec<&str> = if self.all_features {
292 collected = collect_symbols(&content);
293 collected.iter().map(String::as_str).collect()
294 } else {
295 self.features.iter().map(String::as_str).collect()
296 };
297 let result = preprocess(&content, &features)?;
298 match stdout().write_all(result.as_bytes()) {
299 Ok(()) => Ok(()),
300 Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()),
301 Err(e) => Err(Error::WriteOutput(e)),
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_no_directives_passthrough() {
312 let input = "line one\nline two\nline three\n";
313 let result = preprocess(input, &[]).unwrap();
314 assert_eq!(result, "line one\nline two\nline three\n");
315 }
316
317 #[test]
318 fn test_ifdef_defined_includes_content() {
319 let input = "\
320#ifdef FEAT_A
321included
322#endif
323";
324 let result = preprocess(input, &["FEAT_A"]).unwrap();
325 assert_eq!(result, "included\n");
326 }
327
328 #[test]
329 fn test_ifdef_undefined_excludes_content() {
330 let input = "\
331#ifdef FEAT_A
332excluded
333#endif
334";
335 let result = preprocess(input, &[]).unwrap();
336 assert_eq!(result, "");
337 }
338
339 #[test]
340 fn test_ifdef_else_defined() {
341 let input = "\
342#ifdef FEAT_A
343yes
344#else
345no
346#endif
347";
348 let result = preprocess(input, &["FEAT_A"]).unwrap();
349 assert_eq!(result, "yes\n");
350 }
351
352 #[test]
353 fn test_ifdef_else_undefined() {
354 let input = "\
355#ifdef FEAT_A
356yes
357#else
358no
359#endif
360";
361 let result = preprocess(input, &[]).unwrap();
362 assert_eq!(result, "no\n");
363 }
364
365 #[test]
366 fn test_ifndef_defined_excludes() {
367 let input = "\
368#ifndef FEAT_A
369included
370#endif
371";
372 let result = preprocess(input, &["FEAT_A"]).unwrap();
373 assert_eq!(result, "");
374 }
375
376 #[test]
377 fn test_ifndef_undefined_includes() {
378 let input = "\
379#ifndef FEAT_A
380included
381#endif
382";
383 let result = preprocess(input, &[]).unwrap();
384 assert_eq!(result, "included\n");
385 }
386
387 #[test]
388 fn test_nested_ifdefs() {
389 let input = "\
390#ifdef OUTER
391outer_content
392#ifdef INNER
393inner_content
394#endif
395after_inner
396#endif
397";
398 let result = preprocess(input, &["OUTER", "INNER"]).unwrap();
400 assert_eq!(result, "outer_content\ninner_content\nafter_inner\n");
401
402 let result = preprocess(input, &["OUTER"]).unwrap();
404 assert_eq!(result, "outer_content\nafter_inner\n");
405
406 let result = preprocess(input, &[]).unwrap();
408 assert_eq!(result, "");
409 }
410
411 #[test]
412 fn test_nested_ifdef_parent_inactive() {
413 let input = "\
414#ifdef OUTER
415#ifdef INNER
416should_not_appear
417#endif
418#endif
419";
420 let result = preprocess(input, &["INNER"]).unwrap();
422 assert_eq!(result, "");
423 }
424
425 #[test]
426 fn test_error_unmatched_endif() {
427 let input = "#endif\n";
428 let err = preprocess(input, &[]).unwrap_err();
429 assert!(matches!(err, Error::UnmatchedEndif { line: 1 }));
430 }
431
432 #[test]
433 fn test_error_unclosed_ifdef() {
434 let input = "#ifdef FEAT_A\ncontent\n";
435 let err = preprocess(input, &[]).unwrap_err();
436 assert!(matches!(err, Error::UnmatchedIfdef { line: 1 }));
437 }
438
439 #[test]
440 fn test_error_duplicate_else() {
441 let input = "\
442#ifdef FEAT_A
443a
444#else
445b
446#else
447c
448#endif
449";
450 let err = preprocess(input, &[]).unwrap_err();
451 assert!(matches!(err, Error::DuplicateElse { line: 5 }));
452 }
453
454 #[test]
455 fn test_error_ifdef_missing_symbol() {
456 let input = "#ifdef\n#endif\n";
457 let err = preprocess(input, &[]).unwrap_err();
458 assert!(matches!(
459 err,
460 Error::MissingSymbol {
461 line: 1,
462 directive: _
463 }
464 ));
465 }
466
467 #[test]
468 fn test_error_ifndef_missing_symbol() {
469 let input = "#ifndef\n#endif\n";
470 let err = preprocess(input, &[]).unwrap_err();
471 assert!(matches!(
472 err,
473 Error::MissingSymbol {
474 line: 1,
475 directive: _
476 }
477 ));
478 }
479
480 #[test]
481 fn test_realistic_xdr_pattern() {
482 let input = "\
483enum SCValType
484{
485 SCV_BOOL = 0,
486 SCV_VOID = 1,
487 SCV_ERROR = 2,
488#ifdef XDR_SPARSE_MAP
489 SCV_SPARSE_MAP = 22,
490#endif
491 SCV_U32 = 4
492};
493
494#ifndef XDR_SPARSE_MAP
495struct FallbackDef {
496 int x;
497};
498#endif
499";
500 let result = preprocess(input, &["XDR_SPARSE_MAP"]).unwrap();
502 assert_eq!(
503 result,
504 "\
505enum SCValType
506{
507 SCV_BOOL = 0,
508 SCV_VOID = 1,
509 SCV_ERROR = 2,
510 SCV_SPARSE_MAP = 22,
511 SCV_U32 = 4
512};
513
514"
515 );
516
517 let result = preprocess(input, &[]).unwrap();
519 assert_eq!(
520 result,
521 "\
522enum SCValType
523{
524 SCV_BOOL = 0,
525 SCV_VOID = 1,
526 SCV_ERROR = 2,
527 SCV_U32 = 4
528};
529
530struct FallbackDef {
531 int x;
532};
533"
534 );
535 }
536
537 #[test]
538 fn test_multiple_features() {
539 let input = "\
540#ifdef FEAT_A
541a
542#endif
543#ifdef FEAT_B
544b
545#endif
546";
547 let result = preprocess(input, &["FEAT_A", "FEAT_B"]).unwrap();
548 assert_eq!(result, "a\nb\n");
549
550 let result = preprocess(input, &["FEAT_A"]).unwrap();
551 assert_eq!(result, "a\n");
552 }
553
554 #[test]
555 fn test_empty_input() {
556 let result = preprocess("", &[]).unwrap();
557 assert_eq!(result, "");
558 }
559
560 #[test]
561 fn test_preserves_indentation() {
562 let input = "\
563#ifdef FEAT
564 indented line
565 deeply indented
566#endif
567";
568 let result = preprocess(input, &["FEAT"]).unwrap();
569 assert_eq!(result, " indented line\n deeply indented\n");
570 }
571
572 #[test]
573 fn test_error_invalid_ifdef_syntax_missing_space() {
574 let input = "#ifdefFEAT_A\ncontent\n#endif\n";
575 let err = preprocess(input, &["FEAT_A"]).unwrap_err();
576 assert!(matches!(
577 err,
578 Error::InvalidDirectiveSyntax {
579 line: 1,
580 directive: _
581 }
582 ));
583 }
584
585 #[test]
586 fn test_error_invalid_endif_syntax_missing_space() {
587 let input = "#ifdef FEAT_A\ncontent\n#endifX\n";
588 let err = preprocess(input, &["FEAT_A"]).unwrap_err();
589 assert!(matches!(
590 err,
591 Error::InvalidDirectiveSyntax {
592 line: 3,
593 directive: _
594 }
595 ));
596 }
597
598 #[test]
599 fn test_else_with_trailing_comment() {
600 let input = "\
601#ifdef FEAT_A
602yes
603#else // fallback
604no
605#endif
606";
607 let result = preprocess(input, &["FEAT_A"]).unwrap();
608 assert_eq!(result, "yes\n");
609
610 let result = preprocess(input, &[]).unwrap();
611 assert_eq!(result, "no\n");
612 }
613
614 #[test]
615 fn test_preserves_missing_final_newline() {
616 let result = preprocess("A", &[]).unwrap();
617 assert_eq!(result, "A");
618 }
619
620 #[test]
621 fn test_preserves_crlf_newlines() {
622 let input = "A\r\nB\r\n";
623 let result = preprocess(input, &[]).unwrap();
624 assert_eq!(result.as_bytes(), input.as_bytes());
625 }
626
627 #[test]
628 fn test_elif_first_branch() {
629 let input = "\
630#ifdef A
631a
632#elif B
633b
634#elif C
635c
636#else
637d
638#endif
639";
640 let result = preprocess(input, &["A", "B", "C"]).unwrap();
641 assert_eq!(result, "a\n");
642 }
643
644 #[test]
645 fn test_elif_second_branch() {
646 let input = "\
647#ifdef A
648a
649#elif B
650b
651#elif C
652c
653#else
654d
655#endif
656";
657 let result = preprocess(input, &["B"]).unwrap();
658 assert_eq!(result, "b\n");
659 }
660
661 #[test]
662 fn test_elif_third_branch() {
663 let input = "\
664#ifdef A
665a
666#elif B
667b
668#elif C
669c
670#else
671d
672#endif
673";
674 let result = preprocess(input, &["C"]).unwrap();
675 assert_eq!(result, "c\n");
676 }
677
678 #[test]
679 fn test_elif_else_fallback() {
680 let input = "\
681#ifdef A
682a
683#elif B
684b
685#else
686fallback
687#endif
688";
689 let result = preprocess(input, &[]).unwrap();
690 assert_eq!(result, "fallback\n");
691 }
692
693 #[test]
694 fn test_elif_no_else() {
695 let input = "\
696#ifdef A
697a
698#elif B
699b
700#endif
701";
702 let result = preprocess(input, &[]).unwrap();
704 assert_eq!(result, "");
705
706 let result = preprocess(input, &["B"]).unwrap();
708 assert_eq!(result, "b\n");
709 }
710
711 #[test]
712 fn test_elif_nested() {
713 let input = "\
714#ifdef OUTER
715#ifdef A
716a
717#elif B
718b
719#endif
720#endif
721";
722 let result = preprocess(input, &["B"]).unwrap();
724 assert_eq!(result, "");
725
726 let result = preprocess(input, &["OUTER", "B"]).unwrap();
728 assert_eq!(result, "b\n");
729 }
730
731 #[test]
732 fn test_error_elif_after_else() {
733 let input = "\
734#ifdef A
735a
736#else
737b
738#elif C
739c
740#endif
741";
742 let err = preprocess(input, &[]).unwrap_err();
743 assert!(matches!(err, Error::ElifAfterElse { line: 5 }));
744 }
745
746 #[test]
747 fn test_error_elif_missing_symbol() {
748 let input = "#ifdef A\n#elif\n#endif\n";
749 let err = preprocess(input, &[]).unwrap_err();
750 assert!(matches!(
751 err,
752 Error::MissingSymbol {
753 line: 2,
754 directive: _
755 }
756 ));
757 }
758
759 #[test]
760 fn test_error_elif_without_ifdef() {
761 let input = "#elif A\n";
762 let err = preprocess(input, &[]).unwrap_err();
763 assert!(matches!(err, Error::UnmatchedElif { line: 1 }));
764 }
765
766 #[test]
767 fn test_error_ifdef_comment_as_symbol() {
768 let input = "#ifdef // comment\n#endif\n";
769 let err = preprocess(input, &[]).unwrap_err();
770 assert!(matches!(
771 err,
772 Error::InvalidDirectiveSyntax {
773 line: 1,
774 directive: _
775 }
776 ));
777 }
778
779 #[test]
780 fn test_error_endif_trailing_tokens() {
781 let input = "#ifdef FEAT_A\ncontent\n#endif EXTRA\n";
782 let err = preprocess(input, &["FEAT_A"]).unwrap_err();
783 assert!(matches!(
784 err,
785 Error::InvalidDirectiveSyntax {
786 line: 3,
787 directive: _
788 }
789 ));
790 }
791
792 #[test]
793 fn test_error_else_trailing_tokens() {
794 let input = "#ifdef FEAT_A\na\n#else EXTRA\nb\n#endif\n";
795 let err = preprocess(input, &[]).unwrap_err();
796 assert!(matches!(
797 err,
798 Error::InvalidDirectiveSyntax {
799 line: 3,
800 directive: _
801 }
802 ));
803 }
804
805 #[test]
806 fn test_collect_symbols_basic() {
807 let input = "\
808#ifdef FEAT_A
809a
810#endif
811#ifdef FEAT_B
812b
813#endif
814";
815 let symbols = collect_symbols(input);
816 assert_eq!(symbols, vec!["FEAT_A", "FEAT_B"]);
817 }
818
819 #[test]
820 fn test_collect_symbols_deduplicates() {
821 let input = "\
822#ifdef FEAT_A
823a
824#endif
825#ifdef FEAT_A
826b
827#endif
828";
829 let symbols = collect_symbols(input);
830 assert_eq!(symbols, vec!["FEAT_A"]);
831 }
832
833 #[test]
834 fn test_collect_symbols_all_directive_types() {
835 let input = "\
836#ifdef A
837a
838#elif B
839b
840#endif
841#ifndef C
842c
843#endif
844";
845 let symbols = collect_symbols(input);
846 assert_eq!(symbols, vec!["A", "B", "C"]);
847 }
848
849 #[test]
850 fn test_collect_symbols_empty() {
851 let symbols = collect_symbols("no directives here\n");
852 assert!(symbols.is_empty());
853 }
854
855 #[test]
856 fn test_all_features_enables_everything() {
857 let input = "\
858#ifdef FEAT_A
859a
860#endif
861#ifdef FEAT_B
862b
863#endif
864#ifndef FEAT_C
865c
866#endif
867";
868 let all = collect_symbols(input);
869 let features: Vec<&str> = all.iter().map(String::as_str).collect();
870 let result = preprocess(input, &features).unwrap();
871 assert_eq!(result, "a\nb\n");
872 }
873
874 #[test]
875 fn test_ifndef_with_elif() {
876 let input = "\
877#ifndef A
878a
879#elif B
880b
881#else
882fallback
883#endif
884";
885 let result = preprocess(input, &[]).unwrap();
887 assert_eq!(result, "a\n");
888
889 let result = preprocess(input, &["B"]).unwrap();
891 assert_eq!(result, "a\n");
892
893 let result = preprocess(input, &["A", "B"]).unwrap();
895 assert_eq!(result, "b\n");
896
897 let result = preprocess(input, &["A"]).unwrap();
899 assert_eq!(result, "fallback\n");
900 }
901}