sqruff_lib_core/parser/grammar/
anyof.rs

1use ahash::AHashSet;
2use itertools::{Itertools, chain};
3use nohash_hasher::IntMap;
4
5use super::sequence::{Bracketed, Sequence};
6use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
7use crate::errors::SQLParseError;
8use crate::helpers::ToMatchable;
9use crate::parser::context::ParseContext;
10use crate::parser::match_algorithms::{
11    longest_match, skip_start_index_forward_to_code, trim_to_terminator,
12};
13use crate::parser::match_result::{MatchResult, Matched, Span};
14use crate::parser::matchable::{
15    Matchable, MatchableCacheKey, MatchableTrait, next_matchable_cache_key,
16};
17use crate::parser::segments::base::ErasedSegment;
18use crate::parser::types::ParseMode;
19
20fn parse_mode_match_result(
21    segments: &[ErasedSegment],
22    current_match: MatchResult,
23    max_idx: u32,
24    parse_mode: ParseMode,
25) -> MatchResult {
26    if parse_mode == ParseMode::Strict {
27        return current_match;
28    }
29
30    let stop_idx = current_match.span.end;
31    if stop_idx == max_idx
32        || segments[stop_idx as usize..max_idx as usize]
33            .iter()
34            .all(|it| !it.is_code())
35    {
36        return current_match;
37    }
38
39    let trim_idx = skip_start_index_forward_to_code(segments, stop_idx, segments.len() as u32);
40
41    let unmatched_match = MatchResult {
42        span: Span {
43            start: trim_idx,
44            end: max_idx,
45        },
46        matched: Matched::SyntaxKind(SyntaxKind::Unparsable).into(),
47        ..MatchResult::default()
48    };
49
50    current_match.append(unmatched_match)
51}
52
53pub fn simple(
54    elements: &[Matchable],
55    parse_context: &ParseContext,
56    crumbs: Option<Vec<&str>>,
57) -> Option<(AHashSet<String>, SyntaxSet)> {
58    let option_simples: Vec<Option<(AHashSet<String>, SyntaxSet)>> = elements
59        .iter()
60        .map(|opt| opt.simple(parse_context, crumbs.clone()))
61        .collect();
62
63    if option_simples.iter().any(Option::is_none) {
64        return None;
65    }
66
67    let simple_buff: Vec<(AHashSet<String>, SyntaxSet)> =
68        option_simples.into_iter().flatten().collect();
69
70    let simple_raws: AHashSet<_> = simple_buff
71        .iter()
72        .flat_map(|(raws, _)| raws)
73        .cloned()
74        .collect();
75
76    let simple_types: SyntaxSet = simple_buff
77        .iter()
78        .flat_map(|(_, types)| types.clone())
79        .collect();
80
81    Some((simple_raws, simple_types))
82}
83
84#[derive(Debug, Clone)]
85pub struct AnyNumberOf {
86    pub exclude: Option<Matchable>,
87    pub(crate) elements: Vec<Matchable>,
88    pub terminators: Vec<Matchable>,
89    pub reset_terminators: bool,
90    pub max_times: Option<usize>,
91    pub min_times: usize,
92    pub max_times_per_element: Option<usize>,
93    pub allow_gaps: bool,
94    pub(crate) optional: bool,
95    pub parse_mode: ParseMode,
96    cache_key: MatchableCacheKey,
97}
98
99impl PartialEq for AnyNumberOf {
100    fn eq(&self, other: &Self) -> bool {
101        self.elements
102            .iter()
103            .zip(&other.elements)
104            .all(|(lhs, rhs)| lhs == rhs)
105    }
106}
107
108impl AnyNumberOf {
109    pub fn new(elements: Vec<Matchable>) -> Self {
110        Self {
111            elements,
112            exclude: None,
113            max_times: None,
114            min_times: 0,
115            max_times_per_element: None,
116            allow_gaps: true,
117            optional: false,
118            reset_terminators: false,
119            parse_mode: ParseMode::Strict,
120            terminators: Vec::new(),
121            cache_key: next_matchable_cache_key(),
122        }
123    }
124
125    pub fn optional(&mut self) {
126        self.optional = true;
127    }
128
129    pub fn disallow_gaps(&mut self) {
130        self.allow_gaps = false;
131    }
132
133    pub fn max_times(&mut self, max_times: usize) {
134        self.max_times = max_times.into();
135    }
136
137    pub fn min_times(&mut self, min_times: usize) {
138        self.min_times = min_times;
139    }
140}
141
142impl MatchableTrait for AnyNumberOf {
143    fn elements(&self) -> &[Matchable] {
144        &self.elements
145    }
146
147    fn is_optional(&self) -> bool {
148        self.optional || self.min_times == 0
149    }
150
151    fn simple(
152        &self,
153        parse_context: &ParseContext,
154        crumbs: Option<Vec<&str>>,
155    ) -> Option<(AHashSet<String>, SyntaxSet)> {
156        simple(&self.elements, parse_context, crumbs)
157    }
158
159    fn match_segments(
160        &self,
161        segments: &[ErasedSegment],
162        idx: u32,
163        parse_context: &mut ParseContext,
164    ) -> Result<MatchResult, SQLParseError> {
165        if let Some(exclude) = &self.exclude {
166            let match_result = parse_context
167                .deeper_match(false, &[], |ctx| exclude.match_segments(segments, idx, ctx))?;
168
169            if match_result.has_match() {
170                return Ok(MatchResult::empty_at(idx));
171            }
172        }
173
174        let mut n_matches = 0;
175        let mut option_counter: IntMap<_, usize> = self
176            .elements
177            .iter()
178            .map(|elem| (elem.cache_key(), 0))
179            .collect();
180        let mut matched_idx = idx;
181        let mut working_idx = idx;
182        let mut matched = MatchResult::empty_at(idx);
183        let mut max_idx = segments.len() as u32;
184
185        if self.parse_mode == ParseMode::Greedy {
186            let terminators = if self.reset_terminators {
187                self.terminators.clone()
188            } else {
189                chain(self.terminators.clone(), parse_context.terminators.clone()).collect_vec()
190            };
191            max_idx = trim_to_terminator(segments, idx, &terminators, parse_context)?;
192        }
193
194        loop {
195            if (n_matches >= self.min_times && matched_idx >= max_idx)
196                || self.max_times.is_some() && Some(n_matches) >= self.max_times
197            {
198                return Ok(parse_mode_match_result(
199                    segments,
200                    matched,
201                    max_idx,
202                    self.parse_mode,
203                ));
204            }
205
206            if matched_idx >= max_idx {
207                return Ok(MatchResult::empty_at(idx));
208            }
209
210            let (match_result, matched_option) =
211                parse_context.deeper_match(self.reset_terminators, &self.terminators, |ctx| {
212                    longest_match(
213                        &segments[..max_idx as usize],
214                        &self.elements,
215                        working_idx,
216                        ctx,
217                    )
218                })?;
219
220            if !match_result.has_match() {
221                if n_matches < self.min_times {
222                    matched = MatchResult::empty_at(idx);
223                }
224
225                return Ok(parse_mode_match_result(
226                    segments,
227                    matched,
228                    max_idx,
229                    self.parse_mode,
230                ));
231            }
232
233            let matched_option = matched_option.unwrap();
234            let matched_key = matched_option.cache_key();
235
236            if let Some(counter) = option_counter.get_mut(&matched_key) {
237                *counter += 1;
238
239                if self
240                    .max_times_per_element
241                    .is_some_and(|max_times_per_element| *counter > max_times_per_element)
242                {
243                    return Ok(parse_mode_match_result(
244                        segments,
245                        matched,
246                        max_idx,
247                        self.parse_mode,
248                    ));
249                }
250            }
251
252            matched = matched.append(match_result);
253            matched_idx = matched.span.end;
254            working_idx = matched_idx;
255            if self.allow_gaps {
256                working_idx =
257                    skip_start_index_forward_to_code(segments, matched_idx, segments.len() as u32);
258            }
259            n_matches += 1;
260        }
261    }
262
263    fn cache_key(&self) -> MatchableCacheKey {
264        self.cache_key
265    }
266
267    #[track_caller]
268    fn copy(
269        &self,
270        insert: Option<Vec<Matchable>>,
271        at: Option<usize>,
272        before: Option<Matchable>,
273        remove: Option<Vec<Matchable>>,
274        terminators: Vec<Matchable>,
275        replace_terminators: bool,
276    ) -> Matchable {
277        let mut new_elements = self.elements.clone();
278
279        if let Some(insert_elements) = insert {
280            if let Some(before_element) = before {
281                if let Some(index) = self.elements.iter().position(|e| e == &before_element) {
282                    new_elements.splice(index..index, insert_elements);
283                } else {
284                    panic!("Element for insertion before not found");
285                }
286            } else if let Some(at_index) = at {
287                new_elements.splice(at_index..at_index, insert_elements);
288            } else {
289                new_elements.extend(insert_elements);
290            }
291        }
292
293        if let Some(remove_elements) = remove {
294            new_elements.retain(|elem| !remove_elements.contains(elem));
295        }
296
297        let mut new_grammar = self.clone();
298
299        new_grammar.elements = new_elements;
300        new_grammar.terminators = if replace_terminators {
301            terminators
302        } else {
303            [self.terminators.clone(), terminators].concat()
304        };
305
306        new_grammar.to_matchable()
307    }
308}
309
310pub fn one_of(elements: Vec<Matchable>) -> AnyNumberOf {
311    let mut matcher = AnyNumberOf::new(elements);
312    matcher.max_times(1);
313    matcher.min_times(1);
314    matcher
315}
316
317pub fn optionally_bracketed(elements: Vec<Matchable>) -> AnyNumberOf {
318    let mut args = vec![Bracketed::new(elements.clone()).to_matchable()];
319
320    if elements.len() == 1 {
321        args.extend(elements);
322    } else {
323        args.push(Sequence::new(elements).to_matchable());
324    }
325
326    one_of(args)
327}
328
329pub fn any_set_of(elements: Vec<Matchable>) -> AnyNumberOf {
330    let mut any_number_of = AnyNumberOf::new(elements);
331    any_number_of.max_times = None;
332    any_number_of.max_times_per_element = Some(1);
333    any_number_of
334}