sqruff_lib_core/parser/
node_matcher.rs

1use std::sync::OnceLock;
2
3use super::matchable::MatchableTrait;
4use crate::dialects::Dialect;
5use crate::dialects::syntax::SyntaxKind;
6use crate::errors::SQLParseError;
7use crate::parser::context::ParseContext;
8use crate::parser::match_result::{MatchResult, Matched};
9use crate::parser::matchable::Matchable;
10use crate::parser::segments::ErasedSegment;
11
12#[derive(Clone)]
13pub struct NodeMatcher {
14    node_kind: SyntaxKind,
15    match_grammar: OnceLock<Matchable>,
16    factory: fn(&Dialect) -> Matchable,
17}
18
19impl NodeMatcher {
20    pub fn new(node_kind: SyntaxKind, build_grammar: fn(&Dialect) -> Matchable) -> Self {
21        Self {
22            node_kind,
23            match_grammar: OnceLock::new(),
24            factory: build_grammar,
25        }
26    }
27
28    pub fn match_grammar(&self, dialect: &Dialect) -> Matchable {
29        self.match_grammar
30            .get_or_init(|| (self.factory)(dialect))
31            .clone()
32    }
33
34    pub fn replace(&mut self, match_grammar: Matchable) {
35        self.match_grammar = OnceLock::new();
36        let _ = self.match_grammar.set(match_grammar);
37    }
38}
39
40impl std::fmt::Debug for NodeMatcher {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("NodeMatcher")
43            .field("node_kind", &self.node_kind)
44            .field("match_grammar", &"...")
45            .field("factory", &"...")
46            .finish()
47    }
48}
49
50impl PartialEq for NodeMatcher {
51    fn eq(&self, _other: &Self) -> bool {
52        todo!()
53    }
54}
55
56impl MatchableTrait for NodeMatcher {
57    fn get_type(&self) -> SyntaxKind {
58        self.node_kind
59    }
60
61    fn match_grammar(&self, dialect: &Dialect) -> Option<Matchable> {
62        self.match_grammar(dialect).into()
63    }
64
65    fn elements(&self) -> &[Matchable] {
66        &[]
67    }
68
69    fn match_segments(
70        &self,
71        segments: &[ErasedSegment],
72        idx: u32,
73        parse_context: &mut ParseContext,
74    ) -> Result<MatchResult, SQLParseError> {
75        if idx >= segments.len() as u32 {
76            return Ok(MatchResult::empty_at(idx));
77        }
78
79        if segments[idx as usize].get_type() == self.get_type() {
80            return Ok(MatchResult::from_span(idx, idx + 1));
81        }
82
83        let grammar = self.match_grammar(parse_context.dialect());
84        let match_result = parse_context
85            .deeper_match(false, &[], |ctx| grammar.match_segments(segments, idx, ctx))?;
86
87        Ok(match_result.wrap(Matched::SyntaxKind(self.node_kind)))
88    }
89}