1use std::{fs, path::Path};
2
3use anyhow::{anyhow, Result};
4use tree_sitter::Point;
5use tree_sitter_highlight::{Highlight, HighlightConfiguration, HighlightEvent, Highlighter};
6use tree_sitter_loader::{Config, Loader};
7
8use crate::{
9 query_testing::{parse_position_comments, to_utf8_point, Assertion, Utf8Point},
10 test::{TestInfo, TestOutcome, TestResult, TestSummary},
11 util,
12};
13
14#[derive(Debug)]
15pub struct Failure {
16 row: usize,
17 column: usize,
18 expected_highlight: String,
19 actual_highlights: Vec<String>,
20}
21
22impl std::error::Error for Failure {}
23
24impl std::fmt::Display for Failure {
25 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
26 write!(
27 f,
28 "Failure - row: {}, column: {}, expected highlight '{}', actual highlights: ",
29 self.row, self.column, self.expected_highlight
30 )?;
31 if self.actual_highlights.is_empty() {
32 write!(f, "none.")?;
33 } else {
34 for (i, actual_highlight) in self.actual_highlights.iter().enumerate() {
35 if i > 0 {
36 write!(f, ", ")?;
37 }
38 write!(f, "'{actual_highlight}'")?;
39 }
40 }
41 Ok(())
42 }
43}
44
45pub fn test_highlights(
46 loader: &Loader,
47 loader_config: &Config,
48 highlighter: &mut Highlighter,
49 directory: &Path,
50 test_summary: &mut TestSummary,
51) -> Result<()> {
52 let mut failed = false;
53
54 for highlight_test_file in fs::read_dir(directory)? {
55 let highlight_test_file = highlight_test_file?;
56 let test_file_path = highlight_test_file.path();
57 let test_file_name = highlight_test_file.file_name();
58 if test_file_path.is_dir() && test_file_path.read_dir()?.next().is_some() {
59 test_summary
60 .highlight_results
61 .add_group(test_file_name.to_string_lossy().as_ref());
62 if test_highlights(
63 loader,
64 loader_config,
65 highlighter,
66 &test_file_path,
67 test_summary,
68 )
69 .is_err()
70 {
71 failed = true;
72 }
73 test_summary.highlight_results.pop_traversal();
74 } else {
75 let (language, language_config) = loader
76 .language_configuration_for_file_name(&test_file_path)?
77 .ok_or_else(|| {
78 anyhow!(
79 "{}",
80 util::lang_not_found_for_path(test_file_path.as_path(), loader_config)
81 )
82 })?;
83 let highlight_config = language_config
84 .highlight_config(language, None)?
85 .ok_or_else(|| {
86 anyhow!(
87 "No highlighting config found for {}",
88 test_file_path.display()
89 )
90 })?;
91 match test_highlight(
92 loader,
93 highlighter,
94 highlight_config,
95 fs::read(&test_file_path)?.as_slice(),
96 ) {
97 Ok(assertion_count) => {
98 test_summary.highlight_results.add_case(TestResult {
99 name: test_file_name.to_string_lossy().to_string(),
100 info: TestInfo::AssertionTest {
101 outcome: TestOutcome::AssertionPassed { assertion_count },
102 test_num: test_summary.test_num,
103 },
104 });
105 }
106 Err(e) => {
107 test_summary.highlight_results.add_case(TestResult {
108 name: test_file_name.to_string_lossy().to_string(),
109 info: TestInfo::AssertionTest {
110 outcome: TestOutcome::AssertionFailed {
111 error: e.to_string(),
112 },
113 test_num: test_summary.test_num,
114 },
115 });
116 failed = true;
117 }
118 }
119 test_summary.test_num += 1;
120 }
121 }
122
123 if failed {
124 Err(anyhow!(""))
125 } else {
126 Ok(())
127 }
128}
129pub fn iterate_assertions(
130 assertions: &[Assertion],
131 highlights: &[(Utf8Point, Utf8Point, Highlight)],
132 highlight_names: &[String],
133) -> Result<usize> {
134 let mut i = 0;
137 let mut actual_highlights = Vec::new();
138 for Assertion {
139 position,
140 length,
141 negative,
142 expected_capture_name: expected_highlight,
143 } in assertions
144 {
145 let mut passed = false;
146 let mut end_column = position.column + length - 1;
147 actual_highlights.clear();
148
149 'highlight_loop: while let Some(highlight) = highlights.get(i) {
152 if highlight.1 <= *position {
153 i += 1;
154 continue;
155 }
156
157 let mut j = i;
160 while let (false, Some(highlight)) = (passed, highlights.get(j)) {
161 end_column = position.column + length - 1;
162 if highlight.0.row >= position.row && highlight.0.column > end_column {
163 break 'highlight_loop;
164 }
165
166 let highlight_name = &highlight_names[(highlight.2).0];
172 if (*highlight_name == *expected_highlight) == *negative {
173 actual_highlights.push(highlight_name);
174 } else {
175 passed = true;
176 break 'highlight_loop;
177 }
178
179 j += 1;
180 }
181 }
182
183 if !passed {
184 return Err(Failure {
185 row: position.row,
186 column: end_column,
187 expected_highlight: expected_highlight.clone(),
188 actual_highlights: actual_highlights.into_iter().cloned().collect(),
189 }
190 .into());
191 }
192 }
193
194 Ok(assertions.len())
195}
196
197pub fn test_highlight(
198 loader: &Loader,
199 highlighter: &mut Highlighter,
200 highlight_config: &HighlightConfiguration,
201 source: &[u8],
202) -> Result<usize> {
203 let highlight_names = loader.highlight_names();
205 let highlights = get_highlight_positions(loader, highlighter, highlight_config, source)?;
206 let assertions =
207 parse_position_comments(highlighter.parser(), &highlight_config.language, source)?;
208
209 iterate_assertions(&assertions, &highlights, &highlight_names)
210}
211
212pub fn get_highlight_positions(
213 loader: &Loader,
214 highlighter: &mut Highlighter,
215 highlight_config: &HighlightConfiguration,
216 source: &[u8],
217) -> Result<Vec<(Utf8Point, Utf8Point, Highlight)>> {
218 let mut row = 0;
219 let mut column = 0;
220 let mut byte_offset = 0;
221 let mut was_newline = false;
222 let mut result = Vec::new();
223 let mut highlight_stack = Vec::new();
224 let source = String::from_utf8_lossy(source);
225 let mut char_indices = source.char_indices();
226 for event in highlighter.highlight(highlight_config, source.as_bytes(), None, |string| {
227 loader.highlight_config_for_injection_string(string)
228 })? {
229 match event? {
230 HighlightEvent::HighlightStart(h) => highlight_stack.push(h),
231 HighlightEvent::HighlightEnd => {
232 highlight_stack.pop();
233 }
234 HighlightEvent::Source { start, end } => {
235 let mut start_position = Point::new(row, column);
236 while byte_offset < end {
237 if byte_offset <= start {
238 start_position = Point::new(row, column);
239 }
240 if let Some((i, c)) = char_indices.next() {
241 if was_newline {
242 row += 1;
243 column = 0;
244 } else {
245 column += i - byte_offset;
246 }
247 was_newline = c == '\n';
248 byte_offset = i;
249 } else {
250 break;
251 }
252 }
253 if let Some(highlight) = highlight_stack.last() {
254 let utf8_start_position = to_utf8_point(start_position, source.as_bytes());
255 let utf8_end_position =
256 to_utf8_point(Point::new(row, column), source.as_bytes());
257 result.push((utf8_start_position, utf8_end_position, *highlight));
258 }
259 }
260 }
261 }
262 Ok(result)
263}