sqruff_lib_core/parser/
match_algorithms.rs

1use ahash::AHashMap;
2use itertools::{Itertools as _, enumerate, multiunzip};
3use smol_str::StrExt;
4
5use super::context::ParseContext;
6use super::match_result::{MatchResult, Matched, Span};
7use super::matchable::{Matchable, MatchableTrait};
8use super::segments::base::ErasedSegment;
9use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
10use crate::errors::SQLParseError;
11
12pub fn skip_start_index_forward_to_code(
13    segments: &[ErasedSegment],
14    start_idx: u32,
15    max_idx: u32,
16) -> u32 {
17    let mut idx = start_idx;
18    while idx < max_idx {
19        if segments[idx as usize].is_code() {
20            break;
21        }
22        idx += 1;
23    }
24    idx
25}
26
27pub fn skip_stop_index_backward_to_code(
28    segments: &[ErasedSegment],
29    stop_idx: u32,
30    min_idx: u32,
31) -> u32 {
32    let mut idx = stop_idx;
33    while idx > min_idx {
34        if segments[idx as usize - 1].is_code() {
35            break;
36        }
37        idx -= 1;
38    }
39    idx
40}
41
42pub fn first_trimmed_raw(seg: &ErasedSegment) -> String {
43    seg.raw()
44        .to_uppercase_smolstr()
45        .split(char::is_whitespace)
46        .next()
47        .map(ToString::to_string)
48        .unwrap_or_default()
49}
50
51pub fn first_non_whitespace(
52    segments: &[ErasedSegment],
53    start_idx: u32,
54) -> Option<(String, &SyntaxSet)> {
55    for segment in segments.iter().skip(start_idx as usize) {
56        if let Some(raw) = segment.first_non_whitespace_segment_raw_upper() {
57            return Some((raw, segment.class_types()));
58        }
59    }
60
61    None
62}
63
64pub fn prune_options(
65    options: &[Matchable],
66    segments: &[ErasedSegment],
67    parse_context: &mut ParseContext,
68    start_idx: u32,
69) -> Vec<Matchable> {
70    let mut available_options = vec![];
71
72    // Find the first code element to match against.
73    let Some((first_raw, first_types)) = first_non_whitespace(segments, start_idx) else {
74        return options.to_vec();
75    };
76
77    for opt in options {
78        let Some(simple) = opt.simple(parse_context, None) else {
79            // This element is not simple, we have to do a
80            // full match with it...
81            available_options.push(opt.clone());
82            continue;
83        };
84
85        // Otherwise we have a simple option, so let's use
86        // it for pruning.
87        let (simple_raws, simple_types) = simple;
88        let mut matched = false;
89
90        // We want to know if the first meaningful element of the str_buff
91        // matches the option, based on either simple _raw_ matching or
92        // simple _type_ matching.
93
94        // Match Raws
95        if simple_raws.contains(&first_raw) {
96            // If we get here, it's matched the FIRST element of the string buffer.
97            available_options.push(opt.clone());
98            matched = true;
99        }
100
101        if !matched && first_types.intersects(&simple_types) {
102            available_options.push(opt.clone());
103        }
104    }
105
106    available_options
107}
108
109pub fn longest_match(
110    segments: &[ErasedSegment],
111    matchers: &[Matchable],
112    idx: u32,
113    parse_context: &mut ParseContext,
114) -> Result<(MatchResult, Option<Matchable>), SQLParseError> {
115    let max_idx = segments.len() as u32;
116
117    if matchers.is_empty() || idx == max_idx {
118        return Ok((MatchResult::empty_at(idx), None));
119    }
120
121    let available_options = prune_options(matchers, segments, parse_context, idx);
122    let available_options_count = available_options.len();
123
124    if available_options.is_empty() {
125        return Ok((MatchResult::empty_at(idx), None));
126    }
127
128    let terminators = parse_context.terminators.clone();
129    let cache_position = segments[idx as usize].get_position_marker().unwrap();
130
131    let loc_key = (
132        segments[idx as usize].raw().clone(),
133        cache_position.working_loc(),
134        segments[idx as usize].get_type(),
135        max_idx,
136    );
137
138    let loc_key = parse_context.loc_key(loc_key);
139
140    let mut best_match = MatchResult::empty_at(idx);
141    let mut best_matcher = None;
142
143    'matcher: for (matcher_idx, matcher) in enumerate(available_options) {
144        let matcher_key = matcher.cache_key();
145        let res_match = parse_context.check_parse_cache(loc_key, matcher_key);
146
147        let res_match = match res_match {
148            Some(res_match) => res_match,
149            None => {
150                let res_match = matcher.match_segments(segments, idx, parse_context)?;
151                parse_context.put_parse_cache(loc_key, matcher_key, res_match.clone());
152                res_match
153            }
154        };
155
156        if res_match.has_match() && res_match.span.end == max_idx {
157            return Ok((res_match, matcher.into()));
158        }
159
160        if res_match.is_better_than(&best_match) {
161            best_match = res_match;
162            best_matcher = matcher.into();
163
164            if matcher_idx == available_options_count - 1 {
165                break 'matcher;
166            } else if !terminators.is_empty() {
167                let next_code_idx = skip_start_index_forward_to_code(
168                    segments,
169                    best_match.span.end,
170                    segments.len() as u32,
171                );
172
173                if next_code_idx == segments.len() as u32 {
174                    break 'matcher;
175                }
176
177                for terminator in &terminators {
178                    let terminator_match =
179                        terminator.match_segments(segments, next_code_idx, parse_context)?;
180
181                    if terminator_match.has_match() {
182                        break 'matcher;
183                    }
184                }
185            }
186        }
187    }
188
189    Ok((best_match, best_matcher))
190}
191
192fn next_match(
193    segments: &[ErasedSegment],
194    idx: u32,
195    matchers: &[Matchable],
196    parse_context: &mut ParseContext,
197) -> Result<(MatchResult, Option<Matchable>), SQLParseError> {
198    let max_idx = segments.len() as u32;
199
200    if idx >= max_idx {
201        return Ok((MatchResult::empty_at(idx), None));
202    }
203
204    let mut raw_simple_map: AHashMap<String, Vec<usize>> = AHashMap::new();
205    let mut type_simple_map: AHashMap<SyntaxKind, Vec<usize>> = AHashMap::new();
206
207    for (idx, matcher) in enumerate(matchers) {
208        let (raws, types) = matcher.simple(parse_context, None).unwrap();
209
210        raw_simple_map.reserve(raws.len());
211        type_simple_map.reserve(types.len());
212
213        for raw in raws {
214            raw_simple_map.entry(raw).or_default().push(idx);
215        }
216
217        for typ in types {
218            type_simple_map.entry(typ).or_default().push(idx);
219        }
220    }
221
222    for idx in idx..max_idx {
223        let seg = &segments[idx as usize];
224        let mut matcher_idxs = raw_simple_map
225            .get(&first_trimmed_raw(seg))
226            .cloned()
227            .unwrap_or_default();
228
229        let keys = type_simple_map.keys().copied().collect();
230        let type_overlap = seg.class_types().clone().intersection(&keys);
231
232        for typ in type_overlap {
233            matcher_idxs.extend(type_simple_map[&typ].clone());
234        }
235
236        if matcher_idxs.is_empty() {
237            continue;
238        }
239
240        matcher_idxs.sort();
241        for matcher_idx in matcher_idxs {
242            let matcher = &matchers[matcher_idx];
243            let match_result = matcher.match_segments(segments, idx, parse_context)?;
244
245            if match_result.has_match() {
246                return Ok((match_result, matcher.clone().into()));
247            }
248        }
249    }
250
251    Ok((MatchResult::empty_at(idx), None))
252}
253
254#[allow(clippy::too_many_arguments)]
255pub fn resolve_bracket(
256    segments: &[ErasedSegment],
257    opening_match: MatchResult,
258    opening_matcher: Matchable,
259    start_brackets: &[Matchable],
260    end_brackets: &[Matchable],
261    bracket_persists: &[bool],
262    parse_context: &mut ParseContext,
263    nested_match: bool,
264) -> Result<MatchResult, SQLParseError> {
265    let type_idx = start_brackets
266        .iter()
267        .position(|it| it == &opening_matcher)
268        .unwrap();
269    let mut matched_idx = opening_match.span.end;
270    let mut child_matches = vec![opening_match.clone()];
271
272    let matchers = [start_brackets, end_brackets].concat();
273    loop {
274        let (match_result, matcher) = next_match(segments, matched_idx, &matchers, parse_context)?;
275
276        if !match_result.has_match() {
277            return Err(SQLParseError {
278                description: "Couldn't find closing bracket for opening bracket.".into(),
279                segment: segments[opening_match.span.start as usize].clone().into(),
280            });
281        }
282
283        let matcher = matcher.unwrap();
284        if end_brackets.contains(&matcher) {
285            let closing_idx = end_brackets.iter().position(|it| it == &matcher).unwrap();
286
287            if closing_idx == type_idx {
288                let match_span = match_result.span;
289                let persists = bracket_persists[type_idx];
290                let insert_segments = vec![
291                    (opening_match.span.end, SyntaxKind::Indent),
292                    (match_result.span.start, SyntaxKind::Dedent),
293                ];
294
295                child_matches.push(match_result);
296                let match_result = MatchResult {
297                    span: Span {
298                        start: opening_match.span.start,
299                        end: match_span.end,
300                    },
301                    matched: None,
302                    insert_segments,
303                    child_matches,
304                };
305
306                if !persists {
307                    return Ok(match_result);
308                }
309
310                return Ok(match_result.wrap(Matched::SyntaxKind(SyntaxKind::Bracketed)));
311            }
312
313            return Err(SQLParseError {
314                description: "Found unexpected end bracket!".into(),
315                segment: segments[(match_result.span.end - 1) as usize]
316                    .clone()
317                    .into(),
318            });
319        }
320
321        let inner_match = resolve_bracket(
322            segments,
323            match_result,
324            matcher,
325            start_brackets,
326            end_brackets,
327            bracket_persists,
328            parse_context,
329            false,
330        )?;
331
332        matched_idx = inner_match.span.end;
333        if nested_match {
334            child_matches.push(inner_match);
335        }
336    }
337}
338
339type BracketMatch = Result<(MatchResult, Option<Matchable>, Vec<MatchResult>), SQLParseError>;
340
341fn next_ex_bracket_match(
342    segments: &[ErasedSegment],
343    idx: u32,
344    matchers: &[Matchable],
345    parse_context: &mut ParseContext,
346    bracket_pairs_set: &'static str,
347) -> BracketMatch {
348    let max_idx = segments.len() as u32;
349
350    if idx >= max_idx {
351        return Ok((MatchResult::empty_at(idx), None, Vec::new()));
352    }
353
354    let (_, start_bracket_refs, end_bracket_refs, bracket_persists): (
355        Vec<_>,
356        Vec<_>,
357        Vec<_>,
358        Vec<_>,
359    ) = multiunzip(parse_context.dialect().bracket_sets(bracket_pairs_set));
360
361    let start_brackets = start_bracket_refs
362        .into_iter()
363        .map(|seg_ref| parse_context.dialect().r#ref(seg_ref))
364        .collect_vec();
365
366    let end_brackets = end_bracket_refs
367        .into_iter()
368        .map(|seg_ref| parse_context.dialect().r#ref(seg_ref))
369        .collect_vec();
370
371    let all_matchers = [matchers, &start_brackets, &end_brackets].concat();
372
373    let mut matched_idx = idx;
374    let mut child_matches: Vec<MatchResult> = Vec::new();
375
376    loop {
377        let (match_result, matcher) =
378            next_match(segments, matched_idx, &all_matchers, parse_context)?;
379        if !match_result.has_match() {
380            return Ok((match_result, matcher.clone(), child_matches));
381        }
382
383        if let Some(matcher) = matcher
384            .as_ref()
385            .filter(|matcher| matchers.contains(matcher))
386        {
387            return Ok((match_result, Some(matcher.clone()), child_matches));
388        }
389
390        if matcher
391            .as_ref()
392            .is_some_and(|matcher| end_brackets.contains(matcher))
393        {
394            return Ok((MatchResult::empty_at(idx), None, Vec::new()));
395        }
396
397        let bracket_match = resolve_bracket(
398            segments,
399            match_result,
400            matcher.unwrap(),
401            &start_brackets,
402            &end_brackets,
403            &bracket_persists,
404            parse_context,
405            true,
406        )?;
407
408        matched_idx = bracket_match.span.end;
409        child_matches.push(bracket_match);
410    }
411}
412
413pub fn greedy_match(
414    segments: &[ErasedSegment],
415    idx: u32,
416    parse_context: &mut ParseContext,
417    matchers: &[Matchable],
418    include_terminator: bool,
419    nested_match: bool,
420) -> Result<MatchResult, SQLParseError> {
421    let mut working_idx = idx;
422    let mut stop_idx: u32;
423    let mut child_matches = Vec::new();
424    let mut matched;
425
426    loop {
427        let (match_result, matcher, inner_matches) =
428            parse_context.deeper_match(false, &[], |ctx| {
429                next_ex_bracket_match(segments, working_idx, matchers, ctx, "bracket_pairs")
430            })?;
431
432        matched = match_result;
433
434        if nested_match {
435            child_matches.extend(inner_matches);
436        }
437
438        if !matched.has_match() {
439            return Ok(MatchResult {
440                span: Span {
441                    start: idx,
442                    end: segments.len() as u32,
443                },
444                matched: None,
445                insert_segments: Vec::new(),
446                child_matches,
447            });
448        }
449
450        let start_idx = matched.span.start;
451        stop_idx = matched.span.end;
452
453        let matcher = matcher.unwrap();
454        let (strings, types) = matcher.simple(parse_context, None).unwrap();
455
456        if types.is_empty() && strings.iter().all(|s| s.chars().all(|c| c.is_alphabetic())) {
457            let mut allowable_match = start_idx == working_idx;
458
459            for idx in (working_idx..=start_idx).rev() {
460                if segments[idx as usize - 1].is_meta() {
461                    continue;
462                }
463
464                allowable_match = matches!(
465                    segments[idx as usize - 1].get_type(),
466                    SyntaxKind::Whitespace | SyntaxKind::Newline
467                );
468
469                break;
470            }
471
472            if !allowable_match {
473                working_idx = stop_idx;
474                continue;
475            }
476        }
477
478        break;
479    }
480
481    if include_terminator {
482        return Ok(MatchResult {
483            span: Span {
484                start: idx,
485                end: stop_idx,
486            },
487            ..MatchResult::default()
488        });
489    }
490
491    let stop_idx = skip_stop_index_backward_to_code(segments, matched.span.start, idx);
492
493    let span = if idx == stop_idx {
494        Span {
495            start: idx,
496            end: matched.span.start,
497        }
498    } else {
499        Span {
500            start: idx,
501            end: stop_idx,
502        }
503    };
504
505    Ok(MatchResult {
506        span,
507        child_matches,
508        ..Default::default()
509    })
510}
511
512pub fn trim_to_terminator(
513    segments: &[ErasedSegment],
514    idx: u32,
515    terminators: &[Matchable],
516    parse_context: &mut ParseContext,
517) -> Result<u32, SQLParseError> {
518    if idx >= segments.len() as u32 {
519        return Ok(segments.len() as u32);
520    }
521
522    let early_return = parse_context.deeper_match(false, &[], |ctx| {
523        let pruned_terms = prune_options(terminators, segments, ctx, idx);
524
525        for term in pruned_terms {
526            if term.match_segments(segments, idx, ctx)?.has_match() {
527                return Ok(Some(idx));
528            }
529        }
530
531        Ok(None)
532    })?;
533
534    if let Some(idx) = early_return {
535        return Ok(idx);
536    }
537
538    let term_match = parse_context.deeper_match(false, &[], |ctx| {
539        greedy_match(segments, idx, ctx, terminators, false, false)
540    })?;
541
542    Ok(skip_stop_index_backward_to_code(
543        segments,
544        term_match.span.end,
545        idx,
546    ))
547}