Skip to main content

tree_sitter_cli/
test.rs

1use std::{
2    collections::BTreeMap,
3    ffi::OsStr,
4    fmt::Display as _,
5    fs,
6    io::{self, Write},
7    path::{Path, PathBuf},
8    str,
9    sync::LazyLock,
10    time::Duration,
11};
12
13use anstyle::AnsiColor;
14use anyhow::{anyhow, Context, Result};
15use clap::ValueEnum;
16use indoc::indoc;
17use regex::{
18    bytes::{Regex as ByteRegex, RegexBuilder as ByteRegexBuilder},
19    Regex,
20};
21use schemars::{JsonSchema, Schema, SchemaGenerator};
22use serde::Serialize;
23use similar::{ChangeTag, TextDiff};
24use tree_sitter::{format_sexp, Language, LogType, Parser, Query, Tree};
25use walkdir::WalkDir;
26
27use super::util;
28use crate::{
29    logger::paint,
30    parse::{
31        render_cst, ParseDebugType, ParseFileOptions, ParseOutput, ParseStats, ParseTheme, Stats,
32    },
33};
34
35static HEADER_REGEX: LazyLock<ByteRegex> = LazyLock::new(|| {
36    ByteRegexBuilder::new(
37        r"^(?x)
38           (?P<equals>(?:=+){3,})
39           (?P<suffix1>[^=\r\n][^\r\n]*)?
40           \r?\n
41           (?P<test_name_and_markers>(?:([^=\r\n]|\s+:)[^\r\n]*\r?\n)+)
42           ===+
43           (?P<suffix2>[^=\r\n][^\r\n]*)?\r?\n",
44    )
45    .multi_line(true)
46    .build()
47    .unwrap()
48});
49
50static DIVIDER_REGEX: LazyLock<ByteRegex> = LazyLock::new(|| {
51    ByteRegexBuilder::new(r"^(?P<hyphens>(?:-+){3,})(?P<suffix>[^-\r\n][^\r\n]*)?\r?\n")
52        .multi_line(true)
53        .build()
54        .unwrap()
55});
56
57static COMMENT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?m)^\s*;.*$").unwrap());
58
59static WHITESPACE_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\s+").unwrap());
60
61static SEXP_FIELD_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r" \w+: \(").unwrap());
62
63static POINT_REGEX: LazyLock<Regex> =
64    LazyLock::new(|| Regex::new(r"\s*\[\s*\d+\s*,\s*\d+\s*\]\s*").unwrap());
65
66#[derive(Debug, PartialEq, Eq)]
67pub enum TestEntry {
68    Group {
69        name: String,
70        children: Vec<Self>,
71        file_path: Option<PathBuf>,
72    },
73    Example {
74        name: String,
75        input: Vec<u8>,
76        output: String,
77        header_delim_len: usize,
78        divider_delim_len: usize,
79        has_fields: bool,
80        attributes_str: String,
81        attributes: TestAttributes,
82        file_name: Option<String>,
83    },
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct TestAttributes {
88    pub skip: bool,
89    pub platform: bool,
90    pub fail_fast: bool,
91    pub error: bool,
92    pub cst: bool,
93    pub languages: Vec<Box<str>>,
94}
95
96impl Default for TestEntry {
97    fn default() -> Self {
98        Self::Group {
99            name: String::new(),
100            children: Vec::new(),
101            file_path: None,
102        }
103    }
104}
105
106impl Default for TestAttributes {
107    fn default() -> Self {
108        Self {
109            skip: false,
110            platform: true,
111            fail_fast: false,
112            error: false,
113            cst: false,
114            languages: vec!["".into()],
115        }
116    }
117}
118
119#[derive(ValueEnum, Default, Debug, Copy, Clone, PartialEq, Eq, Serialize)]
120pub enum TestStats {
121    All,
122    #[default]
123    OutliersAndTotal,
124    TotalOnly,
125}
126
127pub struct TestOptions<'a> {
128    pub path: PathBuf,
129    pub debug: bool,
130    pub debug_graph: bool,
131    pub include: Option<Regex>,
132    pub exclude: Option<Regex>,
133    pub file_name: Option<String>,
134    pub update: bool,
135    pub open_log: bool,
136    pub languages: BTreeMap<&'a str, &'a Language>,
137    pub color: bool,
138    pub show_fields: bool,
139    pub overview_only: bool,
140}
141
142/// A stateful object used to collect results from running a grammar's test suite
143#[derive(Debug, Default, Serialize, JsonSchema)]
144pub struct TestSummary {
145    // Parse test results and associated data
146    #[schemars(schema_with = "schema_as_array")]
147    #[serde(serialize_with = "serialize_as_array")]
148    pub parse_results: TestResultHierarchy,
149    pub parse_failures: Vec<TestFailure>,
150    pub parse_stats: Stats,
151    #[schemars(skip)]
152    #[serde(skip)]
153    pub has_parse_errors: bool,
154    #[schemars(skip)]
155    #[serde(skip)]
156    pub parse_stat_display: TestStats,
157
158    // Other test results
159    #[schemars(schema_with = "schema_as_array")]
160    #[serde(serialize_with = "serialize_as_array")]
161    pub highlight_results: TestResultHierarchy,
162    #[schemars(schema_with = "schema_as_array")]
163    #[serde(serialize_with = "serialize_as_array")]
164    pub tag_results: TestResultHierarchy,
165    #[schemars(schema_with = "schema_as_array")]
166    #[serde(serialize_with = "serialize_as_array")]
167    pub query_results: TestResultHierarchy,
168
169    // Data used during construction
170    #[schemars(skip)]
171    #[serde(skip)]
172    pub test_num: usize,
173    // Options passed in from the CLI which control how the summary is displayed
174    #[schemars(skip)]
175    #[serde(skip)]
176    pub color: bool,
177    #[schemars(skip)]
178    #[serde(skip)]
179    pub overview_only: bool,
180    #[schemars(skip)]
181    #[serde(skip)]
182    pub update: bool,
183    #[schemars(skip)]
184    #[serde(skip)]
185    pub json: bool,
186}
187
188impl TestSummary {
189    #[must_use]
190    pub fn new(
191        color: bool,
192        stat_display: TestStats,
193        parse_update: bool,
194        overview_only: bool,
195        json_summary: bool,
196    ) -> Self {
197        Self {
198            color,
199            parse_stat_display: stat_display,
200            update: parse_update,
201            overview_only,
202            json: json_summary,
203            test_num: 1,
204            ..Default::default()
205        }
206    }
207}
208
209#[derive(Debug, Default, JsonSchema)]
210pub struct TestResultHierarchy {
211    root_group: Vec<TestResult>,
212    traversal_idxs: Vec<usize>,
213}
214
215fn serialize_as_array<S>(results: &TestResultHierarchy, serializer: S) -> Result<S::Ok, S::Error>
216where
217    S: serde::Serializer,
218{
219    results.root_group.serialize(serializer)
220}
221
222fn schema_as_array(gen: &mut SchemaGenerator) -> Schema {
223    gen.subschema_for::<Vec<TestResult>>()
224}
225
226/// Stores arbitrarily nested parent test groups and child cases. Supports creation
227/// in DFS traversal order
228impl TestResultHierarchy {
229    /// Signifies the start of a new group's traversal during construction.
230    fn push_traversal(&mut self, idx: usize) {
231        self.traversal_idxs.push(idx);
232    }
233
234    /// Signifies the end of the current group's traversal during construction.
235    /// Must be paired with a prior call to [`TestResultHierarchy::add_group`].
236    pub fn pop_traversal(&mut self) {
237        self.traversal_idxs.pop();
238    }
239
240    /// Adds a new group as a child of the current group. Caller is responsible
241    /// for calling [`TestResultHierarchy::pop_traversal`] once the group is done
242    /// being traversed.
243    pub fn add_group(&mut self, group_name: &str) {
244        let new_group_idx = self.curr_group_len();
245        self.push(TestResult {
246            name: group_name.to_string(),
247            info: TestInfo::Group {
248                children: Vec::new(),
249            },
250        });
251        self.push_traversal(new_group_idx);
252    }
253
254    /// Adds a new test example as a child of the current group.
255    /// Asserts that `test_case.info` is not [`TestInfo::Group`].
256    pub fn add_case(&mut self, test_case: TestResult) {
257        assert!(!matches!(test_case.info, TestInfo::Group { .. }));
258        self.push(test_case);
259    }
260
261    /// Adds a new `TestResult` to the current group.
262    fn push(&mut self, result: TestResult) {
263        // If there are no traversal steps, we're adding to the root
264        if self.traversal_idxs.is_empty() {
265            self.root_group.push(result);
266            return;
267        }
268
269        #[allow(clippy::manual_let_else)]
270        let mut curr_group = match self.root_group[self.traversal_idxs[0]].info {
271            TestInfo::Group { ref mut children } => children,
272            _ => unreachable!(),
273        };
274        for idx in self.traversal_idxs.iter().skip(1) {
275            curr_group = match curr_group[*idx].info {
276                TestInfo::Group { ref mut children } => children,
277                _ => unreachable!(),
278            };
279        }
280
281        curr_group.push(result);
282    }
283
284    fn curr_group_len(&self) -> usize {
285        if self.traversal_idxs.is_empty() {
286            return self.root_group.len();
287        }
288
289        #[allow(clippy::manual_let_else)]
290        let mut curr_group = match self.root_group[self.traversal_idxs[0]].info {
291            TestInfo::Group { ref children } => children,
292            _ => unreachable!(),
293        };
294        for idx in self.traversal_idxs.iter().skip(1) {
295            curr_group = match curr_group[*idx].info {
296                TestInfo::Group { ref children } => children,
297                _ => unreachable!(),
298            };
299        }
300        curr_group.len()
301    }
302
303    #[allow(clippy::iter_without_into_iter)]
304    #[must_use]
305    pub fn iter(&self) -> TestResultIterWithDepth<'_> {
306        let mut stack = Vec::with_capacity(self.root_group.len());
307        for child in self.root_group.iter().rev() {
308            stack.push((0, child));
309        }
310        TestResultIterWithDepth { stack }
311    }
312}
313
314pub struct TestResultIterWithDepth<'a> {
315    stack: Vec<(usize, &'a TestResult)>,
316}
317
318impl<'a> Iterator for TestResultIterWithDepth<'a> {
319    type Item = (usize, &'a TestResult);
320
321    fn next(&mut self) -> Option<Self::Item> {
322        self.stack.pop().inspect(|(depth, result)| {
323            if let TestInfo::Group { children } = &result.info {
324                for child in children.iter().rev() {
325                    self.stack.push((depth + 1, child));
326                }
327            }
328        })
329    }
330}
331
332#[derive(Debug, Serialize, JsonSchema)]
333pub struct TestResult {
334    pub name: String,
335    #[schemars(flatten)]
336    #[serde(flatten)]
337    pub info: TestInfo,
338}
339
340#[derive(Debug, Serialize, JsonSchema)]
341#[schemars(untagged)]
342#[serde(untagged)]
343pub enum TestInfo {
344    Group {
345        children: Vec<TestResult>,
346    },
347    ParseTest {
348        outcome: TestOutcome,
349        // True parse rate, adjusted parse rate
350        #[schemars(schema_with = "parse_rate_schema")]
351        #[serde(serialize_with = "serialize_parse_rates")]
352        parse_rate: Option<(f64, f64)>,
353        test_num: usize,
354    },
355    AssertionTest {
356        outcome: TestOutcome,
357        test_num: usize,
358    },
359}
360
361fn serialize_parse_rates<S>(
362    parse_rate: &Option<(f64, f64)>,
363    serializer: S,
364) -> Result<S::Ok, S::Error>
365where
366    S: serde::Serializer,
367{
368    match parse_rate {
369        None => serializer.serialize_none(),
370        Some((first, _)) => serializer.serialize_some(first),
371    }
372}
373
374fn parse_rate_schema(gen: &mut SchemaGenerator) -> Schema {
375    gen.subschema_for::<Option<f64>>()
376}
377
378#[derive(Debug, Clone, Eq, PartialEq, Serialize, JsonSchema)]
379pub enum TestOutcome {
380    // Parse outcomes
381    Passed,
382    Failed,
383    Updated,
384    Skipped,
385    Platform,
386
387    // Highlight/Tag/Query outcomes
388    AssertionPassed { assertion_count: usize },
389    AssertionFailed { error: String },
390}
391
392impl TestSummary {
393    fn fmt_parse_results(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        let (count, total_adj_parse_time) = self
395            .parse_results
396            .iter()
397            .filter_map(|(_, result)| match result.info {
398                TestInfo::Group { .. } => None,
399                TestInfo::ParseTest { parse_rate, .. } => parse_rate,
400                _ => unreachable!(),
401            })
402            .fold((0usize, 0.0f64), |(count, rate_accum), (_, adj_rate)| {
403                (count + 1, rate_accum + adj_rate)
404            });
405
406        let avg = total_adj_parse_time / count as f64;
407        let std_dev = {
408            let variance = self
409                .parse_results
410                .iter()
411                .filter_map(|(_, result)| match result.info {
412                    TestInfo::Group { .. } => None,
413                    TestInfo::ParseTest { parse_rate, .. } => parse_rate,
414                    _ => unreachable!(),
415                })
416                .map(|(_, rate_i)| (rate_i - avg).powi(2))
417                .sum::<f64>()
418                / count as f64;
419            variance.sqrt()
420        };
421
422        for (depth, entry) in self.parse_results.iter() {
423            write!(f, "{}", "  ".repeat(depth + 1))?;
424            match &entry.info {
425                TestInfo::Group { .. } => writeln!(f, "{}:", entry.name)?,
426                TestInfo::ParseTest {
427                    outcome,
428                    parse_rate,
429                    test_num,
430                } => {
431                    let (color, result_char) = match outcome {
432                        TestOutcome::Passed => (AnsiColor::Green, "✓"),
433                        TestOutcome::Failed => (AnsiColor::Red, "✗"),
434                        TestOutcome::Updated => (AnsiColor::Blue, "✓"),
435                        TestOutcome::Skipped => (AnsiColor::Yellow, "⌀"),
436                        TestOutcome::Platform => (AnsiColor::Magenta, "⌀"),
437                        _ => unreachable!(),
438                    };
439                    let stat_display = match (self.parse_stat_display, parse_rate) {
440                        (TestStats::TotalOnly, _) | (_, None) => String::new(),
441                        (display, Some((true_rate, adj_rate))) => {
442                            let mut stats = if display == TestStats::All {
443                                format!(" ({true_rate:.3} bytes/ms)")
444                            } else {
445                                String::new()
446                            };
447                            // 3 standard deviations below the mean, aka the "Empirical Rule"
448                            if *adj_rate < 3.0f64.mul_add(-std_dev, avg) {
449                                stats += &paint(
450                                    self.color.then_some(AnsiColor::Yellow),
451                                    &format!(
452                                        " -- Warning: Slow parse rate ({true_rate:.3} bytes/ms)"
453                                    ),
454                                );
455                            }
456                            stats
457                        }
458                    };
459                    writeln!(
460                        f,
461                        "{test_num:>3}. {result_char} {}{stat_display}",
462                        paint(self.color.then_some(color), &entry.name),
463                    )?;
464                }
465                TestInfo::AssertionTest { .. } => unreachable!(),
466            }
467        }
468
469        // Parse failure info
470        if !self.parse_failures.is_empty() && self.update && !self.has_parse_errors {
471            writeln!(
472                f,
473                "\n{} update{}:\n",
474                self.parse_failures.len(),
475                if self.parse_failures.len() == 1 {
476                    ""
477                } else {
478                    "s"
479                }
480            )?;
481
482            for (i, TestFailure { name, .. }) in self.parse_failures.iter().enumerate() {
483                writeln!(f, "  {}. {name}", i + 1)?;
484            }
485        } else if !self.parse_failures.is_empty() && !self.overview_only {
486            if !self.has_parse_errors {
487                writeln!(
488                    f,
489                    "\n{} failure{}:",
490                    self.parse_failures.len(),
491                    if self.parse_failures.len() == 1 {
492                        ""
493                    } else {
494                        "s"
495                    }
496                )?;
497            }
498
499            if self.color {
500                DiffKey.fmt(f)?;
501            }
502            for (
503                i,
504                TestFailure {
505                    name,
506                    actual,
507                    expected,
508                    is_cst,
509                },
510            ) in self.parse_failures.iter().enumerate()
511            {
512                if expected == "NO ERROR" {
513                    writeln!(f, "\n  {}. {name}:\n", i + 1)?;
514                    writeln!(f, "  Expected an ERROR node, but got:")?;
515                    let actual = if *is_cst {
516                        actual
517                    } else {
518                        &format_sexp(actual, 2)
519                    };
520                    writeln!(
521                        f,
522                        "  {}",
523                        paint(self.color.then_some(AnsiColor::Red), actual)
524                    )?;
525                } else {
526                    writeln!(f, "\n  {}. {name}:", i + 1)?;
527                    if *is_cst {
528                        writeln!(
529                            f,
530                            "{}",
531                            TestDiff::new(actual, expected).with_color(self.color)
532                        )?;
533                    } else {
534                        writeln!(
535                            f,
536                            "{}",
537                            TestDiff::new(&format_sexp(actual, 2), &format_sexp(expected, 2))
538                                .with_color(self.color,)
539                        )?;
540                    }
541                }
542            }
543        } else {
544            writeln!(f)?;
545        }
546
547        Ok(())
548    }
549}
550
551impl std::fmt::Display for TestSummary {
552    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553        self.fmt_parse_results(f)?;
554
555        let mut render_assertion_results =
556            |name: &str, results: &TestResultHierarchy| -> std::fmt::Result {
557                writeln!(f, "{name}:")?;
558                for (depth, entry) in results.iter() {
559                    write!(f, "{}", "  ".repeat(depth + 2))?;
560                    match &entry.info {
561                        TestInfo::Group { .. } => writeln!(f, "{}", entry.name)?,
562                        TestInfo::AssertionTest { outcome, test_num } => match outcome {
563                            TestOutcome::AssertionPassed { assertion_count } => writeln!(
564                                f,
565                                "{:>3}. ✓ {} ({assertion_count} assertions)",
566                                test_num,
567                                paint(self.color.then_some(AnsiColor::Green), &entry.name)
568                            )?,
569                            TestOutcome::AssertionFailed { error } => {
570                                writeln!(
571                                    f,
572                                    "{:>3}. ✗ {}",
573                                    test_num,
574                                    paint(self.color.then_some(AnsiColor::Red), &entry.name)
575                                )?;
576                                writeln!(f, "{}  {error}", "  ".repeat(depth + 1))?;
577                            }
578                            _ => unreachable!(),
579                        },
580                        TestInfo::ParseTest { .. } => unreachable!(),
581                    }
582                }
583                Ok(())
584            };
585
586        if !self.highlight_results.root_group.is_empty() {
587            render_assertion_results("syntax highlighting", &self.highlight_results)?;
588        }
589
590        if !self.tag_results.root_group.is_empty() {
591            render_assertion_results("tags", &self.tag_results)?;
592        }
593
594        if !self.query_results.root_group.is_empty() {
595            render_assertion_results("queries", &self.query_results)?;
596        }
597
598        write!(f, "{}", self.parse_stats)?;
599
600        Ok(())
601    }
602}
603
604pub fn run_tests_at_path(
605    parser: &mut Parser,
606    opts: &TestOptions,
607    test_summary: &mut TestSummary,
608) -> Result<()> {
609    let test_entry = parse_tests(&opts.path)?;
610    let mut _log_session = None;
611
612    if opts.debug_graph {
613        _log_session = Some(util::log_graphs(parser, "log.html", opts.open_log)?);
614    } else if opts.debug {
615        parser.set_logger(Some(Box::new(|log_type, message| {
616            if log_type == LogType::Lex {
617                io::stderr().write_all(b"  ").unwrap();
618            }
619            writeln!(&mut io::stderr(), "{message}").unwrap();
620        })));
621    }
622
623    let mut corrected_entries = Vec::new();
624    run_tests(
625        parser,
626        test_entry,
627        opts,
628        test_summary,
629        &mut corrected_entries,
630        true,
631    )?;
632
633    parser.stop_printing_dot_graphs();
634
635    if test_summary.parse_failures.is_empty() || (opts.update && !test_summary.has_parse_errors) {
636        Ok(())
637    } else if opts.update && test_summary.has_parse_errors {
638        Err(anyhow!(indoc! {"
639                Some tests failed to parse with unexpected `ERROR` or `MISSING` nodes, as shown above, and cannot be updated automatically.
640                Either fix the grammar or manually update the tests if this is expected."}))
641    } else {
642        Err(anyhow!(""))
643    }
644}
645
646pub fn check_queries_at_path(language: &Language, path: &Path) -> Result<()> {
647    if path.exists() {
648        for entry in WalkDir::new(path)
649            .into_iter()
650            .filter_map(std::result::Result::ok)
651            .filter(|e| {
652                e.file_type().is_file()
653                    && e.path().extension().and_then(OsStr::to_str) == Some("scm")
654                    && !e.path().starts_with(".")
655            })
656        {
657            let filepath = entry.file_name().to_str().unwrap_or("");
658            let content = fs::read_to_string(entry.path())
659                .with_context(|| format!("Error reading query file {filepath:?}"))?;
660            Query::new(language, &content)
661                .with_context(|| format!("Error in query file {filepath:?}"))?;
662        }
663    }
664    Ok(())
665}
666
667pub struct DiffKey;
668
669impl std::fmt::Display for DiffKey {
670    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
671        write!(
672            f,
673            "\ncorrect / {} / {}",
674            paint(Some(AnsiColor::Green), "expected"),
675            paint(Some(AnsiColor::Red), "unexpected")
676        )?;
677        Ok(())
678    }
679}
680
681impl DiffKey {
682    /// Writes [`DiffKey`] to stdout
683    pub fn print() {
684        println!("{Self}");
685    }
686}
687
688pub struct TestDiff<'a> {
689    pub actual: &'a str,
690    pub expected: &'a str,
691    pub color: bool,
692}
693
694impl<'a> TestDiff<'a> {
695    #[must_use]
696    pub const fn new(actual: &'a str, expected: &'a str) -> Self {
697        Self {
698            actual,
699            expected,
700            color: true,
701        }
702    }
703
704    #[must_use]
705    pub const fn with_color(mut self, color: bool) -> Self {
706        self.color = color;
707        self
708    }
709}
710
711impl std::fmt::Display for TestDiff<'_> {
712    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
713        let diff = TextDiff::from_lines(self.actual, self.expected);
714        for diff in diff.iter_all_changes() {
715            match diff.tag() {
716                ChangeTag::Equal => {
717                    if self.color {
718                        write!(f, "{diff}")?;
719                    } else {
720                        write!(f, " {diff}")?;
721                    }
722                }
723                ChangeTag::Insert => {
724                    if self.color {
725                        write!(
726                            f,
727                            "{}",
728                            paint(Some(AnsiColor::Green), diff.as_str().unwrap())
729                        )?;
730                    } else {
731                        write!(f, "+{diff}")?;
732                    }
733                    if diff.missing_newline() {
734                        writeln!(f)?;
735                    }
736                }
737                ChangeTag::Delete => {
738                    if self.color {
739                        write!(f, "{}", paint(Some(AnsiColor::Red), diff.as_str().unwrap()))?;
740                    } else {
741                        write!(f, "-{diff}")?;
742                    }
743                    if diff.missing_newline() {
744                        writeln!(f)?;
745                    }
746                }
747            }
748        }
749
750        Ok(())
751    }
752}
753
754#[derive(Debug, Serialize, JsonSchema)]
755pub struct TestFailure {
756    name: String,
757    actual: String,
758    expected: String,
759    is_cst: bool,
760}
761
762impl TestFailure {
763    fn new<T, U, V>(name: T, actual: U, expected: V, is_cst: bool) -> Self
764    where
765        T: Into<String>,
766        U: Into<String>,
767        V: Into<String>,
768    {
769        Self {
770            name: name.into(),
771            actual: actual.into(),
772            expected: expected.into(),
773            is_cst,
774        }
775    }
776}
777
778struct TestCorrection {
779    name: String,
780    input: String,
781    output: String,
782    attributes_str: String,
783    header_delim_len: usize,
784    divider_delim_len: usize,
785}
786
787impl TestCorrection {
788    fn new<T, U, V, W>(
789        name: T,
790        input: U,
791        output: V,
792        attributes_str: W,
793        header_delim_len: usize,
794        divider_delim_len: usize,
795    ) -> Self
796    where
797        T: Into<String>,
798        U: Into<String>,
799        V: Into<String>,
800        W: Into<String>,
801    {
802        Self {
803            name: name.into(),
804            input: input.into(),
805            output: output.into(),
806            attributes_str: attributes_str.into(),
807            header_delim_len,
808            divider_delim_len,
809        }
810    }
811}
812
813/// This will return false if we want to "fail fast". It will bail and not parse any more tests.
814fn run_tests(
815    parser: &mut Parser,
816    test_entry: TestEntry,
817    opts: &TestOptions,
818    test_summary: &mut TestSummary,
819    corrected_entries: &mut Vec<TestCorrection>,
820    is_root: bool,
821) -> Result<bool> {
822    match test_entry {
823        TestEntry::Example {
824            name,
825            input,
826            output,
827            header_delim_len,
828            divider_delim_len,
829            has_fields,
830            attributes_str,
831            attributes,
832            ..
833        } => {
834            if attributes.skip {
835                test_summary.parse_results.add_case(TestResult {
836                    name: name.clone(),
837                    info: TestInfo::ParseTest {
838                        outcome: TestOutcome::Skipped,
839                        parse_rate: None,
840                        test_num: test_summary.test_num,
841                    },
842                });
843                test_summary.test_num += 1;
844                return Ok(true);
845            }
846
847            if !attributes.platform {
848                test_summary.parse_results.add_case(TestResult {
849                    name: name.clone(),
850                    info: TestInfo::ParseTest {
851                        outcome: TestOutcome::Platform,
852                        parse_rate: None,
853                        test_num: test_summary.test_num,
854                    },
855                });
856                test_summary.test_num += 1;
857                return Ok(true);
858            }
859
860            for (i, language_name) in attributes.languages.iter().enumerate() {
861                if !language_name.is_empty() {
862                    let language = opts
863                        .languages
864                        .get(language_name.as_ref())
865                        .ok_or_else(|| anyhow!("Language not found: {language_name}"))?;
866                    parser.set_language(language)?;
867                }
868                let start = std::time::Instant::now();
869                let tree = parser.parse(&input, None).unwrap();
870                let parse_rate = {
871                    let parse_time = start.elapsed();
872                    let true_parse_rate = tree.root_node().byte_range().len() as f64
873                        / (parse_time.as_nanos() as f64 / 1_000_000.0);
874                    let adj_parse_rate = adjusted_parse_rate(&tree, parse_time);
875
876                    test_summary.parse_stats.total_parses += 1;
877                    test_summary.parse_stats.total_duration += parse_time;
878                    test_summary.parse_stats.total_bytes += tree.root_node().byte_range().len();
879
880                    Some((true_parse_rate, adj_parse_rate))
881                };
882
883                if attributes.error {
884                    if tree.root_node().has_error() {
885                        test_summary.parse_results.add_case(TestResult {
886                            name: name.clone(),
887                            info: TestInfo::ParseTest {
888                                outcome: TestOutcome::Passed,
889                                parse_rate,
890                                test_num: test_summary.test_num,
891                            },
892                        });
893                        test_summary.parse_stats.successful_parses += 1;
894                        if opts.update {
895                            let input = String::from_utf8(input.clone()).unwrap();
896                            let output = if attributes.cst {
897                                output.clone()
898                            } else {
899                                format_sexp(&output, 0)
900                            };
901                            corrected_entries.push(TestCorrection::new(
902                                &name,
903                                input,
904                                output,
905                                &attributes_str,
906                                header_delim_len,
907                                divider_delim_len,
908                            ));
909                        }
910                    } else {
911                        if opts.update {
912                            let input = String::from_utf8(input.clone()).unwrap();
913                            // Keep the original `expected` output if the actual output has no error
914                            let output = if attributes.cst {
915                                output.clone()
916                            } else {
917                                format_sexp(&output, 0)
918                            };
919                            corrected_entries.push(TestCorrection::new(
920                                &name,
921                                input,
922                                output,
923                                &attributes_str,
924                                header_delim_len,
925                                divider_delim_len,
926                            ));
927                        }
928                        test_summary.parse_results.add_case(TestResult {
929                            name: name.clone(),
930                            info: TestInfo::ParseTest {
931                                outcome: TestOutcome::Failed,
932                                parse_rate,
933                                test_num: test_summary.test_num,
934                            },
935                        });
936                        let actual = if attributes.cst {
937                            render_test_cst(&input, &tree)?
938                        } else {
939                            tree.root_node().to_sexp()
940                        };
941                        test_summary.parse_failures.push(TestFailure::new(
942                            &name,
943                            actual,
944                            "NO ERROR",
945                            attributes.cst,
946                        ));
947                    }
948
949                    if attributes.fail_fast {
950                        return Ok(false);
951                    }
952                } else {
953                    let mut actual = if attributes.cst {
954                        render_test_cst(&input, &tree)?
955                    } else {
956                        tree.root_node().to_sexp()
957                    };
958                    if !(attributes.cst || opts.show_fields || has_fields) {
959                        actual = strip_sexp_fields(&actual);
960                    }
961
962                    if actual == output {
963                        test_summary.parse_results.add_case(TestResult {
964                            name: name.clone(),
965                            info: TestInfo::ParseTest {
966                                outcome: TestOutcome::Passed,
967                                parse_rate,
968                                test_num: test_summary.test_num,
969                            },
970                        });
971                        test_summary.parse_stats.successful_parses += 1;
972                        if opts.update {
973                            let input = String::from_utf8(input.clone()).unwrap();
974                            let output = if attributes.cst {
975                                actual
976                            } else {
977                                format_sexp(&output, 0)
978                            };
979                            corrected_entries.push(TestCorrection::new(
980                                &name,
981                                input,
982                                output,
983                                &attributes_str,
984                                header_delim_len,
985                                divider_delim_len,
986                            ));
987                        }
988                    } else {
989                        if opts.update {
990                            let input = String::from_utf8(input.clone()).unwrap();
991                            let (expected_output, actual_output) = if attributes.cst {
992                                (output.clone(), actual.clone())
993                            } else {
994                                (format_sexp(&output, 0), format_sexp(&actual, 0))
995                            };
996
997                            // Only bail early before updating if the actual is not the output,
998                            // sometimes users want to test cases that
999                            // are intended to have errors, hence why this
1000                            // check isn't shown above
1001                            if actual.contains("ERROR") || actual.contains("MISSING") {
1002                                test_summary.has_parse_errors = true;
1003
1004                                // keep the original `expected` output if the actual output has an
1005                                // error
1006                                corrected_entries.push(TestCorrection::new(
1007                                    &name,
1008                                    input,
1009                                    expected_output,
1010                                    &attributes_str,
1011                                    header_delim_len,
1012                                    divider_delim_len,
1013                                ));
1014                            } else {
1015                                corrected_entries.push(TestCorrection::new(
1016                                    &name,
1017                                    input,
1018                                    actual_output,
1019                                    &attributes_str,
1020                                    header_delim_len,
1021                                    divider_delim_len,
1022                                ));
1023                                test_summary.parse_results.add_case(TestResult {
1024                                    name: name.clone(),
1025                                    info: TestInfo::ParseTest {
1026                                        outcome: TestOutcome::Updated,
1027                                        parse_rate,
1028                                        test_num: test_summary.test_num,
1029                                    },
1030                                });
1031                            }
1032                        } else {
1033                            test_summary.parse_results.add_case(TestResult {
1034                                name: name.clone(),
1035                                info: TestInfo::ParseTest {
1036                                    outcome: TestOutcome::Failed,
1037                                    parse_rate,
1038                                    test_num: test_summary.test_num,
1039                                },
1040                            });
1041                        }
1042                        test_summary.parse_failures.push(TestFailure::new(
1043                            &name,
1044                            actual,
1045                            &output,
1046                            attributes.cst,
1047                        ));
1048
1049                        if attributes.fail_fast {
1050                            return Ok(false);
1051                        }
1052                    }
1053                }
1054
1055                if i == attributes.languages.len() - 1 {
1056                    // reset to the first language
1057                    parser.set_language(opts.languages.values().next().unwrap())?;
1058                }
1059            }
1060            test_summary.test_num += 1;
1061        }
1062        TestEntry::Group {
1063            name,
1064            children,
1065            file_path,
1066        } => {
1067            if children.is_empty() {
1068                return Ok(true);
1069            }
1070
1071            let failure_count = test_summary.parse_failures.len();
1072            let mut ran_test_in_group = false;
1073
1074            let matches_filter = |name: &str, file_name: &Option<String>, opts: &TestOptions| {
1075                if let (Some(test_file_path), Some(filter_file_name)) = (file_name, &opts.file_name)
1076                {
1077                    if !filter_file_name.eq(test_file_path) {
1078                        return false;
1079                    }
1080                }
1081                if let Some(include) = &opts.include {
1082                    include.is_match(name)
1083                } else if let Some(exclude) = &opts.exclude {
1084                    !exclude.is_match(name)
1085                } else {
1086                    true
1087                }
1088            };
1089
1090            for child in children {
1091                if let TestEntry::Example {
1092                    ref name,
1093                    ref file_name,
1094                    ref input,
1095                    ref output,
1096                    ref attributes_str,
1097                    header_delim_len,
1098                    divider_delim_len,
1099                    ..
1100                } = child
1101                {
1102                    if !matches_filter(name, file_name, opts) {
1103                        if opts.update {
1104                            let input = String::from_utf8(input.clone()).unwrap();
1105                            let output = format_sexp(output, 0);
1106                            corrected_entries.push(TestCorrection::new(
1107                                name,
1108                                input,
1109                                output,
1110                                attributes_str,
1111                                header_delim_len,
1112                                divider_delim_len,
1113                            ));
1114                        }
1115
1116                        test_summary.test_num += 1;
1117                        continue;
1118                    }
1119                }
1120
1121                if !ran_test_in_group && !is_root {
1122                    test_summary.parse_results.add_group(&name);
1123                    ran_test_in_group = true;
1124                }
1125                if !run_tests(parser, child, opts, test_summary, corrected_entries, false)? {
1126                    // fail fast
1127                    return Ok(false);
1128                }
1129            }
1130            // Now that we're done traversing the children of the current group, pop
1131            // the index
1132            test_summary.parse_results.pop_traversal();
1133
1134            if let Some(file_path) = file_path {
1135                if opts.update && test_summary.parse_failures.len() - failure_count > 0 {
1136                    write_tests(&file_path, corrected_entries)?;
1137                }
1138                corrected_entries.clear();
1139            }
1140        }
1141    }
1142    Ok(true)
1143}
1144
1145/// Convenience wrapper to render a CST for a test entry.
1146fn render_test_cst(input: &[u8], tree: &Tree) -> Result<String> {
1147    let mut rendered_cst: Vec<u8> = Vec::new();
1148    let mut cursor = tree.walk();
1149    let opts = ParseFileOptions {
1150        edits: &[],
1151        output: ParseOutput::Cst,
1152        stats: &mut ParseStats::default(),
1153        print_time: false,
1154        timeout: 0,
1155        debug: ParseDebugType::Quiet,
1156        debug_graph: false,
1157        cancellation_flag: None,
1158        encoding: None,
1159        open_log: false,
1160        no_ranges: false,
1161        parse_theme: &ParseTheme::empty(),
1162    };
1163    render_cst(input, tree, &mut cursor, &opts, &mut rendered_cst)?;
1164    Ok(String::from_utf8_lossy(&rendered_cst).trim().to_string())
1165}
1166
1167// Parse time is interpreted in ns before converting to ms to avoid truncation issues
1168// Parse rates often have several outliers, leading to a large standard deviation. Taking
1169// the log of these rates serves to "flatten" out the distribution, yielding a more
1170// usable standard deviation for finding statistically significant slow parse rates
1171// NOTE: This is just a heuristic
1172#[must_use]
1173pub fn adjusted_parse_rate(tree: &Tree, parse_time: Duration) -> f64 {
1174    f64::ln(
1175        tree.root_node().byte_range().len() as f64 / (parse_time.as_nanos() as f64 / 1_000_000.0),
1176    )
1177}
1178
1179fn write_tests(file_path: &Path, corrected_entries: &[TestCorrection]) -> Result<()> {
1180    let mut buffer = fs::File::create(file_path)?;
1181    write_tests_to_buffer(&mut buffer, corrected_entries)
1182}
1183
1184fn write_tests_to_buffer(
1185    buffer: &mut impl Write,
1186    corrected_entries: &[TestCorrection],
1187) -> Result<()> {
1188    for (
1189        i,
1190        TestCorrection {
1191            name,
1192            input,
1193            output,
1194            attributes_str,
1195            header_delim_len,
1196            divider_delim_len,
1197        },
1198    ) in corrected_entries.iter().enumerate()
1199    {
1200        if i > 0 {
1201            writeln!(buffer)?;
1202        }
1203        writeln!(
1204            buffer,
1205            "{}\n{name}\n{}{}\n{input}\n{}\n\n{}",
1206            "=".repeat(*header_delim_len),
1207            if attributes_str.is_empty() {
1208                attributes_str.clone()
1209            } else {
1210                format!("{attributes_str}\n")
1211            },
1212            "=".repeat(*header_delim_len),
1213            "-".repeat(*divider_delim_len),
1214            output.trim()
1215        )?;
1216    }
1217    Ok(())
1218}
1219
1220pub fn parse_tests(path: &Path) -> io::Result<TestEntry> {
1221    let name = path
1222        .file_stem()
1223        .and_then(|s| s.to_str())
1224        .unwrap_or("")
1225        .to_string();
1226    if path.is_dir() {
1227        let mut children = Vec::new();
1228        for entry in fs::read_dir(path)? {
1229            let entry = entry?;
1230            let hidden = entry.file_name().to_str().unwrap_or("").starts_with('.');
1231            if !hidden {
1232                children.push(entry.path());
1233            }
1234        }
1235        children.sort_by(|a, b| {
1236            a.file_name()
1237                .unwrap_or_default()
1238                .cmp(b.file_name().unwrap_or_default())
1239        });
1240        let children = children
1241            .iter()
1242            .map(|path| parse_tests(path))
1243            .collect::<io::Result<Vec<TestEntry>>>()?;
1244        Ok(TestEntry::Group {
1245            name,
1246            children,
1247            file_path: None,
1248        })
1249    } else {
1250        let content = fs::read_to_string(path)?;
1251        Ok(parse_test_content(name, &content, Some(path.to_path_buf())))
1252    }
1253}
1254
1255#[must_use]
1256pub fn strip_sexp_fields(sexp: &str) -> String {
1257    SEXP_FIELD_REGEX.replace_all(sexp, " (").to_string()
1258}
1259
1260#[must_use]
1261pub fn strip_points(sexp: &str) -> String {
1262    POINT_REGEX.replace_all(sexp, "").to_string()
1263}
1264
1265fn parse_test_content(name: String, content: &str, file_path: Option<PathBuf>) -> TestEntry {
1266    let mut children = Vec::new();
1267    let bytes = content.as_bytes();
1268    let mut prev_name = String::new();
1269    let mut prev_attributes_str = String::new();
1270    let mut prev_header_end = 0;
1271
1272    // Find the first test header in the file, and determine if it has a
1273    // custom suffix. If so, then this suffix will be used to identify
1274    // all subsequent headers and divider lines in the file.
1275    let first_suffix = HEADER_REGEX
1276        .captures(bytes)
1277        .and_then(|c| c.name("suffix1"))
1278        .map(|m| String::from_utf8_lossy(m.as_bytes()));
1279
1280    // Find all of the `===` test headers, which contain the test names.
1281    // Ignore any matches whose suffix does not match the first header
1282    // suffix in the file.
1283    let header_matches = HEADER_REGEX.captures_iter(bytes).filter_map(|c| {
1284        let header_delim_len = c.name("equals").map_or(80, |m| m.as_bytes().len());
1285        let suffix1 = c
1286            .name("suffix1")
1287            .map(|m| String::from_utf8_lossy(m.as_bytes()));
1288        let suffix2 = c
1289            .name("suffix2")
1290            .map(|m| String::from_utf8_lossy(m.as_bytes()));
1291
1292        let (mut skip, mut platform, mut fail_fast, mut error, mut cst, mut languages) =
1293            (false, None, false, false, false, vec![]);
1294
1295        let test_name_and_markers = c
1296            .name("test_name_and_markers")
1297            .map_or("".as_bytes(), |m| m.as_bytes());
1298
1299        let mut test_name = String::new();
1300        let mut attributes_str = String::new();
1301
1302        let mut seen_marker = false;
1303
1304        let test_name_and_markers = str::from_utf8(test_name_and_markers).unwrap();
1305        for line in test_name_and_markers
1306            .split_inclusive('\n')
1307            .filter(|s| !s.is_empty())
1308        {
1309            let trimmed_line = line.trim();
1310            match trimmed_line.split('(').next().unwrap() {
1311                ":skip" => (seen_marker, skip) = (true, true),
1312                ":platform" => {
1313                    if let Some(platforms) = trimmed_line.strip_prefix(':').and_then(|s| {
1314                        s.strip_prefix("platform(")
1315                            .and_then(|s| s.strip_suffix(')'))
1316                    }) {
1317                        seen_marker = true;
1318                        platform = Some(
1319                            platform.unwrap_or(false) || platforms.trim() == std::env::consts::OS,
1320                        );
1321                    }
1322                }
1323                ":fail-fast" => (seen_marker, fail_fast) = (true, true),
1324                ":error" => (seen_marker, error) = (true, true),
1325                ":language" => {
1326                    if let Some(lang) = trimmed_line.strip_prefix(':').and_then(|s| {
1327                        s.strip_prefix("language(")
1328                            .and_then(|s| s.strip_suffix(')'))
1329                    }) {
1330                        seen_marker = true;
1331                        languages.push(lang.into());
1332                    }
1333                }
1334                ":cst" => (seen_marker, cst) = (true, true),
1335                _ if !seen_marker => {
1336                    test_name.push_str(line);
1337                }
1338                _ => {}
1339            }
1340        }
1341        attributes_str.push_str(test_name_and_markers.strip_prefix(&test_name).unwrap());
1342
1343        // prefer skip over error, both shouldn't be set
1344        if skip {
1345            error = false;
1346        }
1347
1348        // add a default language if none are specified, will defer to the first language
1349        if languages.is_empty() {
1350            languages.push("".into());
1351        }
1352
1353        if suffix1 == first_suffix && suffix2 == first_suffix {
1354            let header_range = c.get(0).unwrap().range();
1355            let test_name = if test_name.is_empty() {
1356                None
1357            } else {
1358                Some(test_name.trim_end().to_string())
1359            };
1360            let attributes_str = if attributes_str.is_empty() {
1361                None
1362            } else {
1363                Some(attributes_str.trim_end().to_string())
1364            };
1365            Some((
1366                header_delim_len,
1367                header_range,
1368                test_name,
1369                attributes_str,
1370                TestAttributes {
1371                    skip,
1372                    platform: platform.unwrap_or(true),
1373                    fail_fast,
1374                    error,
1375                    cst,
1376                    languages,
1377                },
1378            ))
1379        } else {
1380            None
1381        }
1382    });
1383
1384    let (mut prev_header_len, mut prev_attributes) = (80, TestAttributes::default());
1385    for (header_delim_len, header_range, test_name, attributes_str, attributes) in header_matches
1386        .chain(Some((
1387            80,
1388            bytes.len()..bytes.len(),
1389            None,
1390            None,
1391            TestAttributes::default(),
1392        )))
1393    {
1394        // Find the longest line of dashes following each test description. That line
1395        // separates the input from the expected output. Ignore any matches whose suffix
1396        // does not match the first suffix in the file.
1397        if prev_header_end > 0 {
1398            let divider_range = DIVIDER_REGEX
1399                .captures_iter(&bytes[prev_header_end..header_range.start])
1400                .filter_map(|m| {
1401                    let divider_delim_len = m.name("hyphens").map_or(80, |m| m.as_bytes().len());
1402                    let suffix = m
1403                        .name("suffix")
1404                        .map(|m| String::from_utf8_lossy(m.as_bytes()));
1405                    if suffix == first_suffix {
1406                        let range = m.get(0).unwrap().range();
1407                        Some((
1408                            divider_delim_len,
1409                            (prev_header_end + range.start)..(prev_header_end + range.end),
1410                        ))
1411                    } else {
1412                        None
1413                    }
1414                })
1415                .max_by_key(|(_, range)| range.len());
1416
1417            if let Some((divider_delim_len, divider_range)) = divider_range {
1418                if let Ok(output) = str::from_utf8(&bytes[divider_range.end..header_range.start]) {
1419                    let mut input = bytes[prev_header_end..divider_range.start].to_vec();
1420
1421                    // Remove trailing newline from the input.
1422                    input.pop();
1423                    if input.last() == Some(&b'\r') {
1424                        input.pop();
1425                    }
1426
1427                    let (output, has_fields) = if prev_attributes.cst {
1428                        (output.trim().to_string(), false)
1429                    } else {
1430                        // Remove all comments
1431                        let output = COMMENT_REGEX.replace_all(output, "").to_string();
1432
1433                        // Normalize the whitespace in the expected output.
1434                        let output = WHITESPACE_REGEX.replace_all(output.trim(), " ");
1435                        let output = output.replace(" )", ")");
1436
1437                        // Identify if the expected output has fields indicated. If not, then
1438                        // fields will not be checked.
1439                        let has_fields = SEXP_FIELD_REGEX.is_match(&output);
1440
1441                        (output, has_fields)
1442                    };
1443
1444                    let file_name = if let Some(ref path) = file_path {
1445                        path.file_name().map(|n| n.to_string_lossy().to_string())
1446                    } else {
1447                        None
1448                    };
1449
1450                    let t = TestEntry::Example {
1451                        name: prev_name,
1452                        input,
1453                        output,
1454                        header_delim_len: prev_header_len,
1455                        divider_delim_len,
1456                        has_fields,
1457                        attributes_str: prev_attributes_str,
1458                        attributes: prev_attributes,
1459                        file_name,
1460                    };
1461
1462                    children.push(t);
1463                }
1464            }
1465        }
1466        prev_attributes = attributes;
1467        prev_name = test_name.unwrap_or_default();
1468        prev_attributes_str = attributes_str.unwrap_or_default();
1469        prev_header_len = header_delim_len;
1470        prev_header_end = header_range.end;
1471    }
1472    TestEntry::Group {
1473        name,
1474        children,
1475        file_path,
1476    }
1477}
1478
1479#[cfg(test)]
1480mod tests {
1481    use serde_json::json;
1482
1483    use crate::tests::get_language;
1484
1485    use super::*;
1486
1487    #[test]
1488    fn test_parse_test_content_simple() {
1489        let entry = parse_test_content(
1490            "the-filename".to_string(),
1491            r"
1492===============
1493The first test
1494===============
1495
1496a b c
1497
1498---
1499
1500(a
1501    (b c))
1502
1503================
1504The second test
1505================
1506d
1507---
1508(d)
1509        "
1510            .trim(),
1511            None,
1512        );
1513
1514        assert_eq!(
1515            entry,
1516            TestEntry::Group {
1517                name: "the-filename".to_string(),
1518                children: vec![
1519                    TestEntry::Example {
1520                        name: "The first test".to_string(),
1521                        input: b"\na b c\n".to_vec(),
1522                        output: "(a (b c))".to_string(),
1523                        header_delim_len: 15,
1524                        divider_delim_len: 3,
1525                        has_fields: false,
1526                        attributes_str: String::new(),
1527                        attributes: TestAttributes::default(),
1528                        file_name: None,
1529                    },
1530                    TestEntry::Example {
1531                        name: "The second test".to_string(),
1532                        input: b"d".to_vec(),
1533                        output: "(d)".to_string(),
1534                        header_delim_len: 16,
1535                        divider_delim_len: 3,
1536                        has_fields: false,
1537                        attributes_str: String::new(),
1538                        attributes: TestAttributes::default(),
1539                        file_name: None,
1540                    },
1541                ],
1542                file_path: None,
1543            }
1544        );
1545    }
1546
1547    #[test]
1548    fn test_parse_test_content_with_dashes_in_source_code() {
1549        let entry = parse_test_content(
1550            "the-filename".to_string(),
1551            r"
1552==================
1553Code with dashes
1554==================
1555abc
1556---
1557defg
1558----
1559hijkl
1560-------
1561
1562(a (b))
1563
1564=========================
1565Code ending with dashes
1566=========================
1567abc
1568-----------
1569-------------------
1570
1571(c (d))
1572        "
1573            .trim(),
1574            None,
1575        );
1576
1577        assert_eq!(
1578            entry,
1579            TestEntry::Group {
1580                name: "the-filename".to_string(),
1581                children: vec![
1582                    TestEntry::Example {
1583                        name: "Code with dashes".to_string(),
1584                        input: b"abc\n---\ndefg\n----\nhijkl".to_vec(),
1585                        output: "(a (b))".to_string(),
1586                        header_delim_len: 18,
1587                        divider_delim_len: 7,
1588                        has_fields: false,
1589                        attributes_str: String::new(),
1590                        attributes: TestAttributes::default(),
1591                        file_name: None,
1592                    },
1593                    TestEntry::Example {
1594                        name: "Code ending with dashes".to_string(),
1595                        input: b"abc\n-----------".to_vec(),
1596                        output: "(c (d))".to_string(),
1597                        header_delim_len: 25,
1598                        divider_delim_len: 19,
1599                        has_fields: false,
1600                        attributes_str: String::new(),
1601                        attributes: TestAttributes::default(),
1602                        file_name: None,
1603                    },
1604                ],
1605                file_path: None,
1606            }
1607        );
1608    }
1609
1610    #[test]
1611    fn test_format_sexp() {
1612        assert_eq!(format_sexp("", 0), "");
1613        assert_eq!(
1614            format_sexp("(a b: (c) (d) e: (f (g (h (MISSING i)))))", 0),
1615            r"
1616(a
1617  b: (c)
1618  (d)
1619  e: (f
1620    (g
1621      (h
1622        (MISSING i)))))
1623"
1624            .trim()
1625        );
1626        assert_eq!(
1627            format_sexp("(program (ERROR (UNEXPECTED ' ')) (identifier))", 0),
1628            r"
1629(program
1630  (ERROR
1631    (UNEXPECTED ' '))
1632  (identifier))
1633"
1634            .trim()
1635        );
1636        assert_eq!(
1637            format_sexp(r#"(source_file (MISSING ")"))"#, 0),
1638            r#"
1639(source_file
1640  (MISSING ")"))
1641        "#
1642            .trim()
1643        );
1644        assert_eq!(
1645            format_sexp(
1646                r"(source_file (ERROR (UNEXPECTED 'f') (UNEXPECTED '+')))",
1647                0
1648            ),
1649            r"
1650(source_file
1651  (ERROR
1652    (UNEXPECTED 'f')
1653    (UNEXPECTED '+')))
1654"
1655            .trim()
1656        );
1657    }
1658
1659    #[test]
1660    fn test_write_tests_to_buffer() {
1661        let mut buffer = Vec::new();
1662        let corrected_entries = vec![
1663            TestCorrection::new(
1664                "title 1".to_string(),
1665                "input 1".to_string(),
1666                "output 1".to_string(),
1667                String::new(),
1668                80,
1669                80,
1670            ),
1671            TestCorrection::new(
1672                "title 2".to_string(),
1673                "input 2".to_string(),
1674                "output 2".to_string(),
1675                String::new(),
1676                80,
1677                80,
1678            ),
1679        ];
1680        write_tests_to_buffer(&mut buffer, &corrected_entries).unwrap();
1681        assert_eq!(
1682            String::from_utf8(buffer).unwrap(),
1683            r"
1684================================================================================
1685title 1
1686================================================================================
1687input 1
1688--------------------------------------------------------------------------------
1689
1690output 1
1691
1692================================================================================
1693title 2
1694================================================================================
1695input 2
1696--------------------------------------------------------------------------------
1697
1698output 2
1699"
1700            .trim_start()
1701            .to_string()
1702        );
1703    }
1704
1705    #[test]
1706    fn test_parse_test_content_with_comments_in_sexp() {
1707        let entry = parse_test_content(
1708            "the-filename".to_string(),
1709            r#"
1710==================
1711sexp with comment
1712==================
1713code
1714---
1715
1716; Line start comment
1717(a (b))
1718
1719==================
1720sexp with comment between
1721==================
1722code
1723---
1724
1725; Line start comment
1726(a
1727; ignore this
1728    (b)
1729    ; also ignore this
1730)
1731
1732=========================
1733sexp with ';'
1734=========================
1735code
1736---
1737
1738(MISSING ";")
1739        "#
1740            .trim(),
1741            None,
1742        );
1743
1744        assert_eq!(
1745            entry,
1746            TestEntry::Group {
1747                name: "the-filename".to_string(),
1748                children: vec![
1749                    TestEntry::Example {
1750                        name: "sexp with comment".to_string(),
1751                        input: b"code".to_vec(),
1752                        output: "(a (b))".to_string(),
1753                        header_delim_len: 18,
1754                        divider_delim_len: 3,
1755                        has_fields: false,
1756                        attributes_str: String::new(),
1757                        attributes: TestAttributes::default(),
1758                        file_name: None,
1759                    },
1760                    TestEntry::Example {
1761                        name: "sexp with comment between".to_string(),
1762                        input: b"code".to_vec(),
1763                        output: "(a (b))".to_string(),
1764                        header_delim_len: 18,
1765                        divider_delim_len: 3,
1766                        has_fields: false,
1767                        attributes_str: String::new(),
1768                        attributes: TestAttributes::default(),
1769                        file_name: None,
1770                    },
1771                    TestEntry::Example {
1772                        name: "sexp with ';'".to_string(),
1773                        input: b"code".to_vec(),
1774                        output: "(MISSING \";\")".to_string(),
1775                        header_delim_len: 25,
1776                        divider_delim_len: 3,
1777                        has_fields: false,
1778                        attributes_str: String::new(),
1779                        attributes: TestAttributes::default(),
1780                        file_name: None,
1781                    }
1782                ],
1783                file_path: None,
1784            }
1785        );
1786    }
1787
1788    #[test]
1789    fn test_parse_test_content_with_suffixes() {
1790        let entry = parse_test_content(
1791            "the-filename".to_string(),
1792            r"
1793==================asdf\()[]|{}*+?^$.-
1794First test
1795==================asdf\()[]|{}*+?^$.-
1796
1797=========================
1798NOT A TEST HEADER
1799=========================
1800-------------------------
1801
1802---asdf\()[]|{}*+?^$.-
1803
1804(a)
1805
1806==================asdf\()[]|{}*+?^$.-
1807Second test
1808==================asdf\()[]|{}*+?^$.-
1809
1810=========================
1811NOT A TEST HEADER
1812=========================
1813-------------------------
1814
1815---asdf\()[]|{}*+?^$.-
1816
1817(a)
1818
1819=========================asdf\()[]|{}*+?^$.-
1820Test name with = symbol
1821=========================asdf\()[]|{}*+?^$.-
1822
1823=========================
1824NOT A TEST HEADER
1825=========================
1826-------------------------
1827
1828---asdf\()[]|{}*+?^$.-
1829
1830(a)
1831
1832==============================asdf\()[]|{}*+?^$.-
1833Test containing equals
1834==============================asdf\()[]|{}*+?^$.-
1835
1836===
1837
1838------------------------------asdf\()[]|{}*+?^$.-
1839
1840(a)
1841
1842==============================asdf\()[]|{}*+?^$.-
1843Subsequent test containing equals
1844==============================asdf\()[]|{}*+?^$.-
1845
1846===
1847
1848------------------------------asdf\()[]|{}*+?^$.-
1849
1850(a)
1851"
1852            .trim(),
1853            None,
1854        );
1855
1856        let expected_input = b"\n=========================\n\
1857            NOT A TEST HEADER\n\
1858            =========================\n\
1859            -------------------------\n"
1860            .to_vec();
1861        pretty_assertions::assert_eq!(
1862            entry,
1863            TestEntry::Group {
1864                name: "the-filename".to_string(),
1865                children: vec![
1866                    TestEntry::Example {
1867                        name: "First test".to_string(),
1868                        input: expected_input.clone(),
1869                        output: "(a)".to_string(),
1870                        header_delim_len: 18,
1871                        divider_delim_len: 3,
1872                        has_fields: false,
1873                        attributes_str: String::new(),
1874                        attributes: TestAttributes::default(),
1875                        file_name: None,
1876                    },
1877                    TestEntry::Example {
1878                        name: "Second test".to_string(),
1879                        input: expected_input.clone(),
1880                        output: "(a)".to_string(),
1881                        header_delim_len: 18,
1882                        divider_delim_len: 3,
1883                        has_fields: false,
1884                        attributes_str: String::new(),
1885                        attributes: TestAttributes::default(),
1886                        file_name: None,
1887                    },
1888                    TestEntry::Example {
1889                        name: "Test name with = symbol".to_string(),
1890                        input: expected_input,
1891                        output: "(a)".to_string(),
1892                        header_delim_len: 25,
1893                        divider_delim_len: 3,
1894                        has_fields: false,
1895                        attributes_str: String::new(),
1896                        attributes: TestAttributes::default(),
1897                        file_name: None,
1898                    },
1899                    TestEntry::Example {
1900                        name: "Test containing equals".to_string(),
1901                        input: "\n===\n".into(),
1902                        output: "(a)".into(),
1903                        header_delim_len: 30,
1904                        divider_delim_len: 30,
1905                        has_fields: false,
1906                        attributes_str: String::new(),
1907                        attributes: TestAttributes::default(),
1908                        file_name: None,
1909                    },
1910                    TestEntry::Example {
1911                        name: "Subsequent test containing equals".to_string(),
1912                        input: "\n===\n".into(),
1913                        output: "(a)".into(),
1914                        header_delim_len: 30,
1915                        divider_delim_len: 30,
1916                        has_fields: false,
1917                        attributes_str: String::new(),
1918                        attributes: TestAttributes::default(),
1919                        file_name: None,
1920                    }
1921                ],
1922                file_path: None,
1923            }
1924        );
1925    }
1926
1927    #[test]
1928    fn test_parse_test_content_with_newlines_in_test_names() {
1929        let entry = parse_test_content(
1930            "the-filename".to_string(),
1931            r"
1932===============
1933name
1934with
1935newlines
1936===============
1937a
1938---
1939(b)
1940
1941====================
1942name with === signs
1943====================
1944code with ----
1945---
1946(d)
1947",
1948            None,
1949        );
1950
1951        assert_eq!(
1952            entry,
1953            TestEntry::Group {
1954                name: "the-filename".to_string(),
1955                file_path: None,
1956                children: vec![
1957                    TestEntry::Example {
1958                        name: "name\nwith\nnewlines".to_string(),
1959                        input: b"a".to_vec(),
1960                        output: "(b)".to_string(),
1961                        header_delim_len: 15,
1962                        divider_delim_len: 3,
1963                        has_fields: false,
1964                        attributes_str: String::new(),
1965                        attributes: TestAttributes::default(),
1966                        file_name: None,
1967                    },
1968                    TestEntry::Example {
1969                        name: "name with === signs".to_string(),
1970                        input: b"code with ----".to_vec(),
1971                        output: "(d)".to_string(),
1972                        header_delim_len: 20,
1973                        divider_delim_len: 3,
1974                        has_fields: false,
1975                        attributes_str: String::new(),
1976                        attributes: TestAttributes::default(),
1977                        file_name: None,
1978                    }
1979                ]
1980            }
1981        );
1982    }
1983
1984    #[test]
1985    fn test_parse_test_with_markers() {
1986        // do one with :skip, we should not see it in the entry output
1987
1988        let entry = parse_test_content(
1989            "the-filename".to_string(),
1990            r"
1991=====================
1992Test with skip marker
1993:skip
1994=====================
1995a
1996---
1997(b)
1998",
1999            None,
2000        );
2001
2002        assert_eq!(
2003            entry,
2004            TestEntry::Group {
2005                name: "the-filename".to_string(),
2006                file_path: None,
2007                children: vec![TestEntry::Example {
2008                    name: "Test with skip marker".to_string(),
2009                    input: b"a".to_vec(),
2010                    output: "(b)".to_string(),
2011                    header_delim_len: 21,
2012                    divider_delim_len: 3,
2013                    has_fields: false,
2014                    attributes_str: ":skip".to_string(),
2015                    attributes: TestAttributes {
2016                        skip: true,
2017                        platform: true,
2018                        fail_fast: false,
2019                        error: false,
2020                        cst: false,
2021                        languages: vec!["".into()]
2022                    },
2023                    file_name: None,
2024                }]
2025            }
2026        );
2027
2028        let entry = parse_test_content(
2029            "the-filename".to_string(),
2030            &format!(
2031                r"
2032=========================
2033Test with platform marker
2034:platform({})
2035:fail-fast
2036=========================
2037a
2038---
2039(b)
2040
2041=============================
2042Test with bad platform marker
2043:platform({})
2044
2045:language(foo)
2046=============================
2047a
2048---
2049(b)
2050
2051====================
2052Test with cst marker
2053:cst
2054====================
20551
2056---
20570:0 - 1:0   source_file
20580:0 - 0:1   expression
20590:0 - 0:1     number_literal `1`
2060",
2061                std::env::consts::OS,
2062                if std::env::consts::OS == "linux" {
2063                    "macos"
2064                } else {
2065                    "linux"
2066                }
2067            ),
2068            None,
2069        );
2070
2071        assert_eq!(
2072            entry,
2073            TestEntry::Group {
2074                name: "the-filename".to_string(),
2075                file_path: None,
2076                children: vec![
2077                    TestEntry::Example {
2078                        name: "Test with platform marker".to_string(),
2079                        input: b"a".to_vec(),
2080                        output: "(b)".to_string(),
2081                        header_delim_len: 25,
2082                        divider_delim_len: 3,
2083                        has_fields: false,
2084                        attributes_str: format!(":platform({})\n:fail-fast", std::env::consts::OS),
2085                        attributes: TestAttributes {
2086                            skip: false,
2087                            platform: true,
2088                            fail_fast: true,
2089                            error: false,
2090                            cst: false,
2091                            languages: vec!["".into()]
2092                        },
2093                        file_name: None,
2094                    },
2095                    TestEntry::Example {
2096                        name: "Test with bad platform marker".to_string(),
2097                        input: b"a".to_vec(),
2098                        output: "(b)".to_string(),
2099                        header_delim_len: 29,
2100                        divider_delim_len: 3,
2101                        has_fields: false,
2102                        attributes_str: if std::env::consts::OS == "linux" {
2103                            ":platform(macos)\n\n:language(foo)".to_string()
2104                        } else {
2105                            ":platform(linux)\n\n:language(foo)".to_string()
2106                        },
2107                        attributes: TestAttributes {
2108                            skip: false,
2109                            platform: false,
2110                            fail_fast: false,
2111                            error: false,
2112                            cst: false,
2113                            languages: vec!["foo".into()]
2114                        },
2115                        file_name: None,
2116                    },
2117                    TestEntry::Example {
2118                        name: "Test with cst marker".to_string(),
2119                        input: b"1".to_vec(),
2120                        output: "0:0 - 1:0   source_file
21210:0 - 0:1   expression
21220:0 - 0:1     number_literal `1`"
2123                            .to_string(),
2124                        header_delim_len: 20,
2125                        divider_delim_len: 3,
2126                        has_fields: false,
2127                        attributes_str: ":cst".to_string(),
2128                        attributes: TestAttributes {
2129                            skip: false,
2130                            platform: true,
2131                            fail_fast: false,
2132                            error: false,
2133                            cst: true,
2134                            languages: vec!["".into()]
2135                        },
2136                        file_name: None,
2137                    }
2138                ]
2139            }
2140        );
2141    }
2142
2143    fn clear_parse_rate(result: &mut TestResult) {
2144        let test_case_info = &mut result.info;
2145        match test_case_info {
2146            TestInfo::ParseTest {
2147                ref mut parse_rate, ..
2148            } => {
2149                assert!(parse_rate.is_some());
2150                *parse_rate = None;
2151            }
2152            TestInfo::Group { .. } | TestInfo::AssertionTest { .. } => {
2153                panic!("Unexpected test result")
2154            }
2155        }
2156    }
2157
2158    #[test]
2159    fn run_tests_simple() {
2160        let mut parser = Parser::new();
2161        let language = get_language("c");
2162        parser
2163            .set_language(&language)
2164            .expect("Failed to set language");
2165        let mut languages = BTreeMap::new();
2166        languages.insert("c", &language);
2167        let opts = TestOptions {
2168            path: PathBuf::from("foo"),
2169            debug: true,
2170            debug_graph: false,
2171            include: None,
2172            exclude: None,
2173            file_name: None,
2174            update: false,
2175            open_log: false,
2176            languages,
2177            color: true,
2178            show_fields: false,
2179            overview_only: false,
2180        };
2181
2182        // NOTE: The following test cases are combined to work around a race condition
2183        // in the loader
2184        {
2185            let test_entry = TestEntry::Group {
2186                name: "foo".to_string(),
2187                file_path: None,
2188                children: vec![TestEntry::Example {
2189                    name: "C Test 1".to_string(),
2190                    input: b"1;\n".to_vec(),
2191                    output: "(translation_unit (expression_statement (number_literal)))"
2192                        .to_string(),
2193                    header_delim_len: 25,
2194                    divider_delim_len: 3,
2195                    has_fields: false,
2196                    attributes_str: String::new(),
2197                    attributes: TestAttributes::default(),
2198                    file_name: None,
2199                }],
2200            };
2201
2202            let mut test_summary = TestSummary::new(true, TestStats::All, false, false, false);
2203            let mut corrected_entries = Vec::new();
2204            run_tests(
2205                &mut parser,
2206                test_entry,
2207                &opts,
2208                &mut test_summary,
2209                &mut corrected_entries,
2210                true,
2211            )
2212            .expect("Failed to run tests");
2213
2214            // parse rates will always be different, so we need to clear out these
2215            // fields to reliably assert equality below
2216            clear_parse_rate(&mut test_summary.parse_results.root_group[0]);
2217            test_summary.parse_stats.total_duration = Duration::from_secs(0);
2218
2219            let json_results = serde_json::to_string(&test_summary).unwrap();
2220
2221            assert_eq!(
2222                json_results,
2223                json!({
2224                  "parse_results": [
2225                    {
2226                      "name": "C Test 1",
2227                      "outcome": "Passed",
2228                      "parse_rate": null,
2229                      "test_num": 1
2230                    }
2231                  ],
2232                  "parse_failures": [],
2233                  "parse_stats": {
2234                    "successful_parses": 1,
2235                    "total_parses": 1,
2236                    "total_bytes": 3,
2237                    "total_duration": {
2238                      "secs": 0,
2239                      "nanos": 0,
2240                    }
2241                  },
2242                  "highlight_results": [],
2243                  "tag_results": [],
2244                  "query_results": []
2245                })
2246                .to_string()
2247            );
2248        }
2249        {
2250            let test_entry = TestEntry::Group {
2251                name: "corpus".to_string(),
2252                file_path: None,
2253                children: vec![
2254                    TestEntry::Group {
2255                        name: "group1".to_string(),
2256                        // This test passes
2257                        children: vec![TestEntry::Example {
2258                            name: "C Test 1".to_string(),
2259                            input: b"1;\n".to_vec(),
2260                            output: "(translation_unit (expression_statement (number_literal)))"
2261                                .to_string(),
2262                            header_delim_len: 25,
2263                            divider_delim_len: 3,
2264                            has_fields: false,
2265                            attributes_str: String::new(),
2266                            attributes: TestAttributes::default(),
2267                            file_name: None,
2268                        }],
2269                        file_path: None,
2270                    },
2271                    TestEntry::Group {
2272                        name: "group2".to_string(),
2273                        children: vec![
2274                            // This test passes
2275                            TestEntry::Example {
2276                                name: "C Test 2".to_string(),
2277                                input: b"1;\n".to_vec(),
2278                                output:
2279                                    "(translation_unit (expression_statement (number_literal)))"
2280                                        .to_string(),
2281                                header_delim_len: 25,
2282                                divider_delim_len: 3,
2283                                has_fields: false,
2284                                attributes_str: String::new(),
2285                                attributes: TestAttributes::default(),
2286                                file_name: None,
2287                            },
2288                            // This test fails, and is marked with fail-fast
2289                            TestEntry::Example {
2290                                name: "C Test 3".to_string(),
2291                                input: b"1;\n".to_vec(),
2292                                output:
2293                                    "(translation_unit (expression_statement (string_literal)))"
2294                                        .to_string(),
2295                                header_delim_len: 25,
2296                                divider_delim_len: 3,
2297                                has_fields: false,
2298                                attributes_str: String::new(),
2299                                attributes: TestAttributes {
2300                                    fail_fast: true,
2301                                    ..Default::default()
2302                                },
2303                                file_name: None,
2304                            },
2305                        ],
2306                        file_path: None,
2307                    },
2308                    // This group never runs because of the previous failure
2309                    TestEntry::Group {
2310                        name: "group3".to_string(),
2311                        // This test fails, and is marked with fail-fast
2312                        children: vec![TestEntry::Example {
2313                            name: "C Test 4".to_string(),
2314                            input: b"1;\n".to_vec(),
2315                            output: "(translation_unit (expression_statement (number_literal)))"
2316                                .to_string(),
2317                            header_delim_len: 25,
2318                            divider_delim_len: 3,
2319                            has_fields: false,
2320                            attributes_str: String::new(),
2321                            attributes: TestAttributes::default(),
2322                            file_name: None,
2323                        }],
2324                        file_path: None,
2325                    },
2326                ],
2327            };
2328
2329            let mut test_summary = TestSummary::new(true, TestStats::All, false, false, false);
2330            let mut corrected_entries = Vec::new();
2331            run_tests(
2332                &mut parser,
2333                test_entry,
2334                &opts,
2335                &mut test_summary,
2336                &mut corrected_entries,
2337                true,
2338            )
2339            .expect("Failed to run tests");
2340
2341            // parse rates will always be different, so we need to clear out these
2342            // fields to reliably assert equality below
2343            {
2344                let test_group_1_info = &mut test_summary.parse_results.root_group[0].info;
2345                match test_group_1_info {
2346                    TestInfo::Group {
2347                        ref mut children, ..
2348                    } => clear_parse_rate(&mut children[0]),
2349                    TestInfo::ParseTest { .. } | TestInfo::AssertionTest { .. } => {
2350                        panic!("Unexpected test result");
2351                    }
2352                }
2353                let test_group_2_info = &mut test_summary.parse_results.root_group[1].info;
2354                match test_group_2_info {
2355                    TestInfo::Group {
2356                        ref mut children, ..
2357                    } => {
2358                        clear_parse_rate(&mut children[0]);
2359                        clear_parse_rate(&mut children[1]);
2360                    }
2361                    TestInfo::ParseTest { .. } | TestInfo::AssertionTest { .. } => {
2362                        panic!("Unexpected test result");
2363                    }
2364                }
2365                test_summary.parse_stats.total_duration = Duration::from_secs(0);
2366            }
2367
2368            let json_results = serde_json::to_string(&test_summary).unwrap();
2369
2370            assert_eq!(
2371                json_results,
2372                json!({
2373                  "parse_results": [
2374                    {
2375                      "name": "group1",
2376                      "children": [
2377                        {
2378                          "name": "C Test 1",
2379                          "outcome": "Passed",
2380                          "parse_rate": null,
2381                          "test_num": 1
2382                        }
2383                      ]
2384                    },
2385                    {
2386                      "name": "group2",
2387                      "children": [
2388                        {
2389                          "name": "C Test 2",
2390                          "outcome": "Passed",
2391                          "parse_rate": null,
2392                          "test_num": 2
2393                        },
2394                        {
2395                          "name": "C Test 3",
2396                          "outcome": "Failed",
2397                          "parse_rate": null,
2398                          "test_num": 3
2399                        }
2400                      ]
2401                    }
2402                  ],
2403                  "parse_failures": [
2404                    {
2405                      "name": "C Test 3",
2406                      "actual": "(translation_unit (expression_statement (number_literal)))",
2407                      "expected": "(translation_unit (expression_statement (string_literal)))",
2408                      "is_cst": false,
2409                    }
2410                  ],
2411                  "parse_stats": {
2412                    "successful_parses": 2,
2413                    "total_parses": 3,
2414                    "total_bytes": 9,
2415                    "total_duration": {
2416                      "secs": 0,
2417                      "nanos": 0,
2418                    }
2419                  },
2420                  "highlight_results": [],
2421                  "tag_results": [],
2422                  "query_results": []
2423                })
2424                .to_string()
2425            );
2426        }
2427    }
2428}