1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
use std::fs;
use anyhow::{anyhow, Result};
use lazy_static::lazy_static;
use regex::Regex;
use tree_sitter::{Language, Parser, Point};
lazy_static! {
static ref CAPTURE_NAME_REGEX: Regex = Regex::new("[\\w_\\-.]+").unwrap();
}
#[derive(Debug, Eq, PartialEq)]
pub struct CaptureInfo {
pub name: String,
pub start: Point,
pub end: Point,
}
#[derive(Debug, PartialEq, Eq)]
pub struct Assertion {
pub position: Point,
pub negative: bool,
pub expected_capture_name: String,
}
impl Assertion {
#[must_use]
pub fn new(row: usize, col: usize, negative: bool, expected_capture_name: String) -> Self {
Self {
position: Point::new(row, col),
negative,
expected_capture_name,
}
}
}
/// Parse the given source code, finding all of the comments that contain
/// highlighting assertions. Return a vector of (position, expected highlight name)
/// pairs.
pub fn parse_position_comments(
parser: &mut Parser,
language: &Language,
source: &[u8],
) -> Result<Vec<Assertion>> {
let mut result = Vec::new();
let mut assertion_ranges = Vec::new();
// Parse the code.
parser.set_included_ranges(&[]).unwrap();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
// Walk the tree, finding comment nodes that contain assertions.
let mut ascending = false;
let mut cursor = tree.root_node().walk();
loop {
if ascending {
let node = cursor.node();
// Find every comment node.
if node.kind().to_lowercase().contains("comment") {
if let Ok(text) = node.utf8_text(source) {
let mut position = node.start_position();
if position.row > 0 {
// Find the arrow character ("^" or '<-") in the comment. A left arrow
// refers to the column where the comment node starts. An up arrow refers
// to its own column.
let mut has_left_caret = false;
let mut has_arrow = false;
let mut negative = false;
let mut arrow_end = 0;
for (i, c) in text.char_indices() {
arrow_end = i + 1;
if c == '-' && has_left_caret {
has_arrow = true;
break;
}
if c == '^' {
has_arrow = true;
position.column += i;
break;
}
has_left_caret = c == '<';
}
// find any ! after arrows but before capture name
if has_arrow {
for (i, c) in text[arrow_end..].char_indices() {
if c == '!' {
negative = true;
arrow_end += i + 1;
break;
} else if !c.is_whitespace() {
break;
}
}
}
// If the comment node contains an arrow and a highlight name, record the
// highlight name and the position.
if let (true, Some(mat)) =
(has_arrow, CAPTURE_NAME_REGEX.find(&text[arrow_end..]))
{
assertion_ranges.push((node.start_position(), node.end_position()));
result.push(Assertion {
position,
negative,
expected_capture_name: mat.as_str().to_string(),
});
}
}
}
}
// Continue walking the tree.
if cursor.goto_next_sibling() {
ascending = false;
} else if !cursor.goto_parent() {
break;
}
} else if !cursor.goto_first_child() {
ascending = true;
}
}
// Adjust the row number in each assertion's position to refer to the line of
// code *above* the assertion. There can be multiple lines of assertion comments,
// so the positions may have to be decremented by more than one row.
let mut i = 0;
for assertion in &mut result {
loop {
let on_assertion_line = assertion_ranges[i..]
.iter()
.any(|(start, _)| start.row == assertion.position.row);
if on_assertion_line {
assertion.position.row -= 1;
} else {
while i < assertion_ranges.len()
&& assertion_ranges[i].0.row < assertion.position.row
{
i += 1;
}
break;
}
}
}
// The assertions can end up out of order due to the line adjustments.
result.sort_unstable_by_key(|a| a.position);
Ok(result)
}
pub fn assert_expected_captures(
infos: &[CaptureInfo],
path: String,
parser: &mut Parser,
language: &Language,
) -> Result<()> {
let contents = fs::read_to_string(path)?;
let pairs = parse_position_comments(parser, language, contents.as_bytes())?;
for info in infos {
if let Some(found) = pairs.iter().find(|p| {
p.position.row == info.start.row && p.position >= info.start && p.position < info.end
}) {
if found.expected_capture_name != info.name && info.name != "name" {
Err(anyhow!(
"Assertion failed: at {}, found {}, expected {}",
info.start,
found.expected_capture_name,
info.name
))?;
}
}
}
Ok(())
}