1#![cfg_attr(target_arch = "wasm32", allow(unused_imports))]
4
5use std::{collections::HashSet, fmt::Display};
6
7use serde::Serialize;
8
9use topiary_tree_sitter_facade::{
10 Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryMatch, QueryPredicate, Tree,
11};
12
13use streaming_iterator::StreamingIterator;
14
15use crate::{
16 atom_collection::{AtomCollection, QueryPredicates},
17 error::FormatterError,
18 FormatterResult,
19};
20
21#[derive(Clone, Copy, Debug)]
23pub enum Visualisation {
24 GraphViz,
25 Json,
26}
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)]
32pub struct Position {
33 pub row: u32,
34 pub column: u32,
35}
36
37impl Display for Position {
38 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
39 write!(f, "({},{})", self.row, self.column)
40 }
41}
42
43#[derive(Debug)]
47pub struct TopiaryQuery {
48 pub query: Query,
49 pub query_content: String,
50}
51
52impl TopiaryQuery {
53 pub fn new(
61 grammar: &topiary_tree_sitter_facade::Language,
62 query_content: &str,
63 ) -> FormatterResult<TopiaryQuery> {
64 let query = Query::new(grammar, query_content)
65 .map_err(|e| FormatterError::Query("Error parsing query file".into(), Some(e)))?;
66
67 Ok(TopiaryQuery {
68 query,
69 query_content: query_content.to_owned(),
70 })
71 }
72
73 #[cfg(not(target_arch = "wasm32"))]
76 pub fn pattern_position(&self, pattern_index: usize) -> Position {
77 let byte_offset = self.query.start_byte_for_pattern(pattern_index);
78 let (row, column) =
79 self.query_content[..byte_offset]
80 .chars()
81 .fold((0, 0), |(row, column), c| {
82 if c == '\n' {
83 (row + 1, 0)
84 } else {
85 (row, column + 1)
86 }
87 });
88 Position {
89 row: row + 1,
90 column: column + 1,
91 }
92 }
93
94 #[cfg(target_arch = "wasm32")]
95 pub fn pattern_position(&self, _pattern_index: usize) -> Position {
96 unimplemented!()
97 }
98}
99
100impl From<Point> for Position {
101 fn from(point: Point) -> Self {
102 Self {
103 row: point.row() + 1,
104 column: point.column() + 1,
105 }
106 }
107}
108
109#[derive(Serialize)]
111pub struct SyntaxNode {
112 #[serde(skip_serializing)]
113 pub id: usize,
114
115 pub kind: String,
116 pub is_named: bool,
117 is_extra: bool,
118 is_error: bool,
119 is_missing: bool,
120 start: Position,
121 end: Position,
122
123 pub children: Vec<SyntaxNode>,
124}
125
126impl From<Node<'_>> for SyntaxNode {
127 fn from(node: Node) -> Self {
128 let mut walker = node.walk();
129 let children = node.children(&mut walker).map(Self::from).collect();
130
131 Self {
132 id: node.id(),
133
134 kind: node.kind().into(),
135 is_named: node.is_named(),
136 is_extra: node.is_extra(),
137 is_error: node.is_error(),
138 is_missing: node.is_missing(),
139 start: node.start_position().into(),
140 end: node.end_position().into(),
141
142 children,
143 }
144 }
145}
146
147pub trait NodeExt {
153 fn display_one_based(&self) -> String;
155}
156
157impl NodeExt for Node<'_> {
158 fn display_one_based(&self) -> String {
159 format!(
160 "{{Node {:?} {} - {}}}",
161 self.kind(),
162 Position::from(self.start_position()),
163 Position::from(self.end_position()),
164 )
165 }
166}
167
168#[cfg(not(target_arch = "wasm32"))]
169impl NodeExt for tree_sitter::Node<'_> {
170 fn display_one_based(&self) -> String {
171 format!(
172 "{{Node {:?} {} - {}}}",
173 self.kind(),
174 Position::from(<tree_sitter::Point as Into<Point>>::into(
175 self.start_position()
176 )),
177 Position::from(<tree_sitter::Point as Into<Point>>::into(
178 self.end_position()
179 )),
180 )
181 }
182}
183
184#[derive(Debug)]
185struct LocalQueryMatch<'a> {
188 pattern_index: usize,
189 captures: Vec<QueryCapture<'a>>,
190}
191
192impl Display for LocalQueryMatch<'_> {
193 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
194 write!(
195 f,
196 "LocalQueryMatch {{ pattern_index: {}, captures: [ ",
197 self.pattern_index
198 )?;
199 for (index, capture) in self.captures.iter().enumerate() {
200 if index > 0 {
201 write!(f, ", ")?;
202 }
203 write!(f, "{}", capture.node().display_one_based())?;
207 }
208 write!(f, " ] }}")?;
209 Ok(())
210 }
211}
212
213#[derive(Clone, Debug, PartialEq)]
214pub struct CoverageData {
216 pub cover_percentage: f32,
217 pub missing_patterns: Vec<String>,
218}
219
220pub fn apply_query(
231 input_content: &str,
232 query: &TopiaryQuery,
233 grammar: &topiary_tree_sitter_facade::Language,
234 tolerate_parsing_errors: bool,
235) -> FormatterResult<AtomCollection> {
236 let tree = parse(input_content, grammar, tolerate_parsing_errors)?;
237 let root = tree.root_node();
238 let source = input_content.as_bytes();
239
240 let mut cursor = QueryCursor::new();
242 let mut matches: Vec<LocalQueryMatch> = Vec::new();
243 let capture_names = query.query.capture_names();
244
245 let mut query_matches = query.query.matches(&root, source, &mut cursor);
246 #[allow(clippy::while_let_on_iterator)] while let Some(query_match) = query_matches.next() {
248 let local_captures: Vec<QueryCapture> = query_match.captures().collect();
249
250 matches.push(LocalQueryMatch {
251 pattern_index: query_match.pattern_index(),
252 captures: local_captures,
253 });
254 }
255
256 let specified_leaf_nodes: HashSet<usize> = collect_leaf_ids(&matches, capture_names.clone());
259
260 let mut atoms = AtomCollection::collect_leaves(&root, source, specified_leaf_nodes)?;
262
263 log::debug!("List of atoms before formatting: {atoms:?}");
264
265 let mut pattern_positions: Vec<Option<Position>> = Vec::new();
267
268 #[cfg(not(target_arch = "wasm32"))]
271 if log::log_enabled!(log::Level::Info) {
272 pattern_positions.resize(query.query.pattern_count(), None);
273 }
274
275 for m in matches {
284 let mut predicates = QueryPredicates::default();
285
286 for p in query.query.general_predicates(m.pattern_index) {
287 predicates = handle_predicate(&p, &predicates)?;
288 }
289 check_predicates(&predicates)?;
290
291 if log::log_enabled!(log::Level::Info) {
293 #[cfg(target_arch = "wasm32")]
294 if m.pattern_index >= pattern_positions.len() {
296 pattern_positions.resize(m.pattern_index + 1, None);
297 }
298
299 let pos = pattern_positions[m.pattern_index].unwrap_or_else(|| {
301 let pos = query.pattern_position(m.pattern_index);
302 pattern_positions[m.pattern_index] = Some(pos);
303 pos
304 });
305
306 let query_name_info = if let Some(name) = &predicates.query_name {
307 format!(" of query \"{name}\"")
308 } else {
309 "".into()
310 };
311
312 log::info!("Processing match{query_name_info}: {m} at location {pos}");
313 }
314
315 if m.captures
317 .iter()
318 .any(|c| c.name(capture_names.as_slice()) == "do_nothing")
319 {
320 continue;
321 }
322
323 for c in m.captures {
324 let name = c.name(capture_names.as_slice());
325 atoms.resolve_capture(&name, &c.node(), &predicates)?;
326 }
327 }
328
329 atoms.apply_prepends_and_appends();
331
332 Ok(atoms)
333}
334
335pub fn parse(
337 content: &str,
338 grammar: &topiary_tree_sitter_facade::Language,
339 tolerate_parsing_errors: bool,
340) -> FormatterResult<Tree> {
341 let mut parser = Parser::new()?;
342 parser.set_language(grammar).map_err(|_| {
343 FormatterError::Internal("Could not apply Tree-sitter grammar".into(), None)
344 })?;
345
346 let tree = parser
347 .parse(content, None)?
348 .ok_or_else(|| FormatterError::Internal("Could not parse input".into(), None))?;
349
350 if !tolerate_parsing_errors {
352 check_for_error_nodes(&tree.root_node())?;
353 }
354
355 Ok(tree)
356}
357
358fn check_for_error_nodes(node: &Node) -> FormatterResult<()> {
359 if node.kind() == "ERROR" {
360 let start = node.start_position();
361 let end = node.end_position();
362
363 return Err(FormatterError::Parsing {
365 start_line: start.row() + 1,
366 start_column: start.column() + 1,
367 end_line: end.row() + 1,
368 end_column: end.column() + 1,
369 });
370 }
371
372 for child in node.children(&mut node.walk()) {
373 check_for_error_nodes(&child)?;
374 }
375
376 Ok(())
377}
378
379fn collect_leaf_ids(matches: &[LocalQueryMatch], capture_names: Vec<&str>) -> HashSet<usize> {
384 let mut ids = HashSet::new();
385
386 for m in matches {
387 for c in &m.captures {
388 if c.name(capture_names.as_slice()) == "leaf" {
389 ids.insert(c.node().id());
390 }
391 }
392 }
393 ids
394}
395
396fn handle_predicate(
414 predicate: &QueryPredicate,
415 predicates: &QueryPredicates,
416) -> FormatterResult<QueryPredicates> {
417 let operator = &*predicate.operator();
418 if "delimiter!" == operator {
419 let arg =
420 predicate.args().into_iter().next().ok_or_else(|| {
421 FormatterError::Query(format!("{operator} needs an argument"), None)
422 })?;
423 Ok(QueryPredicates {
424 delimiter: Some(arg),
425 ..predicates.clone()
426 })
427 } else if "scope_id!" == operator {
428 let arg =
429 predicate.args().into_iter().next().ok_or_else(|| {
430 FormatterError::Query(format!("{operator} needs an argument"), None)
431 })?;
432 Ok(QueryPredicates {
433 scope_id: Some(arg),
434 ..predicates.clone()
435 })
436 } else if "single_line_only!" == operator {
437 Ok(QueryPredicates {
438 single_line_only: true,
439 ..predicates.clone()
440 })
441 } else if "multi_line_only!" == operator {
442 Ok(QueryPredicates {
443 multi_line_only: true,
444 ..predicates.clone()
445 })
446 } else if "single_line_scope_only!" == operator {
447 let arg =
448 predicate.args().into_iter().next().ok_or_else(|| {
449 FormatterError::Query(format!("{operator} needs an argument"), None)
450 })?;
451 Ok(QueryPredicates {
452 single_line_scope_only: Some(arg),
453 ..predicates.clone()
454 })
455 } else if "multi_line_scope_only!" == operator {
456 let arg =
457 predicate.args().into_iter().next().ok_or_else(|| {
458 FormatterError::Query(format!("{operator} needs an argument"), None)
459 })?;
460 Ok(QueryPredicates {
461 multi_line_scope_only: Some(arg),
462 ..predicates.clone()
463 })
464 } else if "query_name!" == operator {
465 let arg =
466 predicate.args().into_iter().next().ok_or_else(|| {
467 FormatterError::Query(format!("{operator} needs an argument"), None)
468 })?;
469 Ok(QueryPredicates {
470 query_name: Some(arg),
471 ..predicates.clone()
472 })
473 } else {
474 Err(FormatterError::Query(
475 format!("{operator} is an unknown predicate. Maybe you forgot a \"!\"?"),
476 None,
477 ))
478 }
479}
480
481fn check_predicates(predicates: &QueryPredicates) -> FormatterResult<()> {
497 let mut incompatible_predicates = 0;
498 if predicates.single_line_only {
499 incompatible_predicates += 1;
500 }
501 if predicates.multi_line_only {
502 incompatible_predicates += 1;
503 }
504 if predicates.single_line_scope_only.is_some() {
505 incompatible_predicates += 1;
506 }
507 if predicates.multi_line_scope_only.is_some() {
508 incompatible_predicates += 1;
509 }
510 if incompatible_predicates > 1 {
511 Err(FormatterError::Query(
512 "A query can contain at most one #single/multi_line[_scope]_only! predicate".into(),
513 None,
514 ))
515 } else {
516 Ok(())
517 }
518}
519
520#[cfg(not(target_arch = "wasm32"))]
521pub fn check_query_coverage(
525 input_content: &str,
526 original_query: &TopiaryQuery,
527 grammar: &topiary_tree_sitter_facade::Language,
528) -> FormatterResult<CoverageData> {
529 let tree = parse(input_content, grammar, false)?;
530 let root = tree.root_node();
531 let source = input_content.as_bytes();
532 let mut missing_patterns = Vec::new();
533
534 let mut cursor = QueryCursor::new();
536 let ref_match_count = original_query
537 .query
538 .matches(&root, source, &mut cursor)
539 .count();
540 let pattern_count = original_query.query.pattern_count();
541 let query_content = &original_query.query_content;
542
543 if pattern_count == 0 {
546 let cover_percentage = 0.0;
547 return Ok(CoverageData {
548 cover_percentage,
549 missing_patterns,
550 });
551 }
552
553 if pattern_count == 1 {
556 let mut cover_percentage = 1.0;
557 if ref_match_count == 0 {
558 missing_patterns.push(query_content.into());
559 cover_percentage = 0.0
560 }
561 return Ok(CoverageData {
562 cover_percentage,
563 missing_patterns,
564 });
565 }
566
567 let mut ok_patterns = 0.0;
568 for i in 0..pattern_count {
569 let mut query = Query::new(grammar, query_content)
573 .map_err(|e| FormatterError::Query("Error parsing query file".into(), Some(e)))?;
574 query.disable_pattern(i);
575 let mut cursor = QueryCursor::new();
576 let match_count = query.matches(&root, source, &mut cursor).count();
577 if match_count == ref_match_count {
578 let index_start = query.start_byte_for_pattern(i);
579 let index_end = if i == pattern_count - 1 {
580 query_content.len()
581 } else {
582 query.start_byte_for_pattern(i + 1)
583 };
584 let pattern_content = &query_content[index_start..index_end];
585 missing_patterns.push(pattern_content.into());
586 } else {
587 ok_patterns += 1.0;
588 }
589 }
590
591 let cover_percentage = ok_patterns / pattern_count as f32;
592 Ok(CoverageData {
593 cover_percentage,
594 missing_patterns,
595 })
596}
597
598#[cfg(target_arch = "wasm32")]
599pub fn check_query_coverage(
600 _input_content: &str,
601 _original_query: &TopiaryQuery,
602 _grammar: &topiary_tree_sitter_facade::Language,
603) -> FormatterResult<CoverageData> {
604 unimplemented!();
605}