1use ahash::AHashMap;
2use itertools::{Itertools as _, enumerate, multiunzip};
3use smol_str::StrExt;
4
5use super::context::ParseContext;
6use super::match_result::{MatchResult, Matched, Span};
7use super::matchable::{Matchable, MatchableTrait};
8use super::segments::ErasedSegment;
9use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
10use crate::errors::SQLParseError;
11
12pub fn skip_start_index_forward_to_code(
13 segments: &[ErasedSegment],
14 start_idx: u32,
15 max_idx: u32,
16) -> u32 {
17 let mut idx = start_idx;
18 while idx < max_idx {
19 if segments[idx as usize].is_code() {
20 break;
21 }
22 idx += 1;
23 }
24 idx
25}
26
27pub fn skip_stop_index_backward_to_code(
28 segments: &[ErasedSegment],
29 stop_idx: u32,
30 min_idx: u32,
31) -> u32 {
32 let mut idx = stop_idx;
33 while idx > min_idx {
34 if segments[idx as usize - 1].is_code() {
35 break;
36 }
37 idx -= 1;
38 }
39 idx
40}
41
42pub fn first_trimmed_raw(seg: &ErasedSegment) -> String {
43 seg.raw()
44 .to_uppercase_smolstr()
45 .split(char::is_whitespace)
46 .next()
47 .map(ToString::to_string)
48 .unwrap_or_default()
49}
50
51pub fn first_non_whitespace(
52 segments: &[ErasedSegment],
53 start_idx: u32,
54) -> Option<(String, &SyntaxSet)> {
55 for segment in segments.iter().skip(start_idx as usize) {
56 if let Some(raw) = segment.first_non_whitespace_segment_raw_upper() {
57 return Some((raw, segment.class_types()));
58 }
59 }
60
61 None
62}
63
64pub fn prune_options(
65 options: &[Matchable],
66 segments: &[ErasedSegment],
67 parse_context: &mut ParseContext,
68 start_idx: u32,
69) -> Vec<Matchable> {
70 let mut available_options = vec![];
71
72 let Some((first_raw, first_types)) = first_non_whitespace(segments, start_idx) else {
74 return options.to_vec();
75 };
76
77 for opt in options {
78 let Some(simple) = opt.simple(parse_context, None) else {
79 available_options.push(opt.clone());
82 continue;
83 };
84
85 let (simple_raws, simple_types) = simple;
88 let mut matched = false;
89
90 if simple_raws.contains(&first_raw) {
96 available_options.push(opt.clone());
98 matched = true;
99 }
100
101 if !matched && first_types.intersects(&simple_types) {
102 available_options.push(opt.clone());
103 }
104 }
105
106 available_options
107}
108
109pub fn longest_match(
110 segments: &[ErasedSegment],
111 matchers: &[Matchable],
112 idx: u32,
113 parse_context: &mut ParseContext,
114) -> Result<(MatchResult, Option<Matchable>), SQLParseError> {
115 let max_idx = segments.len() as u32;
116
117 if matchers.is_empty() || idx == max_idx {
118 return Ok((MatchResult::empty_at(idx), None));
119 }
120
121 let available_options = prune_options(matchers, segments, parse_context, idx);
122 let available_options_count = available_options.len();
123
124 if available_options.is_empty() {
125 return Ok((MatchResult::empty_at(idx), None));
126 }
127
128 let terminators = parse_context.terminators.clone();
129 let cache_position = segments[idx as usize].get_position_marker().unwrap();
130
131 let loc_key = (
132 segments[idx as usize].raw().clone(),
133 cache_position.working_loc(),
134 segments[idx as usize].get_type(),
135 max_idx,
136 );
137
138 let loc_key = parse_context.loc_key(loc_key);
139
140 let mut best_match = MatchResult::empty_at(idx);
141 let mut best_matcher = None;
142
143 'matcher: for (matcher_idx, matcher) in enumerate(available_options) {
144 let matcher_key = matcher.cache_key();
145 let res_match = parse_context.check_parse_cache(loc_key, matcher_key);
146
147 let res_match = match res_match {
148 Some(res_match) => res_match,
149 None => {
150 let res_match = matcher.match_segments(segments, idx, parse_context)?;
151 parse_context.put_parse_cache(loc_key, matcher_key, res_match.clone());
152 res_match
153 }
154 };
155
156 if res_match.has_match() && res_match.span.end == max_idx {
157 return Ok((res_match, matcher.into()));
158 }
159
160 if res_match.is_better_than(&best_match) {
161 best_match = res_match;
162 best_matcher = matcher.into();
163
164 if matcher_idx == available_options_count - 1 {
165 break 'matcher;
166 } else if !terminators.is_empty() {
167 let next_code_idx = skip_start_index_forward_to_code(
168 segments,
169 best_match.span.end,
170 segments.len() as u32,
171 );
172
173 if next_code_idx == segments.len() as u32 {
174 break 'matcher;
175 }
176
177 for terminator in &terminators {
178 let terminator_match =
179 terminator.match_segments(segments, next_code_idx, parse_context)?;
180
181 if terminator_match.has_match() {
182 break 'matcher;
183 }
184 }
185 }
186 }
187 }
188
189 Ok((best_match, best_matcher))
190}
191
192fn next_match(
193 segments: &[ErasedSegment],
194 idx: u32,
195 matchers: &[Matchable],
196 parse_context: &mut ParseContext,
197) -> Result<(MatchResult, Option<Matchable>), SQLParseError> {
198 let max_idx = segments.len() as u32;
199
200 if idx >= max_idx {
201 return Ok((MatchResult::empty_at(idx), None));
202 }
203
204 let mut raw_simple_map: AHashMap<String, Vec<usize>> = AHashMap::new();
205 let mut type_simple_map: AHashMap<SyntaxKind, Vec<usize>> = AHashMap::new();
206
207 for (idx, matcher) in enumerate(matchers) {
208 let (raws, types) = matcher.simple(parse_context, None).unwrap();
209
210 raw_simple_map.reserve(raws.len());
211 type_simple_map.reserve(types.len());
212
213 for raw in raws {
214 raw_simple_map.entry(raw).or_default().push(idx);
215 }
216
217 for typ in types {
218 type_simple_map.entry(typ).or_default().push(idx);
219 }
220 }
221
222 for idx in idx..max_idx {
223 let seg = &segments[idx as usize];
224 let mut matcher_idxs = raw_simple_map
225 .get(&first_trimmed_raw(seg))
226 .cloned()
227 .unwrap_or_default();
228
229 let keys = type_simple_map.keys().copied().collect();
230 let type_overlap = seg.class_types().clone().intersection(&keys);
231
232 for typ in type_overlap {
233 matcher_idxs.extend(type_simple_map[&typ].clone());
234 }
235
236 if matcher_idxs.is_empty() {
237 continue;
238 }
239
240 matcher_idxs.sort();
241 for matcher_idx in matcher_idxs {
242 let matcher = &matchers[matcher_idx];
243 let match_result = matcher.match_segments(segments, idx, parse_context)?;
244
245 if match_result.has_match() {
246 return Ok((match_result, matcher.clone().into()));
247 }
248 }
249 }
250
251 Ok((MatchResult::empty_at(idx), None))
252}
253
254#[allow(clippy::too_many_arguments)]
255pub fn resolve_bracket(
256 segments: &[ErasedSegment],
257 opening_match: MatchResult,
258 opening_matcher: Matchable,
259 start_brackets: &[Matchable],
260 end_brackets: &[Matchable],
261 bracket_persists: &[bool],
262 parse_context: &mut ParseContext,
263 nested_match: bool,
264) -> Result<MatchResult, SQLParseError> {
265 let type_idx = start_brackets
266 .iter()
267 .position(|it| it == &opening_matcher)
268 .unwrap();
269 let mut matched_idx = opening_match.span.end;
270 let mut child_matches = vec![opening_match.clone()];
271
272 let matchers = [start_brackets, end_brackets].concat();
273 loop {
274 let (match_result, matcher) = next_match(segments, matched_idx, &matchers, parse_context)?;
275
276 if !match_result.has_match() {
277 return Err(SQLParseError {
278 description: "Couldn't find closing bracket for opening bracket.".into(),
279 segment: segments[opening_match.span.start as usize].clone().into(),
280 });
281 }
282
283 let matcher = matcher.unwrap();
284 if end_brackets.contains(&matcher) {
285 let closing_idx = end_brackets.iter().position(|it| it == &matcher).unwrap();
286
287 if closing_idx == type_idx {
288 let match_span = match_result.span;
289 let persists = bracket_persists[type_idx];
290 let insert_segments = vec![
291 (opening_match.span.end, SyntaxKind::Indent),
292 (match_result.span.start, SyntaxKind::Dedent),
293 ];
294
295 child_matches.push(match_result);
296 let match_result = MatchResult {
297 span: Span {
298 start: opening_match.span.start,
299 end: match_span.end,
300 },
301 matched: None,
302 insert_segments,
303 child_matches,
304 };
305
306 if !persists {
307 return Ok(match_result);
308 }
309
310 return Ok(match_result.wrap(Matched::SyntaxKind(SyntaxKind::Bracketed)));
311 }
312
313 return Err(SQLParseError {
314 description: "Found unexpected end bracket!".into(),
315 segment: segments[(match_result.span.end - 1) as usize]
316 .clone()
317 .into(),
318 });
319 }
320
321 let inner_match = resolve_bracket(
322 segments,
323 match_result,
324 matcher,
325 start_brackets,
326 end_brackets,
327 bracket_persists,
328 parse_context,
329 false,
330 )?;
331
332 matched_idx = inner_match.span.end;
333 if nested_match {
334 child_matches.push(inner_match);
335 }
336 }
337}
338
339type BracketMatch = Result<(MatchResult, Option<Matchable>, Vec<MatchResult>), SQLParseError>;
340
341fn next_ex_bracket_match(
342 segments: &[ErasedSegment],
343 idx: u32,
344 matchers: &[Matchable],
345 parse_context: &mut ParseContext,
346 bracket_pairs_set: &'static str,
347) -> BracketMatch {
348 let max_idx = segments.len() as u32;
349
350 if idx >= max_idx {
351 return Ok((MatchResult::empty_at(idx), None, Vec::new()));
352 }
353
354 let (_, start_bracket_refs, end_bracket_refs, bracket_persists): (
355 Vec<_>,
356 Vec<_>,
357 Vec<_>,
358 Vec<_>,
359 ) = multiunzip(parse_context.dialect().bracket_sets(bracket_pairs_set));
360
361 let start_brackets = start_bracket_refs
362 .into_iter()
363 .map(|seg_ref| parse_context.dialect().r#ref(seg_ref))
364 .collect_vec();
365
366 let end_brackets = end_bracket_refs
367 .into_iter()
368 .map(|seg_ref| parse_context.dialect().r#ref(seg_ref))
369 .collect_vec();
370
371 let all_matchers = [matchers, &start_brackets, &end_brackets].concat();
372
373 let mut matched_idx = idx;
374 let mut child_matches: Vec<MatchResult> = Vec::new();
375
376 loop {
377 let (match_result, matcher) =
378 next_match(segments, matched_idx, &all_matchers, parse_context)?;
379 if !match_result.has_match() {
380 return Ok((match_result, matcher.clone(), child_matches));
381 }
382
383 if let Some(matcher) = matcher
384 .as_ref()
385 .filter(|matcher| matchers.contains(matcher))
386 {
387 return Ok((match_result, Some(matcher.clone()), child_matches));
388 }
389
390 if matcher
391 .as_ref()
392 .is_some_and(|matcher| end_brackets.contains(matcher))
393 {
394 return Ok((MatchResult::empty_at(idx), None, Vec::new()));
395 }
396
397 let bracket_match = resolve_bracket(
398 segments,
399 match_result,
400 matcher.unwrap(),
401 &start_brackets,
402 &end_brackets,
403 &bracket_persists,
404 parse_context,
405 true,
406 )?;
407
408 matched_idx = bracket_match.span.end;
409 child_matches.push(bracket_match);
410 }
411}
412
413pub fn greedy_match(
414 segments: &[ErasedSegment],
415 idx: u32,
416 parse_context: &mut ParseContext,
417 matchers: &[Matchable],
418 include_terminator: bool,
419 nested_match: bool,
420) -> Result<MatchResult, SQLParseError> {
421 let mut working_idx = idx;
422 let mut stop_idx: u32;
423 let mut child_matches = Vec::new();
424 let mut matched;
425
426 loop {
427 let (match_result, matcher, inner_matches) =
428 parse_context.deeper_match(false, &[], |ctx| {
429 next_ex_bracket_match(segments, working_idx, matchers, ctx, "bracket_pairs")
430 })?;
431
432 matched = match_result;
433
434 if nested_match {
435 child_matches.extend(inner_matches);
436 }
437
438 if !matched.has_match() {
439 return Ok(MatchResult {
440 span: Span {
441 start: idx,
442 end: segments.len() as u32,
443 },
444 matched: None,
445 insert_segments: Vec::new(),
446 child_matches,
447 });
448 }
449
450 let start_idx = matched.span.start;
451 stop_idx = matched.span.end;
452
453 let matcher = matcher.unwrap();
454 let (strings, types) = matcher.simple(parse_context, None).unwrap();
455
456 if types.is_empty() && strings.iter().all(|s| s.chars().all(|c| c.is_alphabetic())) {
457 let mut allowable_match = start_idx == working_idx;
458
459 for idx in (working_idx..=start_idx).rev() {
460 if segments[idx as usize - 1].is_meta() {
461 continue;
462 }
463
464 allowable_match = matches!(
465 segments[idx as usize - 1].get_type(),
466 SyntaxKind::Whitespace | SyntaxKind::Newline
467 );
468
469 break;
470 }
471
472 if !allowable_match {
473 working_idx = stop_idx;
474 continue;
475 }
476 }
477
478 break;
479 }
480
481 if include_terminator {
482 return Ok(MatchResult {
483 span: Span {
484 start: idx,
485 end: stop_idx,
486 },
487 ..MatchResult::default()
488 });
489 }
490
491 let stop_idx = skip_stop_index_backward_to_code(segments, matched.span.start, idx);
492
493 let span = if idx == stop_idx {
494 Span {
495 start: idx,
496 end: matched.span.start,
497 }
498 } else {
499 Span {
500 start: idx,
501 end: stop_idx,
502 }
503 };
504
505 Ok(MatchResult {
506 span,
507 child_matches,
508 ..Default::default()
509 })
510}
511
512pub fn trim_to_terminator(
513 segments: &[ErasedSegment],
514 idx: u32,
515 terminators: &[Matchable],
516 parse_context: &mut ParseContext,
517) -> Result<u32, SQLParseError> {
518 if idx >= segments.len() as u32 {
519 return Ok(segments.len() as u32);
520 }
521
522 let early_return = parse_context.deeper_match(false, &[], |ctx| {
523 let pruned_terms = prune_options(terminators, segments, ctx, idx);
524
525 for term in pruned_terms {
526 if term.match_segments(segments, idx, ctx)?.has_match() {
527 return Ok(Some(idx));
528 }
529 }
530
531 Ok(None)
532 })?;
533
534 if let Some(idx) = early_return {
535 return Ok(idx);
536 }
537
538 let term_match = parse_context.deeper_match(false, &[], |ctx| {
539 greedy_match(segments, idx, ctx, terminators, false, false)
540 })?;
541
542 Ok(skip_stop_index_backward_to_code(
543 segments,
544 term_match.span.end,
545 idx,
546 ))
547}