sqruff_lib_core/parser/
node_matcher.rs

1use std::sync::OnceLock;
2
3use super::matchable::MatchableTrait;
4use crate::dialects::Dialect;
5use crate::dialects::syntax::SyntaxKind;
6use crate::errors::SQLParseError;
7use crate::parser::context::ParseContext;
8use crate::parser::match_result::{MatchResult, Matched};
9use crate::parser::matchable::Matchable;
10use crate::parser::segments::ErasedSegment;
11
12#[macro_export]
13macro_rules! vec_of_erased {
14    ($($elem:expr),* $(,)?) => {{
15        vec![$(ToMatchable::to_matchable($elem)),*]
16    }};
17}
18
19#[derive(Clone)]
20pub struct NodeMatcher {
21    node_kind: SyntaxKind,
22    match_grammar: OnceLock<Matchable>,
23    factory: fn(&Dialect) -> Matchable,
24}
25
26impl NodeMatcher {
27    pub fn new(node_kind: SyntaxKind, build_grammar: fn(&Dialect) -> Matchable) -> Self {
28        Self {
29            node_kind,
30            match_grammar: OnceLock::new(),
31            factory: build_grammar,
32        }
33    }
34
35    pub fn match_grammar(&self, dialect: &Dialect) -> Matchable {
36        self.match_grammar
37            .get_or_init(|| (self.factory)(dialect))
38            .clone()
39    }
40
41    pub fn replace(&mut self, match_grammar: Matchable) {
42        self.match_grammar = OnceLock::new();
43        let _ = self.match_grammar.set(match_grammar);
44    }
45}
46
47impl std::fmt::Debug for NodeMatcher {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("NodeMatcher")
50            .field("node_kind", &self.node_kind)
51            .field("match_grammar", &"...")
52            .field("factory", &"...")
53            .finish()
54    }
55}
56
57impl PartialEq for NodeMatcher {
58    fn eq(&self, _other: &Self) -> bool {
59        todo!()
60    }
61}
62
63impl MatchableTrait for NodeMatcher {
64    fn get_type(&self) -> SyntaxKind {
65        self.node_kind
66    }
67
68    fn match_grammar(&self, dialect: &Dialect) -> Option<Matchable> {
69        self.match_grammar(dialect).into()
70    }
71
72    fn elements(&self) -> &[Matchable] {
73        &[]
74    }
75
76    fn match_segments(
77        &self,
78        segments: &[ErasedSegment],
79        idx: u32,
80        parse_context: &mut ParseContext,
81    ) -> Result<MatchResult, SQLParseError> {
82        if idx >= segments.len() as u32 {
83            return Ok(MatchResult::empty_at(idx));
84        }
85
86        if segments[idx as usize].get_type() == self.get_type() {
87            return Ok(MatchResult::from_span(idx, idx + 1));
88        }
89
90        let grammar = self.match_grammar(parse_context.dialect());
91        let match_result = parse_context
92            .deeper_match(false, &[], |ctx| grammar.match_segments(segments, idx, ctx))?;
93
94        Ok(match_result.wrap(Matched::SyntaxKind(self.node_kind)))
95    }
96}