1use std::collections::HashMap;
2use std::fmt;
3use std::path::Path;
4
5use crate::parser::AstParser;
6
7type TestDetector = fn(&mut dyn AstParser, &Path, &str, usize, usize) -> bool;
8
9pub struct LanguageConfig {
10 language: tree_sitter::Language,
11 roles: HashMap<&'static str, NodeRole>,
12 test_detector: TestDetector,
13}
14
15impl fmt::Debug for LanguageConfig {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 f.debug_struct("LanguageConfig")
18 .field("roles", &self.roles)
19 .finish_non_exhaustive()
20 }
21}
22
23impl LanguageConfig {
24 fn new(
25 language: tree_sitter::Language,
26 table: &[(NodeRole, &[&'static str])],
27 test_detector: TestDetector,
28 ) -> Self {
29 let mut roles = HashMap::new();
30 for &(role, kinds) in table {
31 for &kind in kinds {
32 roles.insert(kind, role);
33 }
34 }
35 Self {
36 language,
37 roles,
38 test_detector,
39 }
40 }
41
42 #[must_use]
43 pub fn language(&self) -> &tree_sitter::Language {
44 &self.language
45 }
46
47 #[must_use]
48 pub fn classify(&self, kind: &str) -> NodeRole {
49 self.roles.get(kind).copied().unwrap_or(NodeRole::Other)
50 }
51
52 #[must_use]
57 pub fn is_test_context(
58 &self,
59 parser: &mut dyn AstParser,
60 path: &Path,
61 content: &str,
62 start_line: usize,
63 end_line: usize,
64 ) -> bool {
65 (self.test_detector)(parser, path, content, start_line, end_line)
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq)]
70pub enum NodeRole {
71 Identifier,
72 Literal,
73 Comment,
74 Decoration,
75 Other,
76}
77
78fn rust_is_test(
79 parser: &mut dyn AstParser,
80 path: &Path,
81 content: &str,
82 start_line: usize,
83 end_line: usize,
84) -> bool {
85 if path.components().any(|c| c.as_os_str() == "tests") {
86 return true;
87 }
88 let ranges = rust_test_ranges(parser, content);
89 ranges
90 .iter()
91 .any(|&(range_start, range_end)| start_line >= range_start && end_line <= range_end)
92}
93
94fn rust_test_ranges(parser: &mut dyn AstParser, content: &str) -> Vec<(usize, usize)> {
96 let Ok(tree) = parser.parse(content, &tree_sitter_rust::LANGUAGE.into()) else {
97 return vec![];
98 };
99
100 let src = content.as_bytes();
101 let mut ranges = vec![];
102 collect_test_ranges(tree.root_node(), src, &mut ranges);
103 ranges
104}
105
106fn collect_test_ranges(parent: tree_sitter::Node, src: &[u8], ranges: &mut Vec<(usize, usize)>) {
107 let mut cursor = parent.walk();
108 for node in parent.children(&mut cursor) {
109 match node.kind() {
110 "mod_item" if has_preceding_attr(&node, src, is_cfg_test_attr) => {
111 push_range_with_attrs(&node, ranges);
112 }
113 "mod_item" => recurse_into_mod_body(node, src, ranges),
114 "function_item" if has_preceding_attr(&node, src, |t| t == "#[test]") => {
115 push_range_with_attrs(&node, ranges);
116 }
117 _ => {}
118 }
119 }
120}
121
122fn push_range_with_attrs(node: &tree_sitter::Node, ranges: &mut Vec<(usize, usize)>) {
123 let start = first_preceding_attr_row(node).unwrap_or(node.start_position().row);
124 ranges.push((start + 1, node.end_position().row + 1));
125}
126
127fn recurse_into_mod_body(node: tree_sitter::Node, src: &[u8], ranges: &mut Vec<(usize, usize)>) {
128 if let Some(body) = node.child_by_field_name("body") {
129 collect_test_ranges(body, src, ranges);
130 }
131}
132
133fn is_cfg_test_attr(attr_text: &str) -> bool {
136 attr_text.starts_with("#[cfg(")
137 && !attr_text.contains("not(test)")
138 && (attr_text == "#[cfg(test)]"
139 || attr_text.contains("(test,")
140 || attr_text.contains("(test)")
141 || attr_text.contains(", test)")
142 || attr_text.contains(", test,"))
143}
144
145fn has_preceding_attr(node: &tree_sitter::Node, src: &[u8], pred: impl Fn(&str) -> bool) -> bool {
146 let mut sibling = node.prev_sibling();
147 while let Some(s) = sibling {
148 if s.kind() != "attribute_item" {
149 break;
150 }
151 if s.utf8_text(src).is_ok_and(&pred) {
152 return true;
153 }
154 sibling = s.prev_sibling();
155 }
156 false
157}
158
159fn first_preceding_attr_row(node: &tree_sitter::Node) -> Option<usize> {
160 let mut first_row = None;
161 let mut sibling = node.prev_sibling();
162 while let Some(s) = sibling {
163 if s.kind() != "attribute_item" {
164 break;
165 }
166 first_row = Some(s.start_position().row);
167 sibling = s.prev_sibling();
168 }
169 first_row
170}
171
172fn js_is_test(
174 _parser: &mut dyn AstParser,
175 path: &Path,
176 _content: &str,
177 _start_line: usize,
178 _end_line: usize,
179) -> bool {
180 let stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("");
181 Path::new(stem)
182 .extension()
183 .is_some_and(|ext| ext == "test" || ext == "spec")
184 || path.components().any(|c| c.as_os_str() == "__tests__")
185}
186
187#[must_use]
188pub fn rust() -> LanguageConfig {
189 LanguageConfig::new(
190 tree_sitter_rust::LANGUAGE.into(),
191 &[
192 (
193 NodeRole::Identifier,
194 &[
195 "identifier",
196 "type_identifier",
197 "field_identifier",
198 "shorthand_field_identifier",
199 "primitive_type",
200 "lifetime",
201 "self",
202 "metavariable",
203 "crate",
204 "super",
205 ],
206 ),
207 (
208 NodeRole::Literal,
209 &[
210 "string_literal",
211 "raw_string_literal",
212 "char_literal",
213 "integer_literal",
214 "float_literal",
215 "boolean_literal",
216 ],
217 ),
218 (NodeRole::Comment, &["line_comment", "block_comment"]),
219 (
220 NodeRole::Decoration,
221 &["attribute_item", "inner_attribute_item"],
222 ),
223 ],
224 rust_is_test,
225 )
226}
227
228const TS_ROLES: &[(NodeRole, &[&str])] = &[
229 (
230 NodeRole::Identifier,
231 &[
232 "identifier",
233 "shorthand_property_identifier",
234 "shorthand_property_identifier_pattern",
235 "property_identifier",
236 "type_identifier",
237 "predefined_type",
238 ],
239 ),
240 (
241 NodeRole::Literal,
242 &[
243 "string",
244 "template_string",
245 "number",
246 "true",
247 "false",
248 "null",
249 "undefined",
250 "regex",
251 ],
252 ),
253 (NodeRole::Comment, &["comment"]),
254 (NodeRole::Decoration, &["decorator"]),
255];
256
257#[must_use]
258pub fn javascript() -> LanguageConfig {
259 LanguageConfig::new(
260 tree_sitter_javascript::LANGUAGE.into(),
261 &[
262 (
263 NodeRole::Identifier,
264 &[
265 "identifier",
266 "shorthand_property_identifier",
267 "shorthand_property_identifier_pattern",
268 "property_identifier",
269 ],
270 ),
271 (
272 NodeRole::Literal,
273 &[
274 "string",
275 "template_string",
276 "number",
277 "true",
278 "false",
279 "null",
280 "undefined",
281 "regex",
282 ],
283 ),
284 (NodeRole::Comment, &["comment"]),
285 (NodeRole::Decoration, &["decorator"]),
286 ],
287 js_is_test,
288 )
289}
290
291#[must_use]
292pub fn typescript() -> LanguageConfig {
293 LanguageConfig::new(
294 tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
295 TS_ROLES,
296 js_is_test,
297 )
298}
299
300#[must_use]
301pub fn typescript_tsx() -> LanguageConfig {
302 LanguageConfig::new(
303 tree_sitter_typescript::LANGUAGE_TSX.into(),
304 TS_ROLES,
305 js_is_test,
306 )
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::parser::TreeSitterParser;
313
314 fn parse_rust_test_ranges(src: &str) -> Vec<(usize, usize)> {
315 let mut parser = TreeSitterParser::new();
316 rust_test_ranges(&mut parser, src)
317 }
318
319 #[test]
320 fn rust_test_ranges_finds_cfg_test_module() {
321 let src = "\
322fn production() -> i32 { 42 }
323
324#[cfg(test)]
325mod tests {
326 fn helper(x: i32) -> i32 { x + 1 }
327}
328";
329 assert_eq!(parse_rust_test_ranges(src), vec![(3, 6)]);
330 }
331
332 #[test]
333 fn detects_naked_test_fn_as_test_context() {
334 let src = "\
335fn production() -> i32 { 42 }
336
337#[test]
338fn test_something() {
339 let x = production();
340 assert_eq!(x, 42);
341}
342";
343 assert_eq!(parse_rust_test_ranges(src), vec![(3, 7)]);
344 }
345
346 #[test]
347 fn walks_past_multiple_attributes_to_find_test() {
348 let src = "\
349#[test]
350#[should_panic]
351fn test_something() {
352 panic!(\"expected\");
353}
354";
355 assert_eq!(parse_rust_test_ranges(src), vec![(1, 5)]);
356 }
357
358 #[test]
359 fn rejects_cfg_not_test_module() {
360 let src = "\
361#[cfg(not(test))]
362mod prod_only {
363 fn helper() -> i32 { 42 }
364}
365";
366 assert!(parse_rust_test_ranges(src).is_empty());
367 }
368
369 #[test]
370 fn detects_compound_cfg_test_as_test_context() {
371 let src = "\
372#[cfg(all(test, feature = \"integration\"))]
373mod integration_tests {
374 fn helper(x: i32) -> i32 { x + 1 }
375}
376";
377 assert_eq!(parse_rust_test_ranges(src), vec![(1, 4)]);
378 }
379
380 #[test]
381 fn finds_test_fn_nested_in_non_test_module() {
382 let src = "\
383mod integration {
384 #[test]
385 fn test_flow() {
386 assert!(true);
387 }
388}
389";
390 assert_eq!(parse_rust_test_ranges(src), vec![(2, 5)]);
391 }
392}