1use markdown::{
2 mdast::{Heading, Node},
3 to_mdast, ParseOptions,
4};
5use std::mem::take;
6
7pub trait MergeSerialized {
8 fn merge_serialized(&self, source: String) -> Result<Self, String>
9 where
10 Self: Sized;
11}
12
13struct Section<Options: MergeSerialized> {
14 pub depth: u8,
15 pub name: String,
16 pub line: usize,
17 pub options: Options,
18}
19
20struct SectionStack<Options: MergeSerialized + Clone> {
21 root_options: Options,
22 sections: Vec<Section<Options>>,
23}
24
25impl<Options: MergeSerialized + Clone> SectionStack<Options> {
26 pub fn new(root_options: Options) -> Self {
27 Self {
28 root_options,
29 sections: Vec::<Section<Options>>::new(),
30 }
31 }
32
33 pub fn push_heading(&mut self, heading: Heading) {
34 let Node::Text(text) = heading.children.into_iter().nth(0).unwrap() else {
35 panic!("Markdown headings must contain plain text.")
36 };
37 let depth = heading.depth;
38 self.sections.retain(|s| s.depth < depth);
39 let section = Section {
40 depth,
41 line: heading.position.unwrap().start.line,
42 name: text.value,
43 options: self.get_options().clone(),
44 };
45 self.sections.push(section);
46 }
47
48 pub fn set_options(&mut self, options: Options) {
49 if let Some(last_section) = self.sections.last_mut() {
50 last_section.options = options;
51 } else {
52 self.root_options = options;
53 }
54 }
55
56 pub fn get_options(&self) -> &Options {
57 self.sections
58 .last()
59 .map(|s| &s.options)
60 .unwrap_or_else(|| &self.root_options)
61 }
62
63 pub fn get_headings(&self) -> Vec<String> {
64 self.sections.iter().map(|s| s.name.clone()).collect()
65 }
66}
67
68#[derive(Debug, Default, PartialEq, Eq)]
69pub struct TestCase<Options: MergeSerialized> {
70 pub name: String,
71 pub headings: Vec<String>,
72 pub line_number: usize,
73 pub options: Options,
74 pub args: Vec<String>,
75}
76
77impl<Options: MergeSerialized + Clone> TestCase<Options> {
78 fn new(args: Vec<String>, section_stack: &SectionStack<Options>) -> TestCase<Options> {
79 let options = section_stack.get_options().clone();
80 let mut headings = section_stack.get_headings();
81 let name = headings
82 .pop()
83 .unwrap_or_else(|| "(Unnamed test)".to_string());
84 TestCase {
85 name,
86 headings,
87 line_number: section_stack.sections.last().map(|s| s.line).unwrap_or(0),
88 options,
89 args,
90 }
91 }
92}
93
94pub fn get_test_cases<Options: MergeSerialized + Clone>(
95 content: String,
96 root_options: Options,
97) -> Vec<TestCase<Options>> {
98 let ast = to_mdast(&content, &ParseOptions::default()).unwrap();
99 let Node::Root(root_node) = ast else {
100 panic!("No root node found")
101 };
102 let nodes = root_node.children;
103 let mut section_stack = SectionStack::new(root_options);
104 let mut test_cases: Vec<TestCase<Options>> = vec![];
105 let mut args: Vec<String> = vec![];
106 let mut push_test_case = |s: &SectionStack<Options>, a: &mut Vec<String>| {
107 if a.len() > 0 {
108 test_cases.push(TestCase::new(take(a), &s));
109 }
110 };
111 for node in nodes {
112 match node {
113 Node::Heading(heading) => {
114 push_test_case(§ion_stack, &mut args);
115 section_stack.push_heading(heading);
116 }
117 Node::Code(code) => {
118 if code.meta.as_deref() == Some("options") {
119 let options = section_stack
120 .get_options()
121 .merge_serialized(code.value)
122 .unwrap_or_else(|error| {
123 let line = code.position.unwrap().start.line;
124 panic!(
125 "Failed to parse options from code block at line {}: {}",
126 line, error
127 );
128 });
129 section_stack.set_options(options)
130 } else {
131 args.push(code.value)
132 }
133 }
134 _ => {}
135 }
136 }
137 push_test_case(§ion_stack, &mut args);
138 test_cases
139}
140
141#[cfg(test)]
142mod tests {
143 use crate::{get_test_cases, MergeSerialized, TestCase};
144 use std::path::PathBuf;
145 use toml::{from_str, Table};
146
147 #[derive(Default, PartialEq, Eq, Debug, Clone, Copy)]
148 struct Options {
149 foo: i64,
150 bar: bool,
151 }
152
153 impl MergeSerialized for Options {
154 fn merge_serialized(&self, source: String) -> Result<Self, String> {
155 let values = from_str::<Table>(&source).map_err(|e| e.to_string())?;
156 Ok(Options {
157 foo: values
158 .get("foo")
159 .and_then(|v| v.as_integer())
160 .unwrap_or(self.foo),
161 bar: values
162 .get("bar")
163 .and_then(|v| v.as_bool())
164 .unwrap_or(self.bar),
165 })
166 }
167 }
168
169 #[test]
170 fn test_basic() {
171 let path = PathBuf::from_iter([env!("CARGO_MANIFEST_DIR"), "src", "test.md"]);
172 let content = std::fs::read_to_string(path).unwrap();
173 let result = get_test_cases(content, Options::default());
174 let expected = [
175 TestCase {
176 name: "Apple".to_owned(),
177 headings: vec!["Tests".to_owned(), "Fruits".to_owned()],
178 line_number: 10,
179 options: Options { foo: 5, bar: true },
180 args: vec!["Granny Smith".to_owned(), "red".to_owned()],
181 },
182 TestCase {
183 name: "Pear".to_owned(),
184 headings: vec!["Tests".to_owned(), "Fruits".to_owned()],
185 line_number: 20,
186 options: Options { foo: 5, bar: false },
187 args: vec!["Bartlett".to_owned(), "yellow".to_owned()],
188 },
189 TestCase {
190 name: "Potato".to_owned(),
191 headings: vec!["Tests".to_owned(), "Vegetables".to_owned()],
192 line_number: 40,
193 options: Options { foo: 11, bar: true },
194 args: vec!["Russet".to_owned(), "brown".to_owned()],
195 },
196 ];
197 assert_eq!(result, expected);
198 }
199}