Skip to main content

tree_sitter_cli/
test.rs

1use std::{
2    collections::BTreeMap,
3    ffi::OsStr,
4    fmt::Display as _,
5    fs,
6    io::{self, Write},
7    path::{Path, PathBuf},
8    str,
9    sync::LazyLock,
10    time::Duration,
11};
12
13use anstyle::AnsiColor;
14use anyhow::{anyhow, Context, Result};
15use clap::ValueEnum;
16use indoc::indoc;
17use regex::{
18    bytes::{Regex as ByteRegex, RegexBuilder as ByteRegexBuilder},
19    Regex,
20};
21use schemars::{JsonSchema, Schema, SchemaGenerator};
22use serde::Serialize;
23use similar::{ChangeTag, TextDiff};
24use tree_sitter::{format_sexp, Language, LogType, Parser, Query, Tree};
25use walkdir::WalkDir;
26
27use super::util;
28use crate::{
29    logger::paint,
30    parse::{
31        render_cst, ParseDebugType, ParseFileOptions, ParseOutput, ParseStats, ParseTheme, Stats,
32    },
33};
34
35static HEADER_REGEX: LazyLock<ByteRegex> = LazyLock::new(|| {
36    ByteRegexBuilder::new(
37        r"^(?x)
38           (?P<equals>(?:=+){3,})
39           (?P<suffix1>[^=\r\n][^\r\n]*)?
40           \r?\n
41           (?P<test_name_and_markers>(?:([^=\r\n]|\s+:)[^\r\n]*\r?\n)+)
42           ===+
43           (?P<suffix2>[^=\r\n][^\r\n]*)?\r?\n",
44    )
45    .multi_line(true)
46    .build()
47    .unwrap()
48});
49
50static DIVIDER_REGEX: LazyLock<ByteRegex> = LazyLock::new(|| {
51    ByteRegexBuilder::new(r"^(?P<hyphens>(?:-+){3,})(?P<suffix>[^-\r\n][^\r\n]*)?\r?\n")
52        .multi_line(true)
53        .build()
54        .unwrap()
55});
56
57static COMMENT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?m)^\s*;.*$").unwrap());
58
59static WHITESPACE_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\s+").unwrap());
60
61static SEXP_FIELD_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r" \w+: \(").unwrap());
62
63static POINT_REGEX: LazyLock<Regex> =
64    LazyLock::new(|| Regex::new(r"\s*\[\s*\d+\s*,\s*\d+\s*\]\s*").unwrap());
65
66#[derive(Debug, PartialEq, Eq)]
67pub enum TestEntry {
68    Group {
69        name: String,
70        children: Vec<Self>,
71        file_path: Option<PathBuf>,
72    },
73    Example {
74        name: String,
75        input: Vec<u8>,
76        output: String,
77        header_delim_len: usize,
78        divider_delim_len: usize,
79        has_fields: bool,
80        attributes_str: String,
81        attributes: TestAttributes,
82        file_name: Option<String>,
83    },
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct TestAttributes {
88    pub skip: bool,
89    pub platform: bool,
90    pub fail_fast: bool,
91    pub error: bool,
92    pub cst: bool,
93    pub languages: Vec<Box<str>>,
94}
95
96impl Default for TestEntry {
97    fn default() -> Self {
98        Self::Group {
99            name: String::new(),
100            children: Vec::new(),
101            file_path: None,
102        }
103    }
104}
105
106impl Default for TestAttributes {
107    fn default() -> Self {
108        Self {
109            skip: false,
110            platform: true,
111            fail_fast: false,
112            error: false,
113            cst: false,
114            languages: vec!["".into()],
115        }
116    }
117}
118
119#[derive(ValueEnum, Default, Debug, Copy, Clone, PartialEq, Eq, Serialize)]
120pub enum TestStats {
121    All,
122    #[default]
123    OutliersAndTotal,
124    TotalOnly,
125}
126
127pub struct TestOptions<'a> {
128    pub path: PathBuf,
129    pub debug: bool,
130    pub debug_graph: bool,
131    pub include: Option<Regex>,
132    pub exclude: Option<Regex>,
133    pub file_name: Option<String>,
134    pub update: bool,
135    pub open_log: bool,
136    pub languages: BTreeMap<&'a str, &'a Language>,
137    pub color: bool,
138    pub show_fields: bool,
139    pub overview_only: bool,
140}
141
142/// A stateful object used to collect results from running a grammar's test suite
143#[derive(Debug, Default, Serialize, JsonSchema)]
144pub struct TestSummary {
145    // Parse test results and associated data
146    #[schemars(schema_with = "schema_as_array")]
147    #[serde(serialize_with = "serialize_as_array")]
148    pub parse_results: TestResultHierarchy,
149    pub parse_failures: Vec<TestFailure>,
150    pub parse_stats: Stats,
151    #[schemars(skip)]
152    #[serde(skip)]
153    pub has_parse_errors: bool,
154    #[schemars(skip)]
155    #[serde(skip)]
156    pub parse_stat_display: TestStats,
157
158    // Other test results
159    #[schemars(schema_with = "schema_as_array")]
160    #[serde(serialize_with = "serialize_as_array")]
161    pub highlight_results: TestResultHierarchy,
162    #[schemars(schema_with = "schema_as_array")]
163    #[serde(serialize_with = "serialize_as_array")]
164    pub tag_results: TestResultHierarchy,
165    #[schemars(schema_with = "schema_as_array")]
166    #[serde(serialize_with = "serialize_as_array")]
167    pub query_results: TestResultHierarchy,
168
169    // Data used during construction
170    #[schemars(skip)]
171    #[serde(skip)]
172    pub test_num: usize,
173    // Options passed in from the CLI which control how the summary is displayed
174    #[schemars(skip)]
175    #[serde(skip)]
176    pub color: bool,
177    #[schemars(skip)]
178    #[serde(skip)]
179    pub overview_only: bool,
180    #[schemars(skip)]
181    #[serde(skip)]
182    pub update: bool,
183    #[schemars(skip)]
184    #[serde(skip)]
185    pub json: bool,
186}
187
188impl TestSummary {
189    #[must_use]
190    pub fn new(
191        color: bool,
192        stat_display: TestStats,
193        parse_update: bool,
194        overview_only: bool,
195        json_summary: bool,
196    ) -> Self {
197        Self {
198            color,
199            parse_stat_display: stat_display,
200            update: parse_update,
201            overview_only,
202            json: json_summary,
203            test_num: 1,
204            ..Default::default()
205        }
206    }
207}
208
209#[derive(Debug, Default, JsonSchema)]
210pub struct TestResultHierarchy {
211    root_group: Vec<TestResult>,
212    traversal_idxs: Vec<usize>,
213}
214
215fn serialize_as_array<S>(results: &TestResultHierarchy, serializer: S) -> Result<S::Ok, S::Error>
216where
217    S: serde::Serializer,
218{
219    results.root_group.serialize(serializer)
220}
221
222fn schema_as_array(gen: &mut SchemaGenerator) -> Schema {
223    gen.subschema_for::<Vec<TestResult>>()
224}
225
226/// Stores arbitrarily nested parent test groups and child cases. Supports creation
227/// in DFS traversal order
228impl TestResultHierarchy {
229    /// Signifies the start of a new group's traversal during construction.
230    fn push_traversal(&mut self, idx: usize) {
231        self.traversal_idxs.push(idx);
232    }
233
234    /// Signifies the end of the current group's traversal during construction.
235    /// Must be paired with a prior call to [`TestResultHierarchy::add_group`].
236    pub fn pop_traversal(&mut self) {
237        self.traversal_idxs.pop();
238    }
239
240    /// Adds a new group as a child of the current group. Caller is responsible
241    /// for calling [`TestResultHierarchy::pop_traversal`] once the group is done
242    /// being traversed.
243    pub fn add_group(&mut self, group_name: &str) {
244        let new_group_idx = self.curr_group_len();
245        self.push(TestResult {
246            name: group_name.to_string(),
247            info: TestInfo::Group {
248                children: Vec::new(),
249            },
250        });
251        self.push_traversal(new_group_idx);
252    }
253
254    /// Adds a new test example as a child of the current group.
255    /// Asserts that `test_case.info` is not [`TestInfo::Group`].
256    pub fn add_case(&mut self, test_case: TestResult) {
257        assert!(!matches!(test_case.info, TestInfo::Group { .. }));
258        self.push(test_case);
259    }
260
261    /// Adds a new `TestResult` to the current group.
262    fn push(&mut self, result: TestResult) {
263        // If there are no traversal steps, we're adding to the root
264        if self.traversal_idxs.is_empty() {
265            self.root_group.push(result);
266            return;
267        }
268
269        #[allow(clippy::manual_let_else)]
270        let mut curr_group = match self.root_group[self.traversal_idxs[0]].info {
271            TestInfo::Group { ref mut children } => children,
272            _ => unreachable!(),
273        };
274        for idx in self.traversal_idxs.iter().skip(1) {
275            curr_group = match curr_group[*idx].info {
276                TestInfo::Group { ref mut children } => children,
277                _ => unreachable!(),
278            };
279        }
280
281        curr_group.push(result);
282    }
283
284    fn curr_group_len(&self) -> usize {
285        if self.traversal_idxs.is_empty() {
286            return self.root_group.len();
287        }
288
289        #[allow(clippy::manual_let_else)]
290        let mut curr_group = match self.root_group[self.traversal_idxs[0]].info {
291            TestInfo::Group { ref children } => children,
292            _ => unreachable!(),
293        };
294        for idx in self.traversal_idxs.iter().skip(1) {
295            curr_group = match curr_group[*idx].info {
296                TestInfo::Group { ref children } => children,
297                _ => unreachable!(),
298            };
299        }
300        curr_group.len()
301    }
302
303    #[allow(clippy::iter_without_into_iter)]
304    #[must_use]
305    pub fn iter(&self) -> TestResultIterWithDepth<'_> {
306        let mut stack = Vec::with_capacity(self.root_group.len());
307        for child in self.root_group.iter().rev() {
308            stack.push((0, child));
309        }
310        TestResultIterWithDepth { stack }
311    }
312}
313
314pub struct TestResultIterWithDepth<'a> {
315    stack: Vec<(usize, &'a TestResult)>,
316}
317
318impl<'a> Iterator for TestResultIterWithDepth<'a> {
319    type Item = (usize, &'a TestResult);
320
321    fn next(&mut self) -> Option<Self::Item> {
322        self.stack.pop().inspect(|(depth, result)| {
323            if let TestInfo::Group { children } = &result.info {
324                for child in children.iter().rev() {
325                    self.stack.push((depth + 1, child));
326                }
327            }
328        })
329    }
330}
331
332#[derive(Debug, Serialize, JsonSchema)]
333pub struct TestResult {
334    pub name: String,
335    #[schemars(flatten)]
336    #[serde(flatten)]
337    pub info: TestInfo,
338}
339
340#[derive(Debug, Serialize, JsonSchema)]
341#[schemars(untagged)]
342#[serde(untagged)]
343pub enum TestInfo {
344    Group {
345        children: Vec<TestResult>,
346    },
347    ParseTest {
348        outcome: TestOutcome,
349        // True parse rate, adjusted parse rate
350        #[schemars(schema_with = "parse_rate_schema")]
351        #[serde(serialize_with = "serialize_parse_rates")]
352        parse_rate: Option<(f64, f64)>,
353        test_num: usize,
354    },
355    AssertionTest {
356        outcome: TestOutcome,
357        test_num: usize,
358    },
359}
360
361fn serialize_parse_rates<S>(
362    parse_rate: &Option<(f64, f64)>,
363    serializer: S,
364) -> Result<S::Ok, S::Error>
365where
366    S: serde::Serializer,
367{
368    match parse_rate {
369        None => serializer.serialize_none(),
370        Some((first, _)) => serializer.serialize_some(first),
371    }
372}
373
374fn parse_rate_schema(gen: &mut SchemaGenerator) -> Schema {
375    gen.subschema_for::<Option<f64>>()
376}
377
378#[derive(Debug, Clone, Eq, PartialEq, Serialize, JsonSchema)]
379pub enum TestOutcome {
380    // Parse outcomes
381    Passed,
382    Failed,
383    Updated,
384    Skipped,
385    Platform,
386
387    // Highlight/Tag/Query outcomes
388    AssertionPassed { assertion_count: usize },
389    AssertionFailed { error: String },
390}
391
392impl TestSummary {
393    fn fmt_parse_results(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        let (count, total_adj_parse_time) = self
395            .parse_results
396            .iter()
397            .filter_map(|(_, result)| match result.info {
398                TestInfo::Group { .. } => None,
399                TestInfo::ParseTest { parse_rate, .. } => parse_rate,
400                _ => unreachable!(),
401            })
402            .fold((0usize, 0.0f64), |(count, rate_accum), (_, adj_rate)| {
403                (count + 1, rate_accum + adj_rate)
404            });
405
406        let avg = total_adj_parse_time / count as f64;
407        let std_dev = {
408            let variance = self
409                .parse_results
410                .iter()
411                .filter_map(|(_, result)| match result.info {
412                    TestInfo::Group { .. } => None,
413                    TestInfo::ParseTest { parse_rate, .. } => parse_rate,
414                    _ => unreachable!(),
415                })
416                .map(|(_, rate_i)| (rate_i - avg).powi(2))
417                .sum::<f64>()
418                / count as f64;
419            variance.sqrt()
420        };
421
422        for (depth, entry) in self.parse_results.iter() {
423            write!(f, "{}", "  ".repeat(depth + 1))?;
424            match &entry.info {
425                TestInfo::Group { .. } => writeln!(f, "{}:", entry.name)?,
426                TestInfo::ParseTest {
427                    outcome,
428                    parse_rate,
429                    test_num,
430                } => {
431                    let (color, result_char) = match outcome {
432                        TestOutcome::Passed => (AnsiColor::Green, "✓"),
433                        TestOutcome::Failed => (AnsiColor::Red, "✗"),
434                        TestOutcome::Updated => (AnsiColor::Blue, "✓"),
435                        TestOutcome::Skipped => (AnsiColor::Yellow, "⌀"),
436                        TestOutcome::Platform => (AnsiColor::Magenta, "⌀"),
437                        _ => unreachable!(),
438                    };
439                    let stat_display = match (self.parse_stat_display, parse_rate) {
440                        (TestStats::TotalOnly, _) | (_, None) => String::new(),
441                        (display, Some((true_rate, adj_rate))) => {
442                            let mut stats = if display == TestStats::All {
443                                format!(" ({true_rate:.3} bytes/ms)")
444                            } else {
445                                String::new()
446                            };
447                            // 3 standard deviations below the mean, aka the "Empirical Rule"
448                            if *adj_rate < 3.0f64.mul_add(-std_dev, avg) {
449                                stats += &paint(
450                                    self.color.then_some(AnsiColor::Yellow),
451                                    &format!(
452                                        " -- Warning: Slow parse rate ({true_rate:.3} bytes/ms)"
453                                    ),
454                                );
455                            }
456                            stats
457                        }
458                    };
459                    writeln!(
460                        f,
461                        "{test_num:>3}. {result_char} {}{stat_display}",
462                        paint(self.color.then_some(color), &entry.name),
463                    )?;
464                }
465                TestInfo::AssertionTest { .. } => unreachable!(),
466            }
467        }
468
469        // Parse failure info
470        if !self.parse_failures.is_empty() && self.update && !self.has_parse_errors {
471            writeln!(
472                f,
473                "\n{} update{}:\n",
474                self.parse_failures.len(),
475                if self.parse_failures.len() == 1 {
476                    ""
477                } else {
478                    "s"
479                }
480            )?;
481
482            for (i, TestFailure { name, .. }) in self.parse_failures.iter().enumerate() {
483                writeln!(f, "  {}. {name}", i + 1)?;
484            }
485        } else if !self.parse_failures.is_empty() && !self.overview_only {
486            if !self.has_parse_errors {
487                writeln!(
488                    f,
489                    "\n{} failure{}:",
490                    self.parse_failures.len(),
491                    if self.parse_failures.len() == 1 {
492                        ""
493                    } else {
494                        "s"
495                    }
496                )?;
497            }
498
499            if self.color {
500                DiffKey.fmt(f)?;
501            }
502            for (
503                i,
504                TestFailure {
505                    name,
506                    actual,
507                    expected,
508                    is_cst,
509                },
510            ) in self.parse_failures.iter().enumerate()
511            {
512                if expected == "NO ERROR" {
513                    writeln!(f, "\n  {}. {name}:\n", i + 1)?;
514                    writeln!(f, "  Expected an ERROR node, but got:")?;
515                    let actual = if *is_cst {
516                        actual
517                    } else {
518                        &format_sexp(actual, 2)
519                    };
520                    writeln!(
521                        f,
522                        "  {}",
523                        paint(self.color.then_some(AnsiColor::Red), actual)
524                    )?;
525                } else {
526                    writeln!(f, "\n  {}. {name}:", i + 1)?;
527                    if *is_cst {
528                        writeln!(
529                            f,
530                            "{}",
531                            TestDiff::new(actual, expected).with_color(self.color)
532                        )?;
533                    } else {
534                        writeln!(
535                            f,
536                            "{}",
537                            TestDiff::new(&format_sexp(actual, 2), &format_sexp(expected, 2))
538                                .with_color(self.color,)
539                        )?;
540                    }
541                }
542            }
543        } else {
544            writeln!(f)?;
545        }
546
547        Ok(())
548    }
549}
550
551impl std::fmt::Display for TestSummary {
552    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553        self.fmt_parse_results(f)?;
554
555        let mut render_assertion_results =
556            |name: &str, results: &TestResultHierarchy| -> std::fmt::Result {
557                writeln!(f, "{name}:")?;
558                for (depth, entry) in results.iter() {
559                    write!(f, "{}", "  ".repeat(depth + 2))?;
560                    match &entry.info {
561                        TestInfo::Group { .. } => writeln!(f, "{}", entry.name)?,
562                        TestInfo::AssertionTest { outcome, test_num } => match outcome {
563                            TestOutcome::AssertionPassed { assertion_count } => writeln!(
564                                f,
565                                "{:>3}. ✓ {} ({assertion_count} assertions)",
566                                test_num,
567                                paint(self.color.then_some(AnsiColor::Green), &entry.name)
568                            )?,
569                            TestOutcome::AssertionFailed { error } => {
570                                writeln!(
571                                    f,
572                                    "{:>3}. ✗ {}",
573                                    test_num,
574                                    paint(self.color.then_some(AnsiColor::Red), &entry.name)
575                                )?;
576                                writeln!(f, "{}  {error}", "  ".repeat(depth + 1))?;
577                            }
578                            _ => unreachable!(),
579                        },
580                        TestInfo::ParseTest { .. } => unreachable!(),
581                    }
582                }
583                Ok(())
584            };
585
586        if !self.highlight_results.root_group.is_empty() {
587            render_assertion_results("syntax highlighting", &self.highlight_results)?;
588        }
589
590        if !self.tag_results.root_group.is_empty() {
591            render_assertion_results("tags", &self.tag_results)?;
592        }
593
594        if !self.query_results.root_group.is_empty() {
595            render_assertion_results("queries", &self.query_results)?;
596        }
597
598        write!(f, "{}", self.parse_stats)?;
599
600        Ok(())
601    }
602}
603
604pub fn run_tests_at_path(
605    parser: &mut Parser,
606    opts: &TestOptions,
607    test_summary: &mut TestSummary,
608) -> Result<()> {
609    let test_entry = parse_tests(&opts.path)?;
610
611    let _log_session = if opts.debug_graph {
612        Some(util::log_graphs(parser, "log.html", opts.open_log)?)
613    } else {
614        None
615    };
616    if opts.debug {
617        parser.set_logger(Some(Box::new(|log_type, message| {
618            if log_type == LogType::Lex {
619                io::stderr().write_all(b"  ").unwrap();
620            }
621            writeln!(&mut io::stderr(), "{message}").unwrap();
622        })));
623    }
624
625    let mut corrected_entries = Vec::new();
626    run_tests(
627        parser,
628        test_entry,
629        opts,
630        test_summary,
631        &mut corrected_entries,
632        true,
633    )?;
634
635    parser.stop_printing_dot_graphs();
636
637    if test_summary.parse_failures.is_empty() || (opts.update && !test_summary.has_parse_errors) {
638        Ok(())
639    } else if opts.update && test_summary.has_parse_errors {
640        Err(anyhow!(indoc! {"
641                Some tests failed to parse with unexpected `ERROR` or `MISSING` nodes, as shown above, and cannot be updated automatically.
642                Either fix the grammar or manually update the tests if this is expected."}))
643    } else {
644        Err(anyhow!(""))
645    }
646}
647
648pub fn check_queries_at_path(language: &Language, path: &Path) -> Result<()> {
649    if path.exists() {
650        for entry in WalkDir::new(path)
651            .into_iter()
652            .filter_map(std::result::Result::ok)
653            .filter(|e| {
654                e.file_type().is_file()
655                    && e.path().extension().and_then(OsStr::to_str) == Some("scm")
656                    && !e.path().starts_with(".")
657            })
658        {
659            let filepath = entry.file_name().to_str().unwrap_or("");
660            let content = fs::read_to_string(entry.path())
661                .with_context(|| format!("Error reading query file {filepath:?}"))?;
662            Query::new(language, &content)
663                .with_context(|| format!("Error in query file {filepath:?}"))?;
664        }
665    }
666    Ok(())
667}
668
669pub struct DiffKey;
670
671impl std::fmt::Display for DiffKey {
672    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
673        write!(
674            f,
675            "\ncorrect / {} / {}",
676            paint(Some(AnsiColor::Green), "expected"),
677            paint(Some(AnsiColor::Red), "unexpected")
678        )?;
679        Ok(())
680    }
681}
682
683impl DiffKey {
684    /// Writes [`DiffKey`] to stdout
685    pub fn print() {
686        println!("{Self}");
687    }
688}
689
690pub struct TestDiff<'a> {
691    pub actual: &'a str,
692    pub expected: &'a str,
693    pub color: bool,
694}
695
696impl<'a> TestDiff<'a> {
697    #[must_use]
698    pub const fn new(actual: &'a str, expected: &'a str) -> Self {
699        Self {
700            actual,
701            expected,
702            color: true,
703        }
704    }
705
706    #[must_use]
707    pub const fn with_color(mut self, color: bool) -> Self {
708        self.color = color;
709        self
710    }
711}
712
713impl std::fmt::Display for TestDiff<'_> {
714    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
715        let diff = TextDiff::from_lines(self.actual, self.expected);
716        for diff in diff.iter_all_changes() {
717            match diff.tag() {
718                ChangeTag::Equal => {
719                    if self.color {
720                        write!(f, "{diff}")?;
721                    } else {
722                        write!(f, " {diff}")?;
723                    }
724                }
725                ChangeTag::Insert => {
726                    if self.color {
727                        write!(
728                            f,
729                            "{}",
730                            paint(Some(AnsiColor::Green), diff.as_str().unwrap())
731                        )?;
732                    } else {
733                        write!(f, "+{diff}")?;
734                    }
735                    if diff.missing_newline() {
736                        writeln!(f)?;
737                    }
738                }
739                ChangeTag::Delete => {
740                    if self.color {
741                        write!(f, "{}", paint(Some(AnsiColor::Red), diff.as_str().unwrap()))?;
742                    } else {
743                        write!(f, "-{diff}")?;
744                    }
745                    if diff.missing_newline() {
746                        writeln!(f)?;
747                    }
748                }
749            }
750        }
751
752        Ok(())
753    }
754}
755
756#[derive(Debug, Serialize, JsonSchema)]
757pub struct TestFailure {
758    name: String,
759    actual: String,
760    expected: String,
761    is_cst: bool,
762}
763
764impl TestFailure {
765    fn new<T, U, V>(name: T, actual: U, expected: V, is_cst: bool) -> Self
766    where
767        T: Into<String>,
768        U: Into<String>,
769        V: Into<String>,
770    {
771        Self {
772            name: name.into(),
773            actual: actual.into(),
774            expected: expected.into(),
775            is_cst,
776        }
777    }
778}
779
780struct TestCorrection {
781    name: String,
782    input: String,
783    output: String,
784    attributes_str: String,
785    header_delim_len: usize,
786    divider_delim_len: usize,
787}
788
789impl TestCorrection {
790    fn new<T, U, V, W>(
791        name: T,
792        input: U,
793        output: V,
794        attributes_str: W,
795        header_delim_len: usize,
796        divider_delim_len: usize,
797    ) -> Self
798    where
799        T: Into<String>,
800        U: Into<String>,
801        V: Into<String>,
802        W: Into<String>,
803    {
804        Self {
805            name: name.into(),
806            input: input.into(),
807            output: output.into(),
808            attributes_str: attributes_str.into(),
809            header_delim_len,
810            divider_delim_len,
811        }
812    }
813}
814
815/// This will return false if we want to "fail fast". It will bail and not parse any more tests.
816fn run_tests(
817    parser: &mut Parser,
818    test_entry: TestEntry,
819    opts: &TestOptions,
820    test_summary: &mut TestSummary,
821    corrected_entries: &mut Vec<TestCorrection>,
822    is_root: bool,
823) -> Result<bool> {
824    match test_entry {
825        TestEntry::Example {
826            name,
827            input,
828            output,
829            header_delim_len,
830            divider_delim_len,
831            has_fields,
832            attributes_str,
833            attributes,
834            ..
835        } => {
836            if attributes.skip {
837                test_summary.parse_results.add_case(TestResult {
838                    name: name.clone(),
839                    info: TestInfo::ParseTest {
840                        outcome: TestOutcome::Skipped,
841                        parse_rate: None,
842                        test_num: test_summary.test_num,
843                    },
844                });
845                test_summary.test_num += 1;
846                return Ok(true);
847            }
848
849            if !attributes.platform {
850                test_summary.parse_results.add_case(TestResult {
851                    name: name.clone(),
852                    info: TestInfo::ParseTest {
853                        outcome: TestOutcome::Platform,
854                        parse_rate: None,
855                        test_num: test_summary.test_num,
856                    },
857                });
858                test_summary.test_num += 1;
859                return Ok(true);
860            }
861
862            for (i, language_name) in attributes.languages.iter().enumerate() {
863                if !language_name.is_empty() {
864                    let language = opts
865                        .languages
866                        .get(language_name.as_ref())
867                        .ok_or_else(|| anyhow!("Language not found: {language_name}"))?;
868                    parser.set_language(language)?;
869                }
870                let start = std::time::Instant::now();
871                let tree = parser.parse(&input, None).unwrap();
872                let parse_rate = {
873                    let parse_time = start.elapsed();
874                    let byte_len = tree.root_node().byte_range().len();
875                    let true_parse_rate =
876                        byte_len as f64 / (parse_time.as_nanos() as f64 / 1_000_000.0);
877                    let adj_parse_rate = adjusted_parse_rate(&tree, parse_time);
878
879                    test_summary.parse_stats.total_parses += 1;
880                    test_summary.parse_stats.total_duration += parse_time;
881                    test_summary.parse_stats.total_bytes += byte_len;
882
883                    Some((true_parse_rate, adj_parse_rate))
884                };
885
886                if attributes.error {
887                    if tree.root_node().has_error() {
888                        test_summary.parse_results.add_case(TestResult {
889                            name: name.clone(),
890                            info: TestInfo::ParseTest {
891                                outcome: TestOutcome::Passed,
892                                parse_rate,
893                                test_num: test_summary.test_num,
894                            },
895                        });
896                        test_summary.parse_stats.successful_parses += 1;
897                        if opts.update {
898                            let input = String::from_utf8(input.clone()).unwrap();
899                            let output = if attributes.cst {
900                                output.clone()
901                            } else {
902                                format_sexp(&output, 0)
903                            };
904                            corrected_entries.push(TestCorrection::new(
905                                &name,
906                                input,
907                                output,
908                                &attributes_str,
909                                header_delim_len,
910                                divider_delim_len,
911                            ));
912                        }
913                    } else {
914                        if opts.update {
915                            let input = String::from_utf8(input.clone()).unwrap();
916                            // Keep the original `expected` output if the actual output has no error
917                            let output = if attributes.cst {
918                                output.clone()
919                            } else {
920                                format_sexp(&output, 0)
921                            };
922                            corrected_entries.push(TestCorrection::new(
923                                &name,
924                                input,
925                                output,
926                                &attributes_str,
927                                header_delim_len,
928                                divider_delim_len,
929                            ));
930                        }
931                        test_summary.parse_results.add_case(TestResult {
932                            name: name.clone(),
933                            info: TestInfo::ParseTest {
934                                outcome: TestOutcome::Failed,
935                                parse_rate,
936                                test_num: test_summary.test_num,
937                            },
938                        });
939                        let actual = if attributes.cst {
940                            render_test_cst(&input, &tree)?
941                        } else {
942                            tree.root_node().to_sexp()
943                        };
944                        test_summary.parse_failures.push(TestFailure::new(
945                            &name,
946                            actual,
947                            "NO ERROR",
948                            attributes.cst,
949                        ));
950                    }
951
952                    if attributes.fail_fast {
953                        return Ok(false);
954                    }
955                } else {
956                    let mut actual = if attributes.cst {
957                        render_test_cst(&input, &tree)?
958                    } else {
959                        tree.root_node().to_sexp()
960                    };
961                    if !(attributes.cst || opts.show_fields || has_fields) {
962                        actual = strip_sexp_fields(&actual);
963                    }
964
965                    if actual == output {
966                        test_summary.parse_results.add_case(TestResult {
967                            name: name.clone(),
968                            info: TestInfo::ParseTest {
969                                outcome: TestOutcome::Passed,
970                                parse_rate,
971                                test_num: test_summary.test_num,
972                            },
973                        });
974                        test_summary.parse_stats.successful_parses += 1;
975                        if opts.update {
976                            let input = String::from_utf8(input.clone()).unwrap();
977                            let output = if attributes.cst {
978                                actual
979                            } else {
980                                format_sexp(&output, 0)
981                            };
982                            corrected_entries.push(TestCorrection::new(
983                                &name,
984                                input,
985                                output,
986                                &attributes_str,
987                                header_delim_len,
988                                divider_delim_len,
989                            ));
990                        }
991                    } else {
992                        if opts.update {
993                            let input = String::from_utf8(input.clone()).unwrap();
994                            let (expected_output, actual_output) = if attributes.cst {
995                                (output.clone(), actual.clone())
996                            } else {
997                                (format_sexp(&output, 0), format_sexp(&actual, 0))
998                            };
999
1000                            // Only bail early before updating if the actual is not the output,
1001                            // sometimes users want to test cases that
1002                            // are intended to have errors, hence why this
1003                            // check isn't shown above
1004                            if actual.contains("ERROR") || actual.contains("MISSING") {
1005                                test_summary.has_parse_errors = true;
1006
1007                                // keep the original `expected` output if the actual output has an
1008                                // error
1009                                corrected_entries.push(TestCorrection::new(
1010                                    &name,
1011                                    input,
1012                                    expected_output,
1013                                    &attributes_str,
1014                                    header_delim_len,
1015                                    divider_delim_len,
1016                                ));
1017                            } else {
1018                                corrected_entries.push(TestCorrection::new(
1019                                    &name,
1020                                    input,
1021                                    actual_output,
1022                                    &attributes_str,
1023                                    header_delim_len,
1024                                    divider_delim_len,
1025                                ));
1026                                test_summary.parse_results.add_case(TestResult {
1027                                    name: name.clone(),
1028                                    info: TestInfo::ParseTest {
1029                                        outcome: TestOutcome::Updated,
1030                                        parse_rate,
1031                                        test_num: test_summary.test_num,
1032                                    },
1033                                });
1034                            }
1035                        } else {
1036                            test_summary.parse_results.add_case(TestResult {
1037                                name: name.clone(),
1038                                info: TestInfo::ParseTest {
1039                                    outcome: TestOutcome::Failed,
1040                                    parse_rate,
1041                                    test_num: test_summary.test_num,
1042                                },
1043                            });
1044                        }
1045                        test_summary.parse_failures.push(TestFailure::new(
1046                            &name,
1047                            actual,
1048                            &output,
1049                            attributes.cst,
1050                        ));
1051
1052                        if attributes.fail_fast {
1053                            return Ok(false);
1054                        }
1055                    }
1056                }
1057
1058                if i == attributes.languages.len() - 1 {
1059                    // reset to the first language
1060                    parser.set_language(opts.languages.values().next().unwrap())?;
1061                }
1062            }
1063            test_summary.test_num += 1;
1064        }
1065        TestEntry::Group {
1066            name,
1067            children,
1068            file_path,
1069        } => {
1070            if children.is_empty() {
1071                return Ok(true);
1072            }
1073
1074            let failure_count = test_summary.parse_failures.len();
1075            let mut ran_test_in_group = false;
1076
1077            let matches_filter = |name: &str, file_name: &Option<String>, opts: &TestOptions| {
1078                if let (Some(test_file_path), Some(filter_file_name)) = (file_name, &opts.file_name)
1079                {
1080                    if !filter_file_name.eq(test_file_path) {
1081                        return false;
1082                    }
1083                }
1084                if let Some(include) = &opts.include {
1085                    include.is_match(name)
1086                } else if let Some(exclude) = &opts.exclude {
1087                    !exclude.is_match(name)
1088                } else {
1089                    true
1090                }
1091            };
1092
1093            for child in children {
1094                if let TestEntry::Example {
1095                    ref name,
1096                    ref file_name,
1097                    ref input,
1098                    ref output,
1099                    ref attributes_str,
1100                    header_delim_len,
1101                    divider_delim_len,
1102                    ..
1103                } = child
1104                {
1105                    if !matches_filter(name, file_name, opts) {
1106                        if opts.update {
1107                            let input = String::from_utf8(input.clone()).unwrap();
1108                            let output = format_sexp(output, 0);
1109                            corrected_entries.push(TestCorrection::new(
1110                                name,
1111                                input,
1112                                output,
1113                                attributes_str,
1114                                header_delim_len,
1115                                divider_delim_len,
1116                            ));
1117                        }
1118
1119                        test_summary.test_num += 1;
1120                        continue;
1121                    }
1122                }
1123
1124                if !ran_test_in_group && !is_root {
1125                    test_summary.parse_results.add_group(&name);
1126                    ran_test_in_group = true;
1127                }
1128                if !run_tests(parser, child, opts, test_summary, corrected_entries, false)? {
1129                    // fail fast
1130                    return Ok(false);
1131                }
1132            }
1133            // Now that we're done traversing the children of the current group, pop
1134            // the index
1135            test_summary.parse_results.pop_traversal();
1136
1137            if let Some(file_path) = file_path {
1138                if opts.update && test_summary.parse_failures.len() - failure_count > 0 {
1139                    write_tests(&file_path, corrected_entries)?;
1140                }
1141                corrected_entries.clear();
1142            }
1143        }
1144    }
1145    Ok(true)
1146}
1147
1148/// Convenience wrapper to render a CST for a test entry.
1149fn render_test_cst(input: &[u8], tree: &Tree) -> Result<String> {
1150    let mut rendered_cst: Vec<u8> = Vec::new();
1151    let mut cursor = tree.walk();
1152    let opts = ParseFileOptions {
1153        edits: &[],
1154        output: ParseOutput::Cst,
1155        stats: &mut ParseStats::default(),
1156        print_time: false,
1157        timeout: 0,
1158        debug: ParseDebugType::Quiet,
1159        debug_graph: false,
1160        cancellation_flag: None,
1161        encoding: None,
1162        open_log: false,
1163        no_ranges: false,
1164        parse_theme: &ParseTheme::empty(),
1165    };
1166    render_cst(input, tree, &mut cursor, &opts, &mut rendered_cst)?;
1167    Ok(String::from_utf8_lossy(&rendered_cst).trim().to_string())
1168}
1169
1170// Parse time is interpreted in ns before converting to ms to avoid truncation issues
1171// Parse rates often have several outliers, leading to a large standard deviation. Taking
1172// the log of these rates serves to "flatten" out the distribution, yielding a more
1173// usable standard deviation for finding statistically significant slow parse rates
1174// NOTE: This is just a heuristic
1175#[must_use]
1176pub fn adjusted_parse_rate(tree: &Tree, parse_time: Duration) -> f64 {
1177    f64::ln(
1178        tree.root_node().byte_range().len() as f64 / (parse_time.as_nanos() as f64 / 1_000_000.0),
1179    )
1180}
1181
1182fn write_tests(file_path: &Path, corrected_entries: &[TestCorrection]) -> Result<()> {
1183    let mut buffer = fs::File::create(file_path)?;
1184    write_tests_to_buffer(&mut buffer, corrected_entries)
1185}
1186
1187fn write_tests_to_buffer(
1188    buffer: &mut impl Write,
1189    corrected_entries: &[TestCorrection],
1190) -> Result<()> {
1191    for (
1192        i,
1193        TestCorrection {
1194            name,
1195            input,
1196            output,
1197            attributes_str,
1198            header_delim_len,
1199            divider_delim_len,
1200        },
1201    ) in corrected_entries.iter().enumerate()
1202    {
1203        if i > 0 {
1204            writeln!(buffer)?;
1205        }
1206        writeln!(
1207            buffer,
1208            "{}\n{name}\n{}{}\n{input}\n{}\n\n{}",
1209            "=".repeat(*header_delim_len),
1210            if attributes_str.is_empty() {
1211                attributes_str.clone()
1212            } else {
1213                format!("{attributes_str}\n")
1214            },
1215            "=".repeat(*header_delim_len),
1216            "-".repeat(*divider_delim_len),
1217            output.trim()
1218        )?;
1219    }
1220    Ok(())
1221}
1222
1223pub fn parse_tests(path: &Path) -> io::Result<TestEntry> {
1224    let name = path
1225        .file_stem()
1226        .and_then(|s| s.to_str())
1227        .unwrap_or("")
1228        .to_string();
1229    if path.is_dir() {
1230        let mut children = Vec::new();
1231        for entry in fs::read_dir(path)? {
1232            let entry = entry?;
1233            let hidden = entry.file_name().to_str().unwrap_or("").starts_with('.');
1234            if !hidden {
1235                children.push(entry.path());
1236            }
1237        }
1238        children.sort_by(|a, b| {
1239            a.file_name()
1240                .unwrap_or_default()
1241                .cmp(b.file_name().unwrap_or_default())
1242        });
1243        let children = children
1244            .iter()
1245            .map(|path| parse_tests(path))
1246            .collect::<io::Result<Vec<TestEntry>>>()?;
1247        Ok(TestEntry::Group {
1248            name,
1249            children,
1250            file_path: None,
1251        })
1252    } else {
1253        let content = fs::read_to_string(path)?;
1254        Ok(parse_test_content(name, &content, Some(path.to_path_buf())))
1255    }
1256}
1257
1258#[must_use]
1259pub fn strip_sexp_fields(sexp: &str) -> String {
1260    SEXP_FIELD_REGEX.replace_all(sexp, " (").to_string()
1261}
1262
1263#[must_use]
1264pub fn strip_points(sexp: &str) -> String {
1265    POINT_REGEX.replace_all(sexp, "").to_string()
1266}
1267
1268fn parse_test_content(name: String, content: &str, file_path: Option<PathBuf>) -> TestEntry {
1269    let mut children = Vec::new();
1270    let bytes = content.as_bytes();
1271    let mut prev_name = String::new();
1272    let mut prev_attributes_str = String::new();
1273    let mut prev_header_end = 0;
1274
1275    // Find the first test header in the file, and determine if it has a
1276    // custom suffix. If so, then this suffix will be used to identify
1277    // all subsequent headers and divider lines in the file.
1278    let first_suffix = HEADER_REGEX
1279        .captures(bytes)
1280        .and_then(|c| c.name("suffix1"))
1281        .map(|m| String::from_utf8_lossy(m.as_bytes()));
1282
1283    // Find all of the `===` test headers, which contain the test names.
1284    // Ignore any matches whose suffix does not match the first header
1285    // suffix in the file.
1286    let header_matches = HEADER_REGEX.captures_iter(bytes).filter_map(|c| {
1287        let header_delim_len = c.name("equals").map_or(80, |m| m.as_bytes().len());
1288        let suffix1 = c
1289            .name("suffix1")
1290            .map(|m| String::from_utf8_lossy(m.as_bytes()));
1291        let suffix2 = c
1292            .name("suffix2")
1293            .map(|m| String::from_utf8_lossy(m.as_bytes()));
1294
1295        let (mut skip, mut platform, mut fail_fast, mut error, mut cst, mut languages) =
1296            (false, None, false, false, false, vec![]);
1297
1298        let test_name_and_markers = c
1299            .name("test_name_and_markers")
1300            .map_or("".as_bytes(), |m| m.as_bytes());
1301
1302        let mut test_name = String::new();
1303        let mut attributes_str = String::new();
1304
1305        let mut seen_marker = false;
1306
1307        let test_name_and_markers = str::from_utf8(test_name_and_markers).unwrap();
1308        for line in test_name_and_markers
1309            .split_inclusive('\n')
1310            .filter(|s| !s.is_empty())
1311        {
1312            let trimmed_line = line.trim();
1313            match trimmed_line.split('(').next().unwrap() {
1314                ":skip" => (seen_marker, skip) = (true, true),
1315                ":platform" => {
1316                    if let Some(platforms) = trimmed_line.strip_prefix(':').and_then(|s| {
1317                        s.strip_prefix("platform(")
1318                            .and_then(|s| s.strip_suffix(')'))
1319                    }) {
1320                        seen_marker = true;
1321                        platform = Some(
1322                            platform.unwrap_or(false) || platforms.trim() == std::env::consts::OS,
1323                        );
1324                    }
1325                }
1326                ":fail-fast" => (seen_marker, fail_fast) = (true, true),
1327                ":error" => (seen_marker, error) = (true, true),
1328                ":language" => {
1329                    if let Some(lang) = trimmed_line.strip_prefix(':').and_then(|s| {
1330                        s.strip_prefix("language(")
1331                            .and_then(|s| s.strip_suffix(')'))
1332                    }) {
1333                        seen_marker = true;
1334                        languages.push(lang.into());
1335                    }
1336                }
1337                ":cst" => (seen_marker, cst) = (true, true),
1338                _ if !seen_marker => {
1339                    test_name.push_str(line);
1340                }
1341                _ => {}
1342            }
1343        }
1344        attributes_str.push_str(test_name_and_markers.strip_prefix(&test_name).unwrap());
1345
1346        // prefer skip over error, both shouldn't be set
1347        if skip {
1348            error = false;
1349        }
1350
1351        // add a default language if none are specified, will defer to the first language
1352        if languages.is_empty() {
1353            languages.push("".into());
1354        }
1355
1356        if suffix1 == first_suffix && suffix2 == first_suffix {
1357            let header_range = c.get(0).unwrap().range();
1358            let test_name = if test_name.is_empty() {
1359                None
1360            } else {
1361                Some(test_name.trim_end().to_string())
1362            };
1363            let attributes_str = if attributes_str.is_empty() {
1364                None
1365            } else {
1366                Some(attributes_str.trim_end().to_string())
1367            };
1368            Some((
1369                header_delim_len,
1370                header_range,
1371                test_name,
1372                attributes_str,
1373                TestAttributes {
1374                    skip,
1375                    platform: platform.unwrap_or(true),
1376                    fail_fast,
1377                    error,
1378                    cst,
1379                    languages,
1380                },
1381            ))
1382        } else {
1383            None
1384        }
1385    });
1386
1387    let (mut prev_header_len, mut prev_attributes) = (80, TestAttributes::default());
1388    for (header_delim_len, header_range, test_name, attributes_str, attributes) in header_matches
1389        .chain(Some((
1390            80,
1391            bytes.len()..bytes.len(),
1392            None,
1393            None,
1394            TestAttributes::default(),
1395        )))
1396    {
1397        // Find the longest line of dashes following each test description. That line
1398        // separates the input from the expected output. Ignore any matches whose suffix
1399        // does not match the first suffix in the file.
1400        if prev_header_end > 0 {
1401            let divider_range = DIVIDER_REGEX
1402                .captures_iter(&bytes[prev_header_end..header_range.start])
1403                .filter_map(|m| {
1404                    let divider_delim_len = m.name("hyphens").map_or(80, |m| m.as_bytes().len());
1405                    let suffix = m
1406                        .name("suffix")
1407                        .map(|m| String::from_utf8_lossy(m.as_bytes()));
1408                    if suffix == first_suffix {
1409                        let range = m.get(0).unwrap().range();
1410                        Some((
1411                            divider_delim_len,
1412                            (prev_header_end + range.start)..(prev_header_end + range.end),
1413                        ))
1414                    } else {
1415                        None
1416                    }
1417                })
1418                .max_by_key(|(_, range)| range.len());
1419
1420            if let Some((divider_delim_len, divider_range)) = divider_range {
1421                if let Ok(output) = str::from_utf8(&bytes[divider_range.end..header_range.start]) {
1422                    let mut input = bytes[prev_header_end..divider_range.start].to_vec();
1423
1424                    // Remove trailing newline from the input.
1425                    input.pop();
1426                    if input.last() == Some(&b'\r') {
1427                        input.pop();
1428                    }
1429
1430                    let (output, has_fields) = if prev_attributes.cst {
1431                        (output.trim().to_string(), false)
1432                    } else {
1433                        // Remove all comments
1434                        let output = COMMENT_REGEX.replace_all(output, "").to_string();
1435
1436                        // Normalize the whitespace in the expected output.
1437                        let output = WHITESPACE_REGEX.replace_all(output.trim(), " ");
1438                        let output = output.replace(" )", ")");
1439
1440                        // Identify if the expected output has fields indicated. If not, then
1441                        // fields will not be checked.
1442                        let has_fields = SEXP_FIELD_REGEX.is_match(&output);
1443
1444                        (output, has_fields)
1445                    };
1446
1447                    let file_name = if let Some(ref path) = file_path {
1448                        path.file_name().map(|n| n.to_string_lossy().to_string())
1449                    } else {
1450                        None
1451                    };
1452
1453                    let t = TestEntry::Example {
1454                        name: prev_name,
1455                        input,
1456                        output,
1457                        header_delim_len: prev_header_len,
1458                        divider_delim_len,
1459                        has_fields,
1460                        attributes_str: prev_attributes_str,
1461                        attributes: prev_attributes,
1462                        file_name,
1463                    };
1464
1465                    children.push(t);
1466                }
1467            }
1468        }
1469        prev_attributes = attributes;
1470        prev_name = test_name.unwrap_or_default();
1471        prev_attributes_str = attributes_str.unwrap_or_default();
1472        prev_header_len = header_delim_len;
1473        prev_header_end = header_range.end;
1474    }
1475    TestEntry::Group {
1476        name,
1477        children,
1478        file_path,
1479    }
1480}
1481
1482#[cfg(test)]
1483mod tests {
1484    use serde_json::json;
1485
1486    use crate::tests::get_language;
1487
1488    use super::*;
1489
1490    #[test]
1491    fn test_parse_test_content_simple() {
1492        let entry = parse_test_content(
1493            "the-filename".to_string(),
1494            r"
1495===============
1496The first test
1497===============
1498
1499a b c
1500
1501---
1502
1503(a
1504    (b c))
1505
1506================
1507The second test
1508================
1509d
1510---
1511(d)
1512        "
1513            .trim(),
1514            None,
1515        );
1516
1517        assert_eq!(
1518            entry,
1519            TestEntry::Group {
1520                name: "the-filename".to_string(),
1521                children: vec![
1522                    TestEntry::Example {
1523                        name: "The first test".to_string(),
1524                        input: b"\na b c\n".to_vec(),
1525                        output: "(a (b c))".to_string(),
1526                        header_delim_len: 15,
1527                        divider_delim_len: 3,
1528                        has_fields: false,
1529                        attributes_str: String::new(),
1530                        attributes: TestAttributes::default(),
1531                        file_name: None,
1532                    },
1533                    TestEntry::Example {
1534                        name: "The second test".to_string(),
1535                        input: b"d".to_vec(),
1536                        output: "(d)".to_string(),
1537                        header_delim_len: 16,
1538                        divider_delim_len: 3,
1539                        has_fields: false,
1540                        attributes_str: String::new(),
1541                        attributes: TestAttributes::default(),
1542                        file_name: None,
1543                    },
1544                ],
1545                file_path: None,
1546            }
1547        );
1548    }
1549
1550    #[test]
1551    fn test_parse_test_content_with_dashes_in_source_code() {
1552        let entry = parse_test_content(
1553            "the-filename".to_string(),
1554            r"
1555==================
1556Code with dashes
1557==================
1558abc
1559---
1560defg
1561----
1562hijkl
1563-------
1564
1565(a (b))
1566
1567=========================
1568Code ending with dashes
1569=========================
1570abc
1571-----------
1572-------------------
1573
1574(c (d))
1575        "
1576            .trim(),
1577            None,
1578        );
1579
1580        assert_eq!(
1581            entry,
1582            TestEntry::Group {
1583                name: "the-filename".to_string(),
1584                children: vec![
1585                    TestEntry::Example {
1586                        name: "Code with dashes".to_string(),
1587                        input: b"abc\n---\ndefg\n----\nhijkl".to_vec(),
1588                        output: "(a (b))".to_string(),
1589                        header_delim_len: 18,
1590                        divider_delim_len: 7,
1591                        has_fields: false,
1592                        attributes_str: String::new(),
1593                        attributes: TestAttributes::default(),
1594                        file_name: None,
1595                    },
1596                    TestEntry::Example {
1597                        name: "Code ending with dashes".to_string(),
1598                        input: b"abc\n-----------".to_vec(),
1599                        output: "(c (d))".to_string(),
1600                        header_delim_len: 25,
1601                        divider_delim_len: 19,
1602                        has_fields: false,
1603                        attributes_str: String::new(),
1604                        attributes: TestAttributes::default(),
1605                        file_name: None,
1606                    },
1607                ],
1608                file_path: None,
1609            }
1610        );
1611    }
1612
1613    #[test]
1614    fn test_format_sexp() {
1615        assert_eq!(format_sexp("", 0), "");
1616        assert_eq!(
1617            format_sexp("(a b: (c) (d) e: (f (g (h (MISSING i)))))", 0),
1618            r"
1619(a
1620  b: (c)
1621  (d)
1622  e: (f
1623    (g
1624      (h
1625        (MISSING i)))))
1626"
1627            .trim()
1628        );
1629        assert_eq!(
1630            format_sexp("(program (ERROR (UNEXPECTED ' ')) (identifier))", 0),
1631            r"
1632(program
1633  (ERROR
1634    (UNEXPECTED ' '))
1635  (identifier))
1636"
1637            .trim()
1638        );
1639        assert_eq!(
1640            format_sexp(r#"(source_file (MISSING ")"))"#, 0),
1641            r#"
1642(source_file
1643  (MISSING ")"))
1644        "#
1645            .trim()
1646        );
1647        assert_eq!(
1648            format_sexp(
1649                r"(source_file (ERROR (UNEXPECTED 'f') (UNEXPECTED '+')))",
1650                0
1651            ),
1652            r"
1653(source_file
1654  (ERROR
1655    (UNEXPECTED 'f')
1656    (UNEXPECTED '+')))
1657"
1658            .trim()
1659        );
1660    }
1661
1662    #[test]
1663    fn test_write_tests_to_buffer() {
1664        let mut buffer = Vec::new();
1665        let corrected_entries = vec![
1666            TestCorrection::new(
1667                "title 1".to_string(),
1668                "input 1".to_string(),
1669                "output 1".to_string(),
1670                String::new(),
1671                80,
1672                80,
1673            ),
1674            TestCorrection::new(
1675                "title 2".to_string(),
1676                "input 2".to_string(),
1677                "output 2".to_string(),
1678                String::new(),
1679                80,
1680                80,
1681            ),
1682        ];
1683        write_tests_to_buffer(&mut buffer, &corrected_entries).unwrap();
1684        assert_eq!(
1685            String::from_utf8(buffer).unwrap(),
1686            r"
1687================================================================================
1688title 1
1689================================================================================
1690input 1
1691--------------------------------------------------------------------------------
1692
1693output 1
1694
1695================================================================================
1696title 2
1697================================================================================
1698input 2
1699--------------------------------------------------------------------------------
1700
1701output 2
1702"
1703            .trim_start()
1704            .to_string()
1705        );
1706    }
1707
1708    #[test]
1709    fn test_parse_test_content_with_comments_in_sexp() {
1710        let entry = parse_test_content(
1711            "the-filename".to_string(),
1712            r#"
1713==================
1714sexp with comment
1715==================
1716code
1717---
1718
1719; Line start comment
1720(a (b))
1721
1722==================
1723sexp with comment between
1724==================
1725code
1726---
1727
1728; Line start comment
1729(a
1730; ignore this
1731    (b)
1732    ; also ignore this
1733)
1734
1735=========================
1736sexp with ';'
1737=========================
1738code
1739---
1740
1741(MISSING ";")
1742        "#
1743            .trim(),
1744            None,
1745        );
1746
1747        assert_eq!(
1748            entry,
1749            TestEntry::Group {
1750                name: "the-filename".to_string(),
1751                children: vec![
1752                    TestEntry::Example {
1753                        name: "sexp with comment".to_string(),
1754                        input: b"code".to_vec(),
1755                        output: "(a (b))".to_string(),
1756                        header_delim_len: 18,
1757                        divider_delim_len: 3,
1758                        has_fields: false,
1759                        attributes_str: String::new(),
1760                        attributes: TestAttributes::default(),
1761                        file_name: None,
1762                    },
1763                    TestEntry::Example {
1764                        name: "sexp with comment between".to_string(),
1765                        input: b"code".to_vec(),
1766                        output: "(a (b))".to_string(),
1767                        header_delim_len: 18,
1768                        divider_delim_len: 3,
1769                        has_fields: false,
1770                        attributes_str: String::new(),
1771                        attributes: TestAttributes::default(),
1772                        file_name: None,
1773                    },
1774                    TestEntry::Example {
1775                        name: "sexp with ';'".to_string(),
1776                        input: b"code".to_vec(),
1777                        output: "(MISSING \";\")".to_string(),
1778                        header_delim_len: 25,
1779                        divider_delim_len: 3,
1780                        has_fields: false,
1781                        attributes_str: String::new(),
1782                        attributes: TestAttributes::default(),
1783                        file_name: None,
1784                    }
1785                ],
1786                file_path: None,
1787            }
1788        );
1789    }
1790
1791    #[test]
1792    fn test_parse_test_content_with_suffixes() {
1793        let entry = parse_test_content(
1794            "the-filename".to_string(),
1795            r"
1796==================asdf\()[]|{}*+?^$.-
1797First test
1798==================asdf\()[]|{}*+?^$.-
1799
1800=========================
1801NOT A TEST HEADER
1802=========================
1803-------------------------
1804
1805---asdf\()[]|{}*+?^$.-
1806
1807(a)
1808
1809==================asdf\()[]|{}*+?^$.-
1810Second test
1811==================asdf\()[]|{}*+?^$.-
1812
1813=========================
1814NOT A TEST HEADER
1815=========================
1816-------------------------
1817
1818---asdf\()[]|{}*+?^$.-
1819
1820(a)
1821
1822=========================asdf\()[]|{}*+?^$.-
1823Test name with = symbol
1824=========================asdf\()[]|{}*+?^$.-
1825
1826=========================
1827NOT A TEST HEADER
1828=========================
1829-------------------------
1830
1831---asdf\()[]|{}*+?^$.-
1832
1833(a)
1834
1835==============================asdf\()[]|{}*+?^$.-
1836Test containing equals
1837==============================asdf\()[]|{}*+?^$.-
1838
1839===
1840
1841------------------------------asdf\()[]|{}*+?^$.-
1842
1843(a)
1844
1845==============================asdf\()[]|{}*+?^$.-
1846Subsequent test containing equals
1847==============================asdf\()[]|{}*+?^$.-
1848
1849===
1850
1851------------------------------asdf\()[]|{}*+?^$.-
1852
1853(a)
1854"
1855            .trim(),
1856            None,
1857        );
1858
1859        let expected_input = b"\n=========================\n\
1860            NOT A TEST HEADER\n\
1861            =========================\n\
1862            -------------------------\n"
1863            .to_vec();
1864        pretty_assertions::assert_eq!(
1865            entry,
1866            TestEntry::Group {
1867                name: "the-filename".to_string(),
1868                children: vec![
1869                    TestEntry::Example {
1870                        name: "First test".to_string(),
1871                        input: expected_input.clone(),
1872                        output: "(a)".to_string(),
1873                        header_delim_len: 18,
1874                        divider_delim_len: 3,
1875                        has_fields: false,
1876                        attributes_str: String::new(),
1877                        attributes: TestAttributes::default(),
1878                        file_name: None,
1879                    },
1880                    TestEntry::Example {
1881                        name: "Second test".to_string(),
1882                        input: expected_input.clone(),
1883                        output: "(a)".to_string(),
1884                        header_delim_len: 18,
1885                        divider_delim_len: 3,
1886                        has_fields: false,
1887                        attributes_str: String::new(),
1888                        attributes: TestAttributes::default(),
1889                        file_name: None,
1890                    },
1891                    TestEntry::Example {
1892                        name: "Test name with = symbol".to_string(),
1893                        input: expected_input,
1894                        output: "(a)".to_string(),
1895                        header_delim_len: 25,
1896                        divider_delim_len: 3,
1897                        has_fields: false,
1898                        attributes_str: String::new(),
1899                        attributes: TestAttributes::default(),
1900                        file_name: None,
1901                    },
1902                    TestEntry::Example {
1903                        name: "Test containing equals".to_string(),
1904                        input: "\n===\n".into(),
1905                        output: "(a)".into(),
1906                        header_delim_len: 30,
1907                        divider_delim_len: 30,
1908                        has_fields: false,
1909                        attributes_str: String::new(),
1910                        attributes: TestAttributes::default(),
1911                        file_name: None,
1912                    },
1913                    TestEntry::Example {
1914                        name: "Subsequent test containing equals".to_string(),
1915                        input: "\n===\n".into(),
1916                        output: "(a)".into(),
1917                        header_delim_len: 30,
1918                        divider_delim_len: 30,
1919                        has_fields: false,
1920                        attributes_str: String::new(),
1921                        attributes: TestAttributes::default(),
1922                        file_name: None,
1923                    }
1924                ],
1925                file_path: None,
1926            }
1927        );
1928    }
1929
1930    #[test]
1931    fn test_parse_test_content_with_newlines_in_test_names() {
1932        let entry = parse_test_content(
1933            "the-filename".to_string(),
1934            r"
1935===============
1936name
1937with
1938newlines
1939===============
1940a
1941---
1942(b)
1943
1944====================
1945name with === signs
1946====================
1947code with ----
1948---
1949(d)
1950",
1951            None,
1952        );
1953
1954        assert_eq!(
1955            entry,
1956            TestEntry::Group {
1957                name: "the-filename".to_string(),
1958                file_path: None,
1959                children: vec![
1960                    TestEntry::Example {
1961                        name: "name\nwith\nnewlines".to_string(),
1962                        input: b"a".to_vec(),
1963                        output: "(b)".to_string(),
1964                        header_delim_len: 15,
1965                        divider_delim_len: 3,
1966                        has_fields: false,
1967                        attributes_str: String::new(),
1968                        attributes: TestAttributes::default(),
1969                        file_name: None,
1970                    },
1971                    TestEntry::Example {
1972                        name: "name with === signs".to_string(),
1973                        input: b"code with ----".to_vec(),
1974                        output: "(d)".to_string(),
1975                        header_delim_len: 20,
1976                        divider_delim_len: 3,
1977                        has_fields: false,
1978                        attributes_str: String::new(),
1979                        attributes: TestAttributes::default(),
1980                        file_name: None,
1981                    }
1982                ]
1983            }
1984        );
1985    }
1986
1987    #[test]
1988    fn test_parse_test_with_markers() {
1989        // do one with :skip, we should not see it in the entry output
1990
1991        let entry = parse_test_content(
1992            "the-filename".to_string(),
1993            r"
1994=====================
1995Test with skip marker
1996:skip
1997=====================
1998a
1999---
2000(b)
2001",
2002            None,
2003        );
2004
2005        assert_eq!(
2006            entry,
2007            TestEntry::Group {
2008                name: "the-filename".to_string(),
2009                file_path: None,
2010                children: vec![TestEntry::Example {
2011                    name: "Test with skip marker".to_string(),
2012                    input: b"a".to_vec(),
2013                    output: "(b)".to_string(),
2014                    header_delim_len: 21,
2015                    divider_delim_len: 3,
2016                    has_fields: false,
2017                    attributes_str: ":skip".to_string(),
2018                    attributes: TestAttributes {
2019                        skip: true,
2020                        platform: true,
2021                        fail_fast: false,
2022                        error: false,
2023                        cst: false,
2024                        languages: vec!["".into()]
2025                    },
2026                    file_name: None,
2027                }]
2028            }
2029        );
2030
2031        let entry = parse_test_content(
2032            "the-filename".to_string(),
2033            &format!(
2034                r"
2035=========================
2036Test with platform marker
2037:platform({})
2038:fail-fast
2039=========================
2040a
2041---
2042(b)
2043
2044=============================
2045Test with bad platform marker
2046:platform({})
2047
2048:language(foo)
2049=============================
2050a
2051---
2052(b)
2053
2054====================
2055Test with cst marker
2056:cst
2057====================
20581
2059---
20600:0 - 1:0   source_file
20610:0 - 0:1   expression
20620:0 - 0:1     number_literal `1`
2063",
2064                std::env::consts::OS,
2065                if std::env::consts::OS == "linux" {
2066                    "macos"
2067                } else {
2068                    "linux"
2069                }
2070            ),
2071            None,
2072        );
2073
2074        assert_eq!(
2075            entry,
2076            TestEntry::Group {
2077                name: "the-filename".to_string(),
2078                file_path: None,
2079                children: vec![
2080                    TestEntry::Example {
2081                        name: "Test with platform marker".to_string(),
2082                        input: b"a".to_vec(),
2083                        output: "(b)".to_string(),
2084                        header_delim_len: 25,
2085                        divider_delim_len: 3,
2086                        has_fields: false,
2087                        attributes_str: format!(":platform({})\n:fail-fast", std::env::consts::OS),
2088                        attributes: TestAttributes {
2089                            skip: false,
2090                            platform: true,
2091                            fail_fast: true,
2092                            error: false,
2093                            cst: false,
2094                            languages: vec!["".into()]
2095                        },
2096                        file_name: None,
2097                    },
2098                    TestEntry::Example {
2099                        name: "Test with bad platform marker".to_string(),
2100                        input: b"a".to_vec(),
2101                        output: "(b)".to_string(),
2102                        header_delim_len: 29,
2103                        divider_delim_len: 3,
2104                        has_fields: false,
2105                        attributes_str: if std::env::consts::OS == "linux" {
2106                            ":platform(macos)\n\n:language(foo)".to_string()
2107                        } else {
2108                            ":platform(linux)\n\n:language(foo)".to_string()
2109                        },
2110                        attributes: TestAttributes {
2111                            skip: false,
2112                            platform: false,
2113                            fail_fast: false,
2114                            error: false,
2115                            cst: false,
2116                            languages: vec!["foo".into()]
2117                        },
2118                        file_name: None,
2119                    },
2120                    TestEntry::Example {
2121                        name: "Test with cst marker".to_string(),
2122                        input: b"1".to_vec(),
2123                        output: "0:0 - 1:0   source_file
21240:0 - 0:1   expression
21250:0 - 0:1     number_literal `1`"
2126                            .to_string(),
2127                        header_delim_len: 20,
2128                        divider_delim_len: 3,
2129                        has_fields: false,
2130                        attributes_str: ":cst".to_string(),
2131                        attributes: TestAttributes {
2132                            skip: false,
2133                            platform: true,
2134                            fail_fast: false,
2135                            error: false,
2136                            cst: true,
2137                            languages: vec!["".into()]
2138                        },
2139                        file_name: None,
2140                    }
2141                ]
2142            }
2143        );
2144    }
2145
2146    fn clear_parse_rate(result: &mut TestResult) {
2147        let test_case_info = &mut result.info;
2148        match test_case_info {
2149            TestInfo::ParseTest {
2150                ref mut parse_rate, ..
2151            } => {
2152                assert!(parse_rate.is_some());
2153                *parse_rate = None;
2154            }
2155            TestInfo::Group { .. } | TestInfo::AssertionTest { .. } => {
2156                panic!("Unexpected test result")
2157            }
2158        }
2159    }
2160
2161    #[test]
2162    fn run_tests_simple() {
2163        let mut parser = Parser::new();
2164        let language = get_language("c");
2165        parser
2166            .set_language(&language)
2167            .expect("Failed to set language");
2168        let mut languages = BTreeMap::new();
2169        languages.insert("c", &language);
2170        let opts = TestOptions {
2171            path: PathBuf::from("foo"),
2172            debug: true,
2173            debug_graph: false,
2174            include: None,
2175            exclude: None,
2176            file_name: None,
2177            update: false,
2178            open_log: false,
2179            languages,
2180            color: true,
2181            show_fields: false,
2182            overview_only: false,
2183        };
2184
2185        // NOTE: The following test cases are combined to work around a race condition
2186        // in the loader
2187        {
2188            let test_entry = TestEntry::Group {
2189                name: "foo".to_string(),
2190                file_path: None,
2191                children: vec![TestEntry::Example {
2192                    name: "C Test 1".to_string(),
2193                    input: b"1;\n".to_vec(),
2194                    output: "(translation_unit (expression_statement (number_literal)))"
2195                        .to_string(),
2196                    header_delim_len: 25,
2197                    divider_delim_len: 3,
2198                    has_fields: false,
2199                    attributes_str: String::new(),
2200                    attributes: TestAttributes::default(),
2201                    file_name: None,
2202                }],
2203            };
2204
2205            let mut test_summary = TestSummary::new(true, TestStats::All, false, false, false);
2206            let mut corrected_entries = Vec::new();
2207            run_tests(
2208                &mut parser,
2209                test_entry,
2210                &opts,
2211                &mut test_summary,
2212                &mut corrected_entries,
2213                true,
2214            )
2215            .expect("Failed to run tests");
2216
2217            // parse rates will always be different, so we need to clear out these
2218            // fields to reliably assert equality below
2219            clear_parse_rate(&mut test_summary.parse_results.root_group[0]);
2220            test_summary.parse_stats.total_duration = Duration::from_secs(0);
2221
2222            let json_results = serde_json::to_string(&test_summary).unwrap();
2223
2224            assert_eq!(
2225                json_results,
2226                json!({
2227                  "parse_results": [
2228                    {
2229                      "name": "C Test 1",
2230                      "outcome": "Passed",
2231                      "parse_rate": null,
2232                      "test_num": 1
2233                    }
2234                  ],
2235                  "parse_failures": [],
2236                  "parse_stats": {
2237                    "successful_parses": 1,
2238                    "total_parses": 1,
2239                    "total_bytes": 3,
2240                    "total_duration": {
2241                      "secs": 0,
2242                      "nanos": 0,
2243                    }
2244                  },
2245                  "highlight_results": [],
2246                  "tag_results": [],
2247                  "query_results": []
2248                })
2249                .to_string()
2250            );
2251        }
2252        {
2253            let test_entry = TestEntry::Group {
2254                name: "corpus".to_string(),
2255                file_path: None,
2256                children: vec![
2257                    TestEntry::Group {
2258                        name: "group1".to_string(),
2259                        // This test passes
2260                        children: vec![TestEntry::Example {
2261                            name: "C Test 1".to_string(),
2262                            input: b"1;\n".to_vec(),
2263                            output: "(translation_unit (expression_statement (number_literal)))"
2264                                .to_string(),
2265                            header_delim_len: 25,
2266                            divider_delim_len: 3,
2267                            has_fields: false,
2268                            attributes_str: String::new(),
2269                            attributes: TestAttributes::default(),
2270                            file_name: None,
2271                        }],
2272                        file_path: None,
2273                    },
2274                    TestEntry::Group {
2275                        name: "group2".to_string(),
2276                        children: vec![
2277                            // This test passes
2278                            TestEntry::Example {
2279                                name: "C Test 2".to_string(),
2280                                input: b"1;\n".to_vec(),
2281                                output:
2282                                    "(translation_unit (expression_statement (number_literal)))"
2283                                        .to_string(),
2284                                header_delim_len: 25,
2285                                divider_delim_len: 3,
2286                                has_fields: false,
2287                                attributes_str: String::new(),
2288                                attributes: TestAttributes::default(),
2289                                file_name: None,
2290                            },
2291                            // This test fails, and is marked with fail-fast
2292                            TestEntry::Example {
2293                                name: "C Test 3".to_string(),
2294                                input: b"1;\n".to_vec(),
2295                                output:
2296                                    "(translation_unit (expression_statement (string_literal)))"
2297                                        .to_string(),
2298                                header_delim_len: 25,
2299                                divider_delim_len: 3,
2300                                has_fields: false,
2301                                attributes_str: String::new(),
2302                                attributes: TestAttributes {
2303                                    fail_fast: true,
2304                                    ..Default::default()
2305                                },
2306                                file_name: None,
2307                            },
2308                        ],
2309                        file_path: None,
2310                    },
2311                    // This group never runs because of the previous failure
2312                    TestEntry::Group {
2313                        name: "group3".to_string(),
2314                        // This test fails, and is marked with fail-fast
2315                        children: vec![TestEntry::Example {
2316                            name: "C Test 4".to_string(),
2317                            input: b"1;\n".to_vec(),
2318                            output: "(translation_unit (expression_statement (number_literal)))"
2319                                .to_string(),
2320                            header_delim_len: 25,
2321                            divider_delim_len: 3,
2322                            has_fields: false,
2323                            attributes_str: String::new(),
2324                            attributes: TestAttributes::default(),
2325                            file_name: None,
2326                        }],
2327                        file_path: None,
2328                    },
2329                ],
2330            };
2331
2332            let mut test_summary = TestSummary::new(true, TestStats::All, false, false, false);
2333            let mut corrected_entries = Vec::new();
2334            run_tests(
2335                &mut parser,
2336                test_entry,
2337                &opts,
2338                &mut test_summary,
2339                &mut corrected_entries,
2340                true,
2341            )
2342            .expect("Failed to run tests");
2343
2344            // parse rates will always be different, so we need to clear out these
2345            // fields to reliably assert equality below
2346            {
2347                let test_group_1_info = &mut test_summary.parse_results.root_group[0].info;
2348                match test_group_1_info {
2349                    TestInfo::Group {
2350                        ref mut children, ..
2351                    } => clear_parse_rate(&mut children[0]),
2352                    TestInfo::ParseTest { .. } | TestInfo::AssertionTest { .. } => {
2353                        panic!("Unexpected test result");
2354                    }
2355                }
2356                let test_group_2_info = &mut test_summary.parse_results.root_group[1].info;
2357                match test_group_2_info {
2358                    TestInfo::Group {
2359                        ref mut children, ..
2360                    } => {
2361                        clear_parse_rate(&mut children[0]);
2362                        clear_parse_rate(&mut children[1]);
2363                    }
2364                    TestInfo::ParseTest { .. } | TestInfo::AssertionTest { .. } => {
2365                        panic!("Unexpected test result");
2366                    }
2367                }
2368                test_summary.parse_stats.total_duration = Duration::from_secs(0);
2369            }
2370
2371            let json_results = serde_json::to_string(&test_summary).unwrap();
2372
2373            assert_eq!(
2374                json_results,
2375                json!({
2376                  "parse_results": [
2377                    {
2378                      "name": "group1",
2379                      "children": [
2380                        {
2381                          "name": "C Test 1",
2382                          "outcome": "Passed",
2383                          "parse_rate": null,
2384                          "test_num": 1
2385                        }
2386                      ]
2387                    },
2388                    {
2389                      "name": "group2",
2390                      "children": [
2391                        {
2392                          "name": "C Test 2",
2393                          "outcome": "Passed",
2394                          "parse_rate": null,
2395                          "test_num": 2
2396                        },
2397                        {
2398                          "name": "C Test 3",
2399                          "outcome": "Failed",
2400                          "parse_rate": null,
2401                          "test_num": 3
2402                        }
2403                      ]
2404                    }
2405                  ],
2406                  "parse_failures": [
2407                    {
2408                      "name": "C Test 3",
2409                      "actual": "(translation_unit (expression_statement (number_literal)))",
2410                      "expected": "(translation_unit (expression_statement (string_literal)))",
2411                      "is_cst": false,
2412                    }
2413                  ],
2414                  "parse_stats": {
2415                    "successful_parses": 2,
2416                    "total_parses": 3,
2417                    "total_bytes": 9,
2418                    "total_duration": {
2419                      "secs": 0,
2420                      "nanos": 0,
2421                    }
2422                  },
2423                  "highlight_results": [],
2424                  "tag_results": [],
2425                  "query_results": []
2426                })
2427                .to_string()
2428            );
2429        }
2430    }
2431}