sqruff_lib_core/parser/grammar/
sequence.rs

1use std::iter::zip;
2use std::ops::{Deref, DerefMut};
3
4use ahash::AHashSet;
5
6use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
7use crate::errors::SQLParseError;
8use crate::helpers::ToMatchable;
9use crate::parser::context::ParseContext;
10use crate::parser::match_algorithms::{
11    resolve_bracket, skip_start_index_forward_to_code, skip_stop_index_backward_to_code,
12    trim_to_terminator,
13};
14use crate::parser::match_result::{MatchResult, Matched, Span};
15use crate::parser::matchable::{
16    Matchable, MatchableCacheKey, MatchableTrait, next_matchable_cache_key,
17};
18use crate::parser::segments::ErasedSegment;
19use crate::parser::types::ParseMode;
20
21fn flush_metas(
22    tpre_nc_idx: u32,
23    post_nc_idx: u32,
24    meta_buffer: Vec<SyntaxKind>,
25    _segments: &[ErasedSegment],
26) -> Vec<(u32, SyntaxKind)> {
27    let meta_idx = if meta_buffer.iter().all(|it| it.indent_val() >= 0) {
28        tpre_nc_idx
29    } else {
30        post_nc_idx
31    };
32    meta_buffer.into_iter().map(|it| (meta_idx, it)).collect()
33}
34
35#[derive(Debug, Clone)]
36pub struct Sequence {
37    elements: Vec<Matchable>,
38    pub parse_mode: ParseMode,
39    pub allow_gaps: bool,
40    is_optional: bool,
41    pub terminators: Vec<Matchable>,
42    cache_key: MatchableCacheKey,
43}
44
45impl Sequence {
46    pub fn disallow_gaps(&mut self) {
47        self.allow_gaps = false;
48    }
49}
50
51impl Sequence {
52    pub fn new(elements: Vec<Matchable>) -> Self {
53        Self {
54            elements,
55            allow_gaps: true,
56            is_optional: false,
57            parse_mode: ParseMode::Strict,
58            terminators: Vec::new(),
59            cache_key: next_matchable_cache_key(),
60        }
61    }
62
63    pub fn optional(&mut self) {
64        self.is_optional = true;
65    }
66
67    pub fn terminators(mut self, terminators: Vec<Matchable>) -> Self {
68        self.terminators = terminators;
69        self
70    }
71
72    pub fn parse_mode(&mut self, mode: ParseMode) {
73        self.parse_mode = mode;
74    }
75
76    pub fn allow_gaps(mut self, allow_gaps: bool) -> Self {
77        self.allow_gaps = allow_gaps;
78        self
79    }
80}
81
82impl PartialEq for Sequence {
83    fn eq(&self, other: &Self) -> bool {
84        zip(&self.elements, &other.elements).all(|(a, b)| a == b)
85    }
86}
87
88impl MatchableTrait for Sequence {
89    fn elements(&self) -> &[Matchable] {
90        &self.elements
91    }
92
93    fn is_optional(&self) -> bool {
94        self.is_optional
95    }
96
97    fn simple(
98        &self,
99        parse_context: &ParseContext,
100        crumbs: Option<Vec<&str>>,
101    ) -> Option<(AHashSet<String>, SyntaxSet)> {
102        let mut simple_raws = AHashSet::new();
103        let mut simple_types = SyntaxSet::EMPTY;
104
105        for opt in &self.elements {
106            let (raws, types) = opt.simple(parse_context, crumbs.clone())?;
107
108            simple_raws.extend(raws);
109            simple_types.extend(types);
110
111            if !opt.is_optional() {
112                return Some((simple_raws, simple_types));
113            }
114        }
115
116        (simple_raws, simple_types).into()
117    }
118
119    fn match_segments(
120        &self,
121        segments: &[ErasedSegment],
122        mut idx: u32,
123        parse_context: &mut ParseContext,
124    ) -> Result<MatchResult, SQLParseError> {
125        let start_idx = idx;
126        let mut matched_idx = idx;
127        let mut max_idx = segments.len() as u32;
128        let mut insert_segments = Vec::new();
129        let mut child_matches = Vec::new();
130        let mut first_match = true;
131        let mut meta_buffer = Vec::new();
132
133        if self.parse_mode == ParseMode::Greedy {
134            let terminators =
135                [self.terminators.clone(), parse_context.terminators.clone()].concat();
136
137            max_idx = trim_to_terminator(segments, idx, &terminators, parse_context)?;
138        }
139
140        for elem in &self.elements {
141            if let Some(indent) = elem.as_conditional() {
142                let match_result = indent.match_segments(segments, matched_idx, parse_context)?;
143                for (_, submatch) in match_result.insert_segments {
144                    meta_buffer.push(submatch);
145                }
146                continue;
147            } else if let Some(indent) = elem.as_indent() {
148                meta_buffer.push(indent.kind);
149                continue;
150            }
151
152            idx = if self.allow_gaps {
153                skip_start_index_forward_to_code(segments, matched_idx, max_idx)
154            } else {
155                matched_idx
156            };
157
158            if idx >= max_idx {
159                if elem.is_optional() {
160                    continue;
161                }
162
163                if self.parse_mode == ParseMode::Strict || matched_idx == start_idx {
164                    return Ok(MatchResult::empty_at(idx));
165                }
166
167                insert_segments.extend(meta_buffer.into_iter().map(|meta| (matched_idx, meta)));
168
169                return Ok(MatchResult {
170                    span: Span {
171                        start: start_idx,
172                        end: matched_idx,
173                    },
174                    insert_segments,
175                    child_matches,
176                    matched: Matched::SyntaxKind(SyntaxKind::Unparsable).into(),
177                });
178            }
179
180            let mut elem_match = parse_context.deeper_match(false, &[], |ctx| {
181                elem.match_segments(&segments[..max_idx as usize], idx, ctx)
182            })?;
183
184            if !elem_match.has_match() {
185                if elem.is_optional() {
186                    continue;
187                }
188
189                if self.parse_mode == ParseMode::Strict {
190                    return Ok(MatchResult::empty_at(idx));
191                }
192
193                if self.parse_mode == ParseMode::GreedyOnceStarted && matched_idx == start_idx {
194                    return Ok(MatchResult::empty_at(idx));
195                }
196
197                if matched_idx == start_idx {
198                    return Ok(MatchResult {
199                        span: Span {
200                            start: start_idx,
201                            end: max_idx,
202                        },
203                        matched: Matched::SyntaxKind(SyntaxKind::Unparsable).into(),
204                        ..MatchResult::default()
205                    });
206                }
207
208                child_matches.push(MatchResult {
209                    span: Span {
210                        start: skip_start_index_forward_to_code(segments, matched_idx, max_idx),
211                        end: max_idx,
212                    },
213                    matched: Matched::SyntaxKind(SyntaxKind::Unparsable).into(),
214                    ..MatchResult::default()
215                });
216
217                return Ok(MatchResult {
218                    span: Span {
219                        start: start_idx,
220                        end: max_idx,
221                    },
222                    insert_segments,
223                    child_matches,
224                    matched: None,
225                });
226            }
227
228            let meta_buffer = std::mem::take(&mut meta_buffer);
229            insert_segments.append(&mut flush_metas(matched_idx, idx, meta_buffer, segments));
230
231            matched_idx = elem_match.span.end;
232
233            if first_match && self.parse_mode == ParseMode::GreedyOnceStarted {
234                let terminators =
235                    [self.terminators.clone(), parse_context.terminators.clone()].concat();
236                max_idx = trim_to_terminator(segments, matched_idx, &terminators, parse_context)?;
237                first_match = false;
238            }
239
240            if elem_match.matched.is_some() {
241                child_matches.push(elem_match);
242                continue;
243            }
244
245            child_matches.append(&mut elem_match.child_matches);
246            insert_segments.append(&mut elem_match.insert_segments);
247        }
248
249        insert_segments.extend(meta_buffer.into_iter().map(|meta| (matched_idx, meta)));
250
251        if matches!(
252            self.parse_mode,
253            ParseMode::Greedy | ParseMode::GreedyOnceStarted
254        ) && max_idx > matched_idx
255        {
256            let idx = skip_start_index_forward_to_code(segments, matched_idx, max_idx);
257            let stop_idx = skip_stop_index_backward_to_code(segments, max_idx, idx);
258
259            if stop_idx > idx {
260                child_matches.push(MatchResult {
261                    span: Span {
262                        start: idx,
263                        end: stop_idx,
264                    },
265                    matched: Matched::SyntaxKind(SyntaxKind::Unparsable).into(),
266                    ..Default::default()
267                });
268                matched_idx = stop_idx;
269            }
270        }
271
272        Ok(MatchResult {
273            span: Span {
274                start: start_idx,
275                end: matched_idx,
276            },
277            matched: None,
278            insert_segments,
279            child_matches,
280        })
281    }
282
283    fn cache_key(&self) -> MatchableCacheKey {
284        self.cache_key
285    }
286
287    fn copy(
288        &self,
289        insert: Option<Vec<Matchable>>,
290        at: Option<usize>,
291        before: Option<Matchable>,
292        remove: Option<Vec<Matchable>>,
293        terminators: Vec<Matchable>,
294        replace_terminators: bool,
295    ) -> Matchable {
296        let mut new_elements = self.elements.clone();
297
298        if let Some(insert_elements) = insert {
299            if let Some(before_element) = before {
300                if let Some(index) = self.elements.iter().position(|e| e == &before_element) {
301                    new_elements.splice(index..index, insert_elements);
302                } else {
303                    panic!("Element for insertion before not found");
304                }
305            } else if let Some(at_index) = at {
306                new_elements.splice(at_index..at_index, insert_elements);
307            } else {
308                new_elements.extend(insert_elements);
309            }
310        }
311
312        if let Some(remove_elements) = remove {
313            new_elements.retain(|elem| !remove_elements.contains(elem));
314        }
315
316        let mut new_grammar = self.clone();
317
318        new_grammar.elements = new_elements;
319        new_grammar.terminators = if replace_terminators {
320            terminators
321        } else {
322            [self.terminators.clone(), terminators].concat()
323        };
324
325        new_grammar.to_matchable()
326    }
327}
328
329#[derive(Debug, Clone, PartialEq)]
330pub struct Bracketed {
331    pub bracket_type: &'static str,
332    pub bracket_pairs_set: &'static str,
333    allow_gaps: bool,
334    pub this: Sequence,
335}
336
337impl Bracketed {
338    pub fn new(args: Vec<Matchable>) -> Self {
339        Self {
340            bracket_type: "round",
341            bracket_pairs_set: "bracket_pairs",
342            allow_gaps: true,
343            this: Sequence::new(args),
344        }
345    }
346}
347
348type BracketInfo = Result<(Matchable, Matchable, bool), String>;
349
350impl Bracketed {
351    pub fn bracket_type(&mut self, bracket_type: &'static str) {
352        self.bracket_type = bracket_type;
353    }
354
355    fn get_bracket_from_dialect(&self, parse_context: &ParseContext) -> BracketInfo {
356        let bracket_pairs = parse_context.dialect().bracket_sets(self.bracket_pairs_set);
357        for (bracket_type, start_ref, end_ref, persists) in bracket_pairs {
358            if bracket_type == self.bracket_type {
359                let start_bracket = parse_context.dialect().r#ref(start_ref);
360                let end_bracket = parse_context.dialect().r#ref(end_ref);
361
362                return Ok((start_bracket, end_bracket, persists));
363            }
364        }
365        Err(format!(
366            "bracket_type {:?} not found in bracket_pairs ({}) of {:?} dialect.",
367            self.bracket_type,
368            self.bracket_pairs_set,
369            parse_context.dialect().name
370        ))
371    }
372}
373
374impl Deref for Bracketed {
375    type Target = Sequence;
376
377    fn deref(&self) -> &Self::Target {
378        &self.this
379    }
380}
381
382impl DerefMut for Bracketed {
383    fn deref_mut(&mut self) -> &mut Self::Target {
384        &mut self.this
385    }
386}
387
388impl MatchableTrait for Bracketed {
389    fn elements(&self) -> &[Matchable] {
390        &self.elements
391    }
392
393    fn is_optional(&self) -> bool {
394        self.this.is_optional()
395    }
396
397    fn simple(
398        &self,
399        parse_context: &ParseContext,
400        crumbs: Option<Vec<&str>>,
401    ) -> Option<(AHashSet<String>, SyntaxSet)> {
402        let (start_bracket, _, _) = self.get_bracket_from_dialect(parse_context).unwrap();
403        start_bracket.simple(parse_context, crumbs)
404    }
405
406    fn match_segments(
407        &self,
408        segments: &[ErasedSegment],
409        idx: u32,
410        parse_context: &mut ParseContext,
411    ) -> Result<MatchResult, SQLParseError> {
412        let (start_bracket, end_bracket, bracket_persists) =
413            self.get_bracket_from_dialect(parse_context).unwrap();
414
415        let start_match = parse_context.deeper_match(false, &[], |ctx| {
416            start_bracket.match_segments(segments, idx, ctx)
417        })?;
418
419        if !start_match.has_match() {
420            return Ok(MatchResult::empty_at(idx));
421        }
422
423        let start_match_span = start_match.span;
424
425        let bracketed_match = resolve_bracket(
426            segments,
427            start_match,
428            start_bracket.clone(),
429            &[start_bracket],
430            std::slice::from_ref(&end_bracket),
431            &[bracket_persists],
432            parse_context,
433            false,
434        )?;
435
436        let mut idx = start_match_span.end;
437        let mut end_idx = bracketed_match.span.end - 1;
438
439        if self.allow_gaps {
440            idx = skip_start_index_forward_to_code(segments, idx, segments.len() as u32);
441            end_idx = skip_stop_index_backward_to_code(segments, end_idx, idx);
442        }
443
444        let mut content_match =
445            parse_context.deeper_match(true, std::slice::from_ref(&end_bracket), |ctx| {
446                self.this
447                    .match_segments(&segments[..end_idx as usize], idx, ctx)
448            })?;
449
450        if content_match.span.end != end_idx && self.parse_mode == ParseMode::Strict {
451            return Ok(MatchResult::empty_at(idx));
452        }
453
454        let intermediate_slice = Span {
455            start: content_match.span.end,
456            end: bracketed_match.span.end - 1,
457        };
458
459        if !self.allow_gaps && intermediate_slice.start == intermediate_slice.end {
460            unimplemented!()
461        }
462
463        let mut child_matches = bracketed_match.child_matches;
464        if content_match.matched.is_some() {
465            child_matches.push(content_match);
466        } else {
467            child_matches.append(&mut content_match.child_matches);
468        }
469
470        Ok(MatchResult {
471            child_matches,
472            ..bracketed_match
473        })
474    }
475
476    fn cache_key(&self) -> MatchableCacheKey {
477        self.this.cache_key()
478    }
479}