sqruff_lib/utils/reflow/
rebreak.rs

1use std::cmp::PartialEq;
2use std::str::FromStr;
3
4use sqruff_lib_core::dialects::syntax::SyntaxKind;
5use sqruff_lib_core::helpers::capitalize;
6use sqruff_lib_core::lint_fix::LintFix;
7use sqruff_lib_core::parser::segments::{ErasedSegment, Tables};
8use strum_macros::{AsRefStr, EnumString};
9
10use super::elements::{ReflowElement, ReflowSequenceType};
11use crate::core::rules::LintResult;
12use crate::utils::reflow::depth_map::StackPositionType;
13use crate::utils::reflow::elements::ReflowPoint;
14use crate::utils::reflow::helpers::{deduce_line_indent, fixes_from_results};
15
16#[derive(Debug)]
17pub struct RebreakSpan {
18    pub(crate) target: ErasedSegment,
19    pub(crate) start_idx: usize,
20    pub(crate) end_idx: usize,
21    pub(crate) line_position: LinePosition,
22    pub(crate) strict: bool,
23}
24
25#[derive(Debug)]
26pub struct RebreakIndices {
27    _dir: i32,
28    adj_pt_idx: isize,
29    newline_pt_idx: isize,
30    pre_code_pt_idx: isize,
31}
32
33impl RebreakIndices {
34    fn from_elements(elements: &ReflowSequenceType, start_idx: usize, dir: i32) -> Option<Self> {
35        assert!(dir == 1 || dir == -1);
36        let limit = if dir == -1 { 0 } else { elements.len() };
37        let adj_point_idx = start_idx as isize + dir as isize;
38
39        if adj_point_idx < 0 || adj_point_idx >= elements.len() as isize {
40            return None;
41        }
42
43        let mut newline_point_idx = adj_point_idx;
44        while (dir == 1 && newline_point_idx < limit as isize)
45            || (dir == -1 && newline_point_idx >= 0)
46        {
47            if elements[newline_point_idx as usize]
48                .class_types()
49                .contains(SyntaxKind::Newline)
50                || elements[(newline_point_idx + dir as isize) as usize]
51                    .segments()
52                    .iter()
53                    .any(|seg| seg.is_code())
54            {
55                break;
56            }
57            newline_point_idx += 2 * dir as isize;
58        }
59
60        let mut pre_code_point_idx = newline_point_idx;
61        while (dir == 1 && pre_code_point_idx < limit as isize)
62            || (dir == -1 && pre_code_point_idx >= 0)
63        {
64            if elements[(pre_code_point_idx + dir as isize) as usize]
65                .segments()
66                .iter()
67                .any(|seg| seg.is_code())
68            {
69                break;
70            }
71            pre_code_point_idx += 2 * dir as isize;
72        }
73
74        RebreakIndices {
75            _dir: dir,
76            adj_pt_idx: adj_point_idx,
77            newline_pt_idx: newline_point_idx,
78            pre_code_pt_idx: pre_code_point_idx,
79        }
80        .into()
81    }
82}
83
84#[derive(Debug)]
85pub struct RebreakLocation {
86    target: ErasedSegment,
87    prev: RebreakIndices,
88    next: RebreakIndices,
89    line_position: LinePosition,
90    strict: bool,
91}
92
93#[derive(Debug, PartialEq, Clone, Copy, AsRefStr, EnumString)]
94#[strum(serialize_all = "lowercase")]
95pub enum LinePosition {
96    Leading,
97    Trailing,
98    Alone,
99    Strict,
100}
101
102impl RebreakLocation {
103    /// Expand a span to a location.
104    pub fn from_span(span: RebreakSpan, elements: &ReflowSequenceType) -> Option<Self> {
105        Self {
106            target: span.target,
107            prev: RebreakIndices::from_elements(elements, span.start_idx, -1)?,
108            next: RebreakIndices::from_elements(elements, span.end_idx, 1)?,
109            line_position: span.line_position,
110            strict: span.strict,
111        }
112        .into()
113    }
114
115    fn has_inappropriate_newlines(&self, elements: &ReflowSequenceType, strict: bool) -> bool {
116        let n_prev_newlines = elements[self.prev.newline_pt_idx as usize].num_newlines();
117        let n_next_newlines = elements[self.next.newline_pt_idx as usize].num_newlines();
118
119        let newlines_on_neither_side = n_prev_newlines + n_next_newlines == 0;
120        let newlines_on_both_sides = n_prev_newlines > 0 && n_next_newlines > 0;
121
122        (newlines_on_neither_side && !strict) || newlines_on_both_sides
123    }
124
125    fn pretty_target_name(&self) -> String {
126        format!("{} {}", self.target.get_type().as_str(), self.target.raw())
127    }
128}
129
130pub fn identify_rebreak_spans(
131    element_buffer: &ReflowSequenceType,
132    root_segment: &ErasedSegment,
133) -> Vec<RebreakSpan> {
134    let mut spans = Vec::new();
135
136    for (idx, elem) in element_buffer
137        .iter()
138        .enumerate()
139        .take(element_buffer.len() - 2)
140        .skip(2)
141    {
142        let ReflowElement::Block(block) = elem else {
143            continue;
144        };
145
146        if let Some(original_line_position) = block.line_position() {
147            let line_position = original_line_position.first().unwrap();
148            spans.push(RebreakSpan {
149                target: elem.segments().first().cloned().unwrap(),
150                start_idx: idx,
151                end_idx: idx,
152                line_position: *line_position,
153                strict: original_line_position.last() == Some(&LinePosition::Strict),
154            });
155        }
156
157        for key in block.line_position_configs().keys() {
158            let mut final_idx = None;
159            if block.depth_info().stack_positions[key].idx != 0 {
160                continue;
161            }
162
163            for end_idx in idx..element_buffer.len() {
164                let end_elem = &element_buffer[end_idx];
165                let ReflowElement::Block(end_block) = end_elem else {
166                    continue;
167                };
168
169                if !end_block.depth_info().stack_positions.contains_key(key) {
170                    final_idx = (end_idx - 2).into();
171                } else if matches!(
172                    end_block.depth_info().stack_positions[key].type_,
173                    Some(StackPositionType::End) | Some(StackPositionType::Solo)
174                ) {
175                    final_idx = end_idx.into();
176                }
177
178                if let Some(final_idx) = final_idx {
179                    let target_depth = block
180                        .depth_info()
181                        .stack_hashes
182                        .iter()
183                        .position(|it| it == key)
184                        .unwrap();
185                    let target = root_segment.path_to(&element_buffer[idx].segments()[0])
186                        [target_depth]
187                        .segment
188                        .clone();
189
190                    let line_position_configs = block.line_position_configs()[key]
191                        .split(':')
192                        .next()
193                        .unwrap();
194                    let line_position = LinePosition::from_str(line_position_configs).unwrap();
195
196                    spans.push(RebreakSpan {
197                        target,
198                        start_idx: idx,
199                        end_idx: final_idx,
200                        line_position,
201                        strict: block.line_position_configs()[key].ends_with("strict"),
202                    });
203
204                    break;
205                }
206            }
207        }
208    }
209
210    spans
211}
212
213pub fn rebreak_sequence(
214    tables: &Tables,
215    elements: ReflowSequenceType,
216    root_segment: &ErasedSegment,
217) -> (ReflowSequenceType, Vec<LintResult>) {
218    let mut lint_results = Vec::new();
219    let mut fixes = Vec::new();
220    let mut elem_buff = elements.clone();
221
222    // Given a sequence we should identify the objects which
223    // make sense to rebreak. That includes any raws with config,
224    // but also and parent segments which have config and we can
225    // find both ends for. Given those spans, we then need to find
226    // the points either side of them and then the blocks either
227    // side to respace them at the same time.
228
229    // 1. First find appropriate spans.
230    let spans = identify_rebreak_spans(&elem_buff, root_segment);
231
232    let mut locations = Vec::new();
233    for span in spans {
234        if let Some(loc) = RebreakLocation::from_span(span, &elements) {
235            locations.push(loc);
236        }
237    }
238
239    // Handle each span:
240    for loc in locations {
241        if loc.has_inappropriate_newlines(&elements, loc.strict) {
242            continue;
243        }
244
245        // if loc.has_templated_newline(elem_buff) {
246        //     continue;
247        // }
248
249        // Points and blocks either side are just offsets from the indices.
250        let prev_point = elem_buff[loc.prev.adj_pt_idx as usize]
251            .as_point()
252            .unwrap()
253            .clone();
254        let next_point = elem_buff[loc.next.adj_pt_idx as usize]
255            .as_point()
256            .unwrap()
257            .clone();
258
259        // So we know we have a preference, is it ok?
260        let new_results = if loc.line_position == LinePosition::Leading {
261            if elem_buff[loc.prev.newline_pt_idx as usize].num_newlines() != 0 {
262                // We're good. It's already leading.
263                continue;
264            }
265
266            // Generate the text for any issues.
267            let pretty_name = loc.pretty_target_name();
268            let _desc = if loc.strict {
269                format!(
270                    "{} should always start a new line.",
271                    capitalize(&pretty_name)
272                )
273            } else {
274                format!("Found trailing {pretty_name}. Expected only leading near line breaks.")
275            };
276
277            if loc.next.adj_pt_idx == loc.next.pre_code_pt_idx
278                && elem_buff[loc.next.newline_pt_idx as usize].num_newlines() == 1
279            {
280                // Simple case. No comments.
281                // Strip newlines from the next point.
282                // Apply the indent to the previous point.
283
284                let desired_indent = next_point.get_indent().unwrap_or_default();
285
286                let (new_results, prev_point) = prev_point.indent_to(
287                    tables,
288                    &desired_indent,
289                    None,
290                    loc.target.clone().into(),
291                    None,
292                    None,
293                );
294
295                let (new_results, next_point) = next_point.respace_point(
296                    tables,
297                    elem_buff[loc.next.adj_pt_idx as usize - 1].as_block(),
298                    elem_buff[loc.next.adj_pt_idx as usize + 1].as_block(),
299                    root_segment,
300                    new_results,
301                    true,
302                    "before",
303                );
304
305                // Update the points in the buffer
306                elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
307                elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
308
309                new_results
310            } else {
311                fixes.push(LintFix::delete(loc.target.clone()));
312                for seg in elem_buff[loc.prev.adj_pt_idx as usize].segments() {
313                    fixes.push(LintFix::delete(seg.clone()));
314                }
315
316                let (new_results, new_point) = ReflowPoint::new(Vec::new()).respace_point(
317                    tables,
318                    elem_buff[(loc.next.adj_pt_idx - 1) as usize].as_block(),
319                    elem_buff[(loc.next.pre_code_pt_idx + 1) as usize].as_block(),
320                    root_segment,
321                    Vec::new(),
322                    false,
323                    "after",
324                );
325
326                let mut create_anchor = None;
327                for i in 0..loc.next.pre_code_pt_idx {
328                    let idx = loc.next.pre_code_pt_idx - i;
329                    if let Some(elem) = elem_buff.get(idx as usize)
330                        && let Some(segments) = elem.segments().last()
331                    {
332                        create_anchor = Some(segments.clone());
333                        break;
334                    }
335                }
336
337                if create_anchor.is_none() {
338                    panic!("Could not find anchor for creation.");
339                }
340
341                fixes.push(LintFix::create_after(
342                    create_anchor.unwrap(),
343                    vec![loc.target.clone()],
344                    None,
345                ));
346
347                rearrange_and_insert(&mut elem_buff, &loc, new_point);
348
349                new_results
350            }
351        } else if loc.line_position == LinePosition::Trailing {
352            if elem_buff[loc.next.newline_pt_idx as usize].num_newlines() != 0 {
353                continue;
354            }
355
356            let pretty_name = loc.pretty_target_name();
357            let _desc = if loc.strict {
358                format!(
359                    "{} should always be at the end of a line.",
360                    capitalize(&pretty_name)
361                )
362            } else {
363                format!("Found leading {pretty_name}. Expected only trailing near line breaks.")
364            };
365
366            if loc.prev.adj_pt_idx == loc.prev.pre_code_pt_idx
367                && elem_buff[loc.prev.newline_pt_idx as usize].num_newlines() == 1
368            {
369                let (new_results, next_point) = next_point.indent_to(
370                    tables,
371                    prev_point.get_indent().as_deref().unwrap_or_default(),
372                    Some(loc.target.clone()),
373                    None,
374                    None,
375                    None,
376                );
377
378                let (new_results, prev_point) = prev_point.respace_point(
379                    tables,
380                    elem_buff[loc.prev.adj_pt_idx as usize - 1].as_block(),
381                    elem_buff[loc.prev.adj_pt_idx as usize + 1].as_block(),
382                    root_segment,
383                    new_results,
384                    true,
385                    "before",
386                );
387
388                // Update the points in the buffer
389                elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
390                elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
391
392                new_results
393            } else {
394                fixes.push(LintFix::delete(loc.target.clone()));
395                for seg in elem_buff[loc.next.adj_pt_idx as usize].segments() {
396                    fixes.push(LintFix::delete(seg.clone()));
397                }
398
399                let (new_results, new_point) = ReflowPoint::new(Vec::new()).respace_point(
400                    tables,
401                    elem_buff[(loc.prev.pre_code_pt_idx - 1) as usize].as_block(),
402                    elem_buff[(loc.prev.adj_pt_idx + 1) as usize].as_block(),
403                    root_segment,
404                    Vec::new(),
405                    false,
406                    "before",
407                );
408
409                fixes.push(LintFix::create_before(
410                    elem_buff[loc.prev.pre_code_pt_idx as usize].segments()[0].clone(),
411                    vec![loc.target.clone()],
412                ));
413
414                reorder_and_insert(&mut elem_buff, &loc, new_point);
415
416                new_results
417            }
418        } else if loc.line_position == LinePosition::Alone {
419            let mut new_results = Vec::new();
420
421            if elem_buff[loc.next.newline_pt_idx as usize].num_newlines() == 0 {
422                let (results, next_point) = next_point.indent_to(
423                    tables,
424                    &deduce_line_indent(
425                        loc.target.get_raw_segments().last().unwrap(),
426                        root_segment,
427                    ),
428                    loc.target.clone().into(),
429                    None,
430                    None,
431                    None,
432                );
433
434                new_results = results;
435                elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
436            }
437
438            if elem_buff[loc.prev.adj_pt_idx as usize].num_newlines() == 0 {
439                let (results, prev_point) = prev_point.indent_to(
440                    tables,
441                    &deduce_line_indent(
442                        loc.target.get_raw_segments().first().unwrap(),
443                        root_segment,
444                    ),
445                    None,
446                    loc.target.clone().into(),
447                    None,
448                    None,
449                );
450
451                new_results = results;
452                elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
453            }
454
455            new_results
456        } else {
457            unimplemented!(
458                "Unexpected line_position config: {}",
459                loc.line_position.as_ref()
460            )
461        };
462
463        let fixes = fixes_from_results(new_results.into_iter())
464            .chain(std::mem::take(&mut fixes))
465            .collect();
466        lint_results.push(LintResult::new(
467            loc.target.clone().into(),
468            fixes,
469            None,
470            None,
471        ));
472    }
473
474    (elem_buff, lint_results)
475}
476
477fn rearrange_and_insert(
478    elem_buff: &mut Vec<ReflowElement>,
479    loc: &RebreakLocation,
480    new_point: ReflowPoint,
481) {
482    let mut new_buff = Vec::with_capacity(elem_buff.len() + 1);
483
484    // First segment: up to loc.prev.adj_pt_idx (exclusive)
485    new_buff.extend_from_slice(&elem_buff[..loc.prev.adj_pt_idx as usize]);
486
487    // Second segment: loc.next.adj_pt_idx to loc.next.pre_code_pt_idx (inclusive)
488    new_buff.extend_from_slice(
489        &elem_buff[loc.next.adj_pt_idx as usize..=loc.next.pre_code_pt_idx as usize],
490    );
491
492    // Third segment: loc.prev.adj_pt_idx + 1 to loc.next.adj_pt_idx (exclusive, the
493    // target)
494    if loc.prev.adj_pt_idx + 1 < loc.next.adj_pt_idx {
495        new_buff.extend_from_slice(
496            &elem_buff[loc.prev.adj_pt_idx as usize + 1..loc.next.adj_pt_idx as usize],
497        );
498    }
499
500    // Insert new_point here
501    new_buff.push(new_point.into());
502
503    // Last segment: after loc.next.pre_code_pt_idx
504    if loc.next.pre_code_pt_idx as usize + 1 < elem_buff.len() {
505        new_buff.extend_from_slice(&elem_buff[loc.next.pre_code_pt_idx as usize + 1..]);
506    }
507
508    // Replace old buffer with the new one
509    *elem_buff = new_buff;
510}
511
512fn reorder_and_insert(
513    elem_buff: &mut Vec<ReflowElement>,
514    loc: &RebreakLocation,
515    new_point: ReflowPoint,
516) {
517    let mut new_buff = Vec::with_capacity(elem_buff.len() + 1);
518
519    // First segment: up to loc.prev.pre_code_pt_idx (exclusive)
520    new_buff.extend_from_slice(&elem_buff[..loc.prev.pre_code_pt_idx as usize]);
521
522    // Insert new_point here
523    new_buff.push(new_point.into());
524
525    // Second segment: loc.prev.adj_pt_idx + 1 to loc.next.adj_pt_idx (exclusive,
526    // the target)
527    if loc.prev.adj_pt_idx + 1 < loc.next.adj_pt_idx {
528        new_buff.extend_from_slice(
529            &elem_buff[loc.prev.adj_pt_idx as usize + 1..loc.next.adj_pt_idx as usize],
530        );
531    }
532
533    // Third segment: loc.prev.pre_code_pt_idx to loc.prev.adj_pt_idx + 1
534    // (inclusive)
535    new_buff.extend_from_slice(
536        &elem_buff[loc.prev.pre_code_pt_idx as usize..=loc.prev.adj_pt_idx as usize],
537    );
538
539    // Last segment: after loc.next.adj_pt_idx
540    if loc.next.adj_pt_idx as usize + 1 < elem_buff.len() {
541        new_buff.extend_from_slice(&elem_buff[loc.next.adj_pt_idx as usize + 1..]);
542    }
543
544    // Replace old buffer with the new one
545    *elem_buff = new_buff;
546}
547
548#[cfg(test)]
549mod tests {
550    use sqruff_lib::core::test_functions::parse_ansi_string;
551    use sqruff_lib_core::helpers::enter_panic;
552    use sqruff_lib_core::parser::segments::Tables;
553
554    use crate::utils::reflow::sequence::{ReflowSequence, TargetSide};
555
556    #[test]
557    fn test_reflow_sequence_rebreak_root() {
558        let cases = [
559            // Trivial Case
560            ("select 1", "select 1"),
561            // These rely on the default config being for leading operators
562            ("select 1\n+2", "select 1\n+2"),
563            ("select 1+\n2", "select 1\n+ 2"), // NOTE: Implicit respace.
564            ("select\n  1 +\n  2", "select\n  1\n  + 2"),
565            (
566                "select\n  1 +\n  -- comment\n  2",
567                "select\n  1\n  -- comment\n  + 2",
568            ),
569            // These rely on the default config being for trailing commas
570            ("select a,b", "select a,b"),
571            ("select a\n,b", "select a,\nb"),
572            ("select\n  a\n  , b", "select\n  a,\n  b"),
573            ("select\n    a\n    , b", "select\n    a,\n    b"),
574            ("select\n  a\n    , b", "select\n  a,\n    b"),
575            (
576                "select\n  a\n  -- comment\n  , b",
577                "select\n  a,\n  -- comment\n  b",
578            ),
579        ];
580
581        let tables = Tables::default();
582        for (raw_sql_in, raw_sql_out) in cases {
583            let _panic = enter_panic(format!("{raw_sql_in:?}"));
584
585            let root = parse_ansi_string(raw_sql_in);
586            let config = <_>::default();
587            let seq = ReflowSequence::from_root(root, &config);
588            let new_seq = seq.rebreak(&tables);
589
590            assert_eq!(new_seq.raw(), raw_sql_out);
591        }
592    }
593
594    #[test]
595    fn test_reflow_sequence_rebreak_target() {
596        let cases = [
597            ("select 1+\n(2+3)", 4, "1+\n(", "1\n+ ("),
598            ("select a,\n(b+c)", 4, "a,\n(", "a,\n("),
599            ("select a\n  , (b+c)", 6, "a\n  , (", "a,\n  ("),
600            // Here we don't have enough context to rebreak it so
601            // it should be left unaltered.
602            ("select a,\n(b+c)", 6, ",\n(b", ",\n(b"),
603            // This intentionally targets an incomplete span.
604            ("select a<=b", 4, "a<=", "a<="),
605        ];
606
607        let tables = Tables::default();
608        for (raw_sql_in, target_idx, seq_sql_in, seq_sql_out) in cases {
609            let root = parse_ansi_string(raw_sql_in);
610            let target = &root.get_raw_segments()[target_idx];
611            let config = <_>::default();
612            let seq = ReflowSequence::from_around_target(target, root, TargetSide::Both, &config);
613
614            assert_eq!(seq.raw(), seq_sql_in);
615
616            let new_seq = seq.rebreak(&tables);
617            assert_eq!(new_seq.raw(), seq_sql_out);
618        }
619    }
620}