Skip to main content

sqruff_lib_core/parser/
segments.rs

1pub mod bracketed;
2pub mod file;
3pub mod fix;
4pub mod from;
5pub mod generator;
6pub mod join;
7pub mod meta;
8pub mod object_reference;
9pub mod select;
10pub mod test_functions;
11
12use std::cell::{Cell, OnceCell};
13use std::fmt::Debug;
14use std::hash::{BuildHasher, Hash, Hasher};
15use std::rc::Rc;
16
17use hashbrown::{DefaultHashBuilder, HashMap};
18use itertools::enumerate;
19use smol_str::SmolStr;
20
21use crate::dialects::init::DialectKind;
22use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
23use crate::lint_fix::LintFix;
24use crate::parser::markers::PositionMarker;
25use crate::parser::segments::fix::{FixPatch, SourceFix};
26use crate::parser::segments::object_reference::{ObjectReferenceKind, ObjectReferenceSegment};
27use crate::segments::AnchorEditInfo;
28use crate::templaters::TemplatedFile;
29
30pub struct SegmentBuilder {
31    node_or_token: NodeOrToken,
32}
33
34impl SegmentBuilder {
35    pub fn whitespace(id: u32, raw: &str) -> ErasedSegment {
36        SegmentBuilder::token(id, raw, SyntaxKind::Whitespace).finish()
37    }
38
39    pub fn newline(id: u32, raw: &str) -> ErasedSegment {
40        SegmentBuilder::token(id, raw, SyntaxKind::Newline).finish()
41    }
42
43    pub fn keyword(id: u32, raw: &str) -> ErasedSegment {
44        SegmentBuilder::token(id, raw, SyntaxKind::Keyword).finish()
45    }
46
47    pub fn comma(id: u32) -> ErasedSegment {
48        SegmentBuilder::token(id, ",", SyntaxKind::Comma).finish()
49    }
50
51    pub fn symbol(id: u32, raw: &str) -> ErasedSegment {
52        SegmentBuilder::token(id, raw, SyntaxKind::Symbol).finish()
53    }
54
55    pub fn node(
56        id: u32,
57        syntax_kind: SyntaxKind,
58        dialect: DialectKind,
59        segments: Vec<ErasedSegment>,
60    ) -> Self {
61        SegmentBuilder {
62            node_or_token: NodeOrToken {
63                id,
64                syntax_kind,
65                class_types: class_types(syntax_kind),
66                position_marker: None,
67                code_idx: OnceCell::new(),
68                kind: NodeOrTokenKind::Node(NodeData {
69                    dialect,
70                    segments,
71                    raw: Default::default(),
72                    source_fixes: vec![],
73                    descendant_type_set: Default::default(),
74                    raw_segments_with_ancestors: Default::default(),
75                }),
76                hash: OnceCell::new(),
77            },
78        }
79    }
80
81    pub fn token(id: u32, raw: &str, syntax_kind: SyntaxKind) -> Self {
82        SegmentBuilder {
83            node_or_token: NodeOrToken {
84                id,
85                syntax_kind,
86                code_idx: OnceCell::new(),
87                class_types: class_types(syntax_kind),
88                position_marker: None,
89                kind: NodeOrTokenKind::Token(TokenData { raw: raw.into() }),
90                hash: OnceCell::new(),
91            },
92        }
93    }
94
95    pub fn position_from_segments(mut self) -> Self {
96        let segments = match &self.node_or_token.kind {
97            NodeOrTokenKind::Node(node) => &node.segments[..],
98            NodeOrTokenKind::Token(_) => &[],
99        };
100
101        self.node_or_token.position_marker = pos_marker(segments).into();
102        self
103    }
104
105    pub fn with_position(mut self, position: PositionMarker) -> Self {
106        self.node_or_token.position_marker = Some(position);
107        self
108    }
109
110    pub fn with_source_fixes(mut self, source_fixes: Vec<SourceFix>) -> Self {
111        if let NodeOrTokenKind::Node(ref mut node) = self.node_or_token.kind {
112            node.source_fixes = source_fixes;
113        }
114        self
115    }
116
117    pub fn finish(self) -> ErasedSegment {
118        ErasedSegment {
119            value: Rc::new(self.node_or_token),
120        }
121    }
122}
123
124#[derive(Debug, Default)]
125pub struct Tables {
126    counter: Cell<u32>,
127}
128
129impl Tables {
130    pub fn next_id(&self) -> u32 {
131        let id = self.counter.get();
132        self.counter.set(id + 1);
133        id
134    }
135}
136
137#[derive(Debug, Clone)]
138pub struct ErasedSegment {
139    pub(crate) value: Rc<NodeOrToken>,
140}
141
142impl Hash for ErasedSegment {
143    fn hash<H: Hasher>(&self, state: &mut H) {
144        self.hash_value().hash(state);
145    }
146}
147
148impl Eq for ErasedSegment {}
149
150impl ErasedSegment {
151    pub fn raw(&self) -> &SmolStr {
152        match &self.value.kind {
153            NodeOrTokenKind::Node(node) => node.raw.get_or_init(|| {
154                SmolStr::from_iter(self.segments().iter().map(|segment| segment.raw().as_str()))
155            }),
156            NodeOrTokenKind::Token(token) => &token.raw,
157        }
158    }
159
160    pub fn segments(&self) -> &[ErasedSegment] {
161        match &self.value.kind {
162            NodeOrTokenKind::Node(node) => &node.segments,
163            NodeOrTokenKind::Token(_) => &[],
164        }
165    }
166
167    pub fn get_type(&self) -> SyntaxKind {
168        self.value.syntax_kind
169    }
170
171    pub fn is_type(&self, kind: SyntaxKind) -> bool {
172        self.get_type() == kind
173    }
174
175    pub fn is_meta(&self) -> bool {
176        matches!(
177            self.value.syntax_kind,
178            SyntaxKind::Indent | SyntaxKind::Implicit | SyntaxKind::Dedent | SyntaxKind::EndOfFile
179        )
180    }
181
182    pub fn is_code(&self) -> bool {
183        match &self.value.kind {
184            NodeOrTokenKind::Node(node) => node.segments.iter().any(|s| s.is_code()),
185            NodeOrTokenKind::Token(_) => {
186                !self.is_comment() && !self.is_whitespace() && !self.is_meta()
187            }
188        }
189    }
190
191    pub fn get_raw_segments(&self) -> Vec<ErasedSegment> {
192        self.recursive_crawl_all(false)
193            .into_iter()
194            .filter(|it| it.segments().is_empty())
195            .collect()
196    }
197
198    #[cfg(feature = "stringify")]
199    pub fn stringify(&self, code_only: bool) -> String {
200        serde_yaml::to_string(&self.to_serialised(code_only, true)).unwrap()
201    }
202
203    pub fn child(&self, seg_types: &SyntaxSet) -> Option<ErasedSegment> {
204        self.segments()
205            .iter()
206            .find(|seg| seg_types.contains(seg.get_type()))
207            .cloned()
208    }
209
210    pub fn recursive_crawl(
211        &self,
212        types: &SyntaxSet,
213        recurse_into: bool,
214        no_recursive_types: &SyntaxSet,
215        allow_self: bool,
216    ) -> Vec<ErasedSegment> {
217        let mut acc = Vec::new();
218
219        let matches = allow_self && self.class_types().intersects(types);
220        if matches {
221            acc.push(self.clone());
222        }
223
224        if !self.descendant_type_set().intersects(types) {
225            return acc;
226        }
227
228        if recurse_into || !matches {
229            for seg in self.segments() {
230                if no_recursive_types.is_empty() || !no_recursive_types.contains(seg.get_type()) {
231                    let segments =
232                        seg.recursive_crawl(types, recurse_into, no_recursive_types, true);
233                    acc.extend(segments);
234                }
235            }
236        }
237
238        acc
239    }
240}
241
242impl ErasedSegment {
243    #[track_caller]
244    pub fn new(&self, segments: Vec<ErasedSegment>) -> ErasedSegment {
245        match &self.value.kind {
246            NodeOrTokenKind::Node(node) => {
247                let mut builder = SegmentBuilder::node(
248                    self.value.id,
249                    self.value.syntax_kind,
250                    node.dialect,
251                    segments,
252                )
253                .with_position(self.get_position_marker().unwrap().clone());
254                // Preserve source_fixes during tree rebuilds
255                if !node.source_fixes.is_empty() {
256                    builder = builder.with_source_fixes(node.source_fixes.clone());
257                }
258                builder.finish()
259            }
260            NodeOrTokenKind::Token(_) => self.deep_clone(),
261        }
262    }
263
264    fn change_segments(&self, segments: Vec<ErasedSegment>) -> ErasedSegment {
265        let NodeOrTokenKind::Node(node) = &self.value.kind else {
266            unimplemented!()
267        };
268
269        ErasedSegment {
270            value: Rc::new(NodeOrToken {
271                id: self.value.id,
272                syntax_kind: self.value.syntax_kind,
273                class_types: self.value.class_types.clone(),
274                position_marker: None,
275                code_idx: OnceCell::new(),
276                kind: NodeOrTokenKind::Node(NodeData {
277                    dialect: node.dialect,
278                    segments,
279                    raw: node.raw.clone(),
280                    source_fixes: node.source_fixes.clone(),
281                    descendant_type_set: node.descendant_type_set.clone(),
282                    raw_segments_with_ancestors: node.raw_segments_with_ancestors.clone(),
283                }),
284                hash: OnceCell::new(),
285            }),
286        }
287    }
288
289    pub fn indent_val(&self) -> i8 {
290        self.value.syntax_kind.indent_val()
291    }
292
293    pub fn can_start_end_non_code(&self) -> bool {
294        matches!(
295            self.value.syntax_kind,
296            SyntaxKind::File | SyntaxKind::Unparsable
297        )
298    }
299
300    pub(crate) fn dialect(&self) -> DialectKind {
301        match &self.value.kind {
302            NodeOrTokenKind::Node(node) => node.dialect,
303            NodeOrTokenKind::Token(_) => todo!(),
304        }
305    }
306
307    pub fn get_start_loc(&self) -> (usize, usize) {
308        match self.get_position_marker() {
309            Some(pos_marker) => pos_marker.working_loc(),
310            None => unreachable!("{self:?} has no PositionMarker"),
311        }
312    }
313
314    pub fn get_end_loc(&self) -> (usize, usize) {
315        match self.get_position_marker() {
316            Some(pos_marker) => pos_marker.working_loc_after(self.raw()),
317            None => {
318                unreachable!("{self:?} has no PositionMarker")
319            }
320        }
321    }
322
323    pub fn is_templated(&self) -> bool {
324        if let Some(pos_marker) = self.get_position_marker() {
325            pos_marker.source_slice.start != pos_marker.source_slice.end && !pos_marker.is_literal()
326        } else {
327            panic!("PosMarker must be set");
328        }
329    }
330
331    pub fn iter_segments(&self, expanding: &SyntaxSet, pass_through: bool) -> Vec<ErasedSegment> {
332        let capacity = if expanding.is_empty() {
333            self.segments().len()
334        } else {
335            0
336        };
337        let mut result = Vec::with_capacity(capacity);
338        for segment in self.segments() {
339            if expanding.contains(segment.get_type()) {
340                let expanding = if pass_through {
341                    expanding
342                } else {
343                    &SyntaxSet::EMPTY
344                };
345                result.append(&mut segment.iter_segments(expanding, false));
346            } else {
347                result.push(segment.clone());
348            }
349        }
350        result
351    }
352
353    pub(crate) fn code_indices(&self) -> Rc<Vec<usize>> {
354        self.value
355            .code_idx
356            .get_or_init(|| {
357                Rc::from(
358                    self.segments()
359                        .iter()
360                        .enumerate()
361                        .filter(|(_, seg)| seg.is_code())
362                        .map(|(idx, _)| idx)
363                        .collect::<Vec<_>>(),
364                )
365            })
366            .clone()
367    }
368
369    pub fn children(
370        &self,
371        seg_types: &'static SyntaxSet,
372    ) -> impl Iterator<Item = &ErasedSegment> + '_ {
373        self.segments()
374            .iter()
375            .filter(move |seg| seg_types.contains(seg.get_type()))
376    }
377
378    pub fn iter_patches(&self, templated_file: &TemplatedFile) -> Vec<FixPatch> {
379        let mut acc = Vec::new();
380
381        let templated_raw = &templated_file.templated_str.as_ref().unwrap()
382            [self.get_position_marker().unwrap().templated_slice.clone()];
383
384        // Always collect source fixes from this segment first
385        acc.extend(self.iter_source_fix_patches(templated_file));
386
387        // Check if any descendants have source_fixes
388        let has_descendant_source_fixes = self
389            .recursive_crawl_all(false)
390            .iter()
391            .any(|s| !s.get_source_fixes().is_empty());
392
393        if self.raw() == templated_raw {
394            if has_descendant_source_fixes {
395                // Tree raw hasn't changed - only source fix patches are needed.
396                // Avoid generating gap patches that could span template boundaries.
397                // This matches SQLFluff's behavior in _iter_templated_patches.
398                for descendant in self.recursive_crawl_all(false).into_iter().skip(1) {
399                    acc.extend(descendant.iter_source_fix_patches(templated_file));
400                }
401            }
402            return acc;
403        }
404
405        if self.get_position_marker().is_none() {
406            return Vec::new();
407        }
408
409        let pos_marker = self.get_position_marker().unwrap();
410        if pos_marker.is_literal() && !has_descendant_source_fixes {
411            acc.extend(self.iter_source_fix_patches(templated_file));
412            acc.push(FixPatch::new(
413                pos_marker.templated_slice.clone(),
414                self.raw().clone(),
415                // SyntaxKind::Literal.into(),
416                pos_marker.source_slice.clone(),
417                templated_file.templated_str.as_ref().unwrap()[pos_marker.templated_slice.clone()]
418                    .to_string(),
419                templated_file.source_str[pos_marker.source_slice.clone()].to_string(),
420            ));
421        } else if self.segments().is_empty() {
422            return acc;
423        } else {
424            let mut segments = self.segments();
425
426            while !segments.is_empty()
427                && matches!(
428                    segments.last().unwrap().get_type(),
429                    SyntaxKind::EndOfFile
430                        | SyntaxKind::Indent
431                        | SyntaxKind::Dedent
432                        | SyntaxKind::Implicit
433                )
434            {
435                segments = &segments[..segments.len() - 1];
436            }
437
438            let pos = self.get_position_marker().unwrap();
439            let mut source_idx = pos.source_slice.start;
440            let mut templated_idx = pos.templated_slice.start;
441            let mut insert_buff = String::new();
442
443            for segment in segments {
444                let pos_marker = segment.get_position_marker().unwrap();
445                if !segment.raw().is_empty() && pos_marker.is_point() {
446                    insert_buff.push_str(segment.raw().as_ref());
447                    continue;
448                }
449
450                let start_diff = pos_marker.templated_slice.start - templated_idx;
451
452                if start_diff > 0 || !insert_buff.is_empty() {
453                    let fixed_raw = std::mem::take(&mut insert_buff);
454                    let raw_segments = segment.get_raw_segments();
455                    let first_segment_pos = raw_segments[0].get_position_marker().unwrap();
456
457                    // The slices must never go backwards so the end of the slice
458                    // must be >= the start. This can happen when source positions
459                    // are non-monotonic due to template expansion.
460                    acc.push(FixPatch::new(
461                        templated_idx..first_segment_pos.templated_slice.start.max(templated_idx),
462                        fixed_raw.into(),
463                        source_idx..first_segment_pos.source_slice.start.max(source_idx),
464                        String::new(),
465                        String::new(),
466                    ));
467                }
468
469                acc.extend(segment.iter_patches(templated_file));
470
471                source_idx = pos_marker.source_slice.end;
472                templated_idx = pos_marker.templated_slice.end;
473            }
474
475            let end_diff = pos.templated_slice.end - templated_idx;
476            if end_diff != 0 || !insert_buff.is_empty() {
477                let source_slice = source_idx..pos.source_slice.end;
478                let templated_slice = templated_idx..pos.templated_slice.end;
479
480                let templated_str = templated_file.templated_str.as_ref().unwrap()
481                    [templated_slice.clone()]
482                .to_owned();
483                let source_str = templated_file.source_str[source_slice.clone()].to_owned();
484
485                acc.push(FixPatch::new(
486                    templated_slice,
487                    insert_buff.into(),
488                    source_slice,
489                    templated_str,
490                    source_str,
491                ));
492            }
493        }
494
495        acc
496    }
497
498    pub fn descendant_type_set(&self) -> &SyntaxSet {
499        match &self.value.kind {
500            NodeOrTokenKind::Node(node) => node.descendant_type_set.get_or_init(|| {
501                self.segments()
502                    .iter()
503                    .flat_map(|segment| {
504                        segment
505                            .descendant_type_set()
506                            .clone()
507                            .union(segment.class_types())
508                    })
509                    .collect()
510            }),
511            NodeOrTokenKind::Token(_) => const { &SyntaxSet::EMPTY },
512        }
513    }
514
515    pub fn is_comment(&self) -> bool {
516        matches!(
517            self.value.syntax_kind,
518            SyntaxKind::Comment
519                | SyntaxKind::InlineComment
520                | SyntaxKind::BlockComment
521                | SyntaxKind::NotebookStart
522        )
523    }
524
525    pub fn is_whitespace(&self) -> bool {
526        matches!(
527            self.value.syntax_kind,
528            SyntaxKind::Whitespace | SyntaxKind::Newline
529        )
530    }
531
532    pub fn is_indent(&self) -> bool {
533        matches!(
534            self.value.syntax_kind,
535            SyntaxKind::Indent | SyntaxKind::Implicit | SyntaxKind::Dedent
536        )
537    }
538
539    pub fn get_position_marker(&self) -> Option<&PositionMarker> {
540        self.value.position_marker.as_ref()
541    }
542
543    pub(crate) fn iter_source_fix_patches(&self, templated_file: &TemplatedFile) -> Vec<FixPatch> {
544        let source_fixes = self.get_source_fixes();
545        let mut patches = Vec::with_capacity(source_fixes.len());
546
547        for source_fix in &source_fixes {
548            patches.push(FixPatch::new(
549                source_fix.templated_slice.clone(),
550                source_fix.edit.clone(),
551                // String::from("source"),
552                source_fix.source_slice.clone(),
553                templated_file.templated_str.clone().unwrap()[source_fix.templated_slice.clone()]
554                    .to_string(),
555                templated_file.source_str[source_fix.source_slice.clone()].to_string(),
556            ));
557        }
558
559        patches
560    }
561
562    pub fn id(&self) -> u32 {
563        self.value.id
564    }
565
566    /// Return any source fixes as list.
567    pub fn get_source_fixes(&self) -> Vec<SourceFix> {
568        match &self.value.kind {
569            NodeOrTokenKind::Node(node) => node.source_fixes.clone(),
570            NodeOrTokenKind::Token(_) => Vec::new(),
571        }
572    }
573
574    /// Return all source fixes from this segment and all its descendants.
575    pub fn get_all_source_fixes(&self) -> Vec<SourceFix> {
576        let mut fixes = self.get_source_fixes();
577        for segment in self.segments() {
578            fixes.extend(segment.get_all_source_fixes());
579        }
580        fixes
581    }
582
583    pub fn edit(
584        &self,
585        id: u32,
586        raw: Option<String>,
587        _source_fixes: Option<Vec<SourceFix>>,
588    ) -> ErasedSegment {
589        match &self.value.kind {
590            NodeOrTokenKind::Node(_node) => {
591                todo!()
592            }
593            NodeOrTokenKind::Token(token) => {
594                let raw = raw.as_deref().unwrap_or(token.raw.as_ref());
595                SegmentBuilder::token(id, raw, self.value.syntax_kind)
596                    .with_position(self.get_position_marker().unwrap().clone())
597                    .finish()
598            }
599        }
600    }
601
602    pub fn class_types(&self) -> &SyntaxSet {
603        &self.value.class_types
604    }
605
606    pub(crate) fn first_non_whitespace_segment_raw_upper(&self) -> Option<String> {
607        for seg in self.get_raw_segments() {
608            if !seg.raw().is_empty() {
609                return Some(seg.raw().to_uppercase());
610            }
611        }
612        None
613    }
614
615    pub fn is(&self, other: &ErasedSegment) -> bool {
616        Rc::ptr_eq(&self.value, &other.value)
617    }
618
619    pub fn addr(&self) -> usize {
620        Rc::as_ptr(&self.value).addr()
621    }
622
623    pub fn direct_descendant_type_set(&self) -> SyntaxSet {
624        self.segments()
625            .iter()
626            .fold(SyntaxSet::EMPTY, |set, it| set.union(it.class_types()))
627    }
628
629    pub fn is_keyword(&self, p0: &str) -> bool {
630        self.is_type(SyntaxKind::Keyword) && self.raw().eq_ignore_ascii_case(p0)
631    }
632
633    pub fn hash_value(&self) -> u64 {
634        *self.value.hash.get_or_init(|| {
635            let mut hasher = DefaultHashBuilder::default().build_hasher();
636            self.get_type().hash(&mut hasher);
637            self.raw().hash(&mut hasher);
638
639            if let Some(marker) = &self.get_position_marker() {
640                marker.source_position().hash(&mut hasher);
641            } else {
642                None::<usize>.hash(&mut hasher);
643            }
644
645            hasher.finish()
646        })
647    }
648
649    pub fn deep_clone(&self) -> Self {
650        Self {
651            value: Rc::new(self.value.as_ref().clone()),
652        }
653    }
654
655    #[track_caller]
656    pub(crate) fn get_mut(&mut self) -> &mut NodeOrToken {
657        Rc::get_mut(&mut self.value).unwrap()
658    }
659
660    #[track_caller]
661    pub(crate) fn make_mut(&mut self) -> &mut NodeOrToken {
662        Rc::make_mut(&mut self.value)
663    }
664
665    pub fn reference(&self) -> ObjectReferenceSegment {
666        ObjectReferenceSegment(
667            self.clone(),
668            match self.get_type() {
669                SyntaxKind::TableReference => ObjectReferenceKind::Table,
670                SyntaxKind::WildcardIdentifier => ObjectReferenceKind::WildcardIdentifier,
671                _ => ObjectReferenceKind::Object,
672            },
673        )
674    }
675
676    pub fn recursive_crawl_all(&self, reverse: bool) -> Vec<ErasedSegment> {
677        let mut result = Vec::with_capacity(self.segments().len() + 1);
678
679        if reverse {
680            for seg in self.segments().iter().rev() {
681                result.append(&mut seg.recursive_crawl_all(reverse));
682            }
683            result.push(self.clone());
684        } else {
685            result.push(self.clone());
686            for seg in self.segments() {
687                result.append(&mut seg.recursive_crawl_all(reverse));
688            }
689        }
690
691        result
692    }
693
694    pub fn raw_segments_with_ancestors(&self) -> &[(ErasedSegment, Vec<PathStep>)] {
695        match &self.value.kind {
696            NodeOrTokenKind::Node(node) => node.raw_segments_with_ancestors.get_or_init(|| {
697                let mut buffer: Vec<(ErasedSegment, Vec<PathStep>)> =
698                    Vec::with_capacity(self.segments().len());
699                let code_idxs = self.code_indices();
700
701                for (idx, seg) in self.segments().iter().enumerate() {
702                    let new_step = vec![PathStep {
703                        segment: self.clone(),
704                        idx,
705                        len: self.segments().len(),
706                        code_idxs: code_idxs.clone(),
707                    }];
708
709                    // Use seg.get_segments().is_empty() as a workaround to check if the segment is
710                    // a SyntaxKind::Raw type. In the original Python code, this was achieved
711                    // using seg.is_type(SyntaxKind::Raw). Here, we assume that a SyntaxKind::Raw
712                    // segment is characterized by having no sub-segments.
713
714                    if seg.segments().is_empty() {
715                        buffer.push((seg.clone(), new_step));
716                    } else {
717                        let extended =
718                            seg.raw_segments_with_ancestors()
719                                .iter()
720                                .map(|(raw_seg, stack)| {
721                                    let mut new_step = new_step.clone();
722                                    new_step.extend_from_slice(stack);
723                                    (raw_seg.clone(), new_step)
724                                });
725
726                        buffer.extend(extended);
727                    }
728                }
729
730                buffer
731            }),
732            NodeOrTokenKind::Token(_) => &[],
733        }
734    }
735
736    pub fn path_to(&self, other: &ErasedSegment) -> Vec<PathStep> {
737        let midpoint = other;
738
739        for (idx, seg) in enumerate(self.segments()) {
740            let mut steps = vec![PathStep {
741                segment: self.clone(),
742                idx,
743                len: self.segments().len(),
744                code_idxs: self.code_indices(),
745            }];
746
747            if seg.eq(midpoint) {
748                return steps;
749            }
750
751            let res = seg.path_to(midpoint);
752
753            if !res.is_empty() {
754                steps.extend(res);
755                return steps;
756            }
757        }
758
759        Vec::new()
760    }
761
762    pub fn apply_fixes(
763        &self,
764        fixes: &mut HashMap<u32, AnchorEditInfo>,
765    ) -> (ErasedSegment, Vec<ErasedSegment>, Vec<ErasedSegment>) {
766        if fixes.is_empty() || self.segments().is_empty() {
767            return (self.clone(), Vec::new(), Vec::new());
768        }
769
770        let mut seg_buffer = Vec::new();
771        let mut has_applied_fixes = false;
772        let mut _requires_validate = false;
773
774        for seg in self.segments() {
775            // Look for uuid match.
776            // This handles potential positioning ambiguity.
777
778            let Some(mut anchor_info) = fixes.remove(&seg.id()) else {
779                seg_buffer.push(seg.clone());
780                continue;
781            };
782
783            if anchor_info.fixes.len() == 2
784                && matches!(anchor_info.fixes[0], LintFix::CreateAfter { .. })
785            {
786                anchor_info.fixes.reverse();
787            }
788
789            let fixes_count = anchor_info.fixes.len();
790            for lint_fix in anchor_info.fixes {
791                has_applied_fixes = true;
792
793                // Deletes are easy.
794                if matches!(lint_fix, LintFix::Delete { .. }) {
795                    // We're just getting rid of this segment.
796                    _requires_validate = true;
797                    // NOTE: We don't add the segment in this case.
798                    continue;
799                }
800
801                // Otherwise it must be a replace or a create.
802                assert!(matches!(
803                    lint_fix,
804                    LintFix::Replace { .. }
805                        | LintFix::CreateBefore { .. }
806                        | LintFix::CreateAfter { .. }
807                ));
808
809                match lint_fix {
810                    LintFix::CreateAfter { edit, .. } => {
811                        if fixes_count == 1 {
812                            // In the case of a creation after that is not part
813                            // of a create_before/create_after pair, also add
814                            // this segment before the edit.
815                            seg_buffer.push(seg.clone());
816                        }
817                        for s in edit {
818                            seg_buffer.push(s);
819                        }
820                        _requires_validate = true;
821                    }
822                    LintFix::CreateBefore { edit, .. } => {
823                        for s in edit {
824                            seg_buffer.push(s);
825                        }
826                        seg_buffer.push(seg.clone());
827                        _requires_validate = true;
828                    }
829                    LintFix::Replace { edit, .. } => {
830                        let mut consumed_pos = false;
831                        let is_single_same_type =
832                            edit.len() == 1 && edit[0].class_types() == seg.class_types();
833
834                        for mut s in edit {
835                            if !consumed_pos && s.raw() == seg.raw() {
836                                consumed_pos = true;
837                                s.make_mut()
838                                    .set_position_marker(seg.get_position_marker().cloned());
839                            }
840                            seg_buffer.push(s);
841                        }
842
843                        if !is_single_same_type {
844                            _requires_validate = true;
845                        }
846                    }
847                    LintFix::Delete { .. } => {
848                        // Already handled above
849                        unreachable!()
850                    }
851                }
852            }
853        }
854
855        if has_applied_fixes {
856            seg_buffer =
857                position_segments(&seg_buffer, self.get_position_marker().as_ref().unwrap());
858        }
859
860        let seg_queue = seg_buffer;
861        let mut seg_buffer = Vec::new();
862        for seg in seg_queue {
863            let (mid, pre, post) = seg.apply_fixes(fixes);
864
865            seg_buffer.extend(pre);
866            seg_buffer.push(mid);
867            seg_buffer.extend(post);
868        }
869
870        let seg_buffer =
871            position_segments(&seg_buffer, self.get_position_marker().as_ref().unwrap());
872        (self.new(seg_buffer), Vec::new(), Vec::new())
873    }
874}
875
876#[cfg(any(test, feature = "serde"))]
877pub mod serde {
878    use serde::ser::SerializeMap;
879    use serde::{Deserialize, Serialize};
880
881    use crate::parser::segments::ErasedSegment;
882
883    #[derive(Serialize, Deserialize)]
884    #[serde(untagged)]
885    pub enum SerialisedSegmentValue {
886        Single(String),
887        Nested(Vec<TupleSerialisedSegment>),
888    }
889
890    #[derive(Deserialize)]
891    pub struct TupleSerialisedSegment(String, SerialisedSegmentValue);
892
893    impl Serialize for TupleSerialisedSegment {
894        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
895        where
896            S: serde::Serializer,
897        {
898            let mut map = serializer.serialize_map(None)?;
899            map.serialize_key(&self.0)?;
900            map.serialize_value(&self.1)?;
901            map.end()
902        }
903    }
904
905    impl TupleSerialisedSegment {
906        pub fn sinlge(key: String, value: String) -> Self {
907            Self(key, SerialisedSegmentValue::Single(value))
908        }
909
910        pub fn nested(key: String, segments: Vec<TupleSerialisedSegment>) -> Self {
911            Self(key, SerialisedSegmentValue::Nested(segments))
912        }
913    }
914
915    impl ErasedSegment {
916        pub fn to_serialised(&self, code_only: bool, show_raw: bool) -> TupleSerialisedSegment {
917            if show_raw && self.segments().is_empty() {
918                TupleSerialisedSegment::sinlge(
919                    self.get_type().as_str().to_string(),
920                    self.raw().to_string(),
921                )
922            } else if code_only {
923                let segments = self
924                    .segments()
925                    .iter()
926                    .filter(|seg| seg.is_code() && !seg.is_meta())
927                    .map(|seg| seg.to_serialised(code_only, show_raw))
928                    .collect::<Vec<_>>();
929
930                TupleSerialisedSegment::nested(self.get_type().as_str().to_string(), segments)
931            } else {
932                let segments = self
933                    .segments()
934                    .iter()
935                    .map(|seg| seg.to_serialised(code_only, show_raw))
936                    .collect::<Vec<_>>();
937
938                TupleSerialisedSegment::nested(self.get_type().as_str().to_string(), segments)
939            }
940        }
941    }
942}
943
944impl PartialEq for ErasedSegment {
945    fn eq(&self, other: &Self) -> bool {
946        if self.id() == other.id() {
947            return true;
948        }
949
950        let pos_self = self.get_position_marker();
951        let pos_other = other.get_position_marker();
952        if let Some((pos_self, pos_other)) = pos_self.zip(pos_other) {
953            self.get_type() == other.get_type()
954                && pos_self.working_loc() == pos_other.working_loc()
955                && self.raw() == other.raw()
956        } else {
957            false
958        }
959    }
960}
961
962pub fn position_segments(
963    segments: &[ErasedSegment],
964    parent_pos: &PositionMarker,
965) -> Vec<ErasedSegment> {
966    if segments.is_empty() {
967        return Vec::new();
968    }
969
970    let (mut line_no, mut line_pos) = { (parent_pos.working_line_no, parent_pos.working_line_pos) };
971
972    let mut segment_buffer: Vec<ErasedSegment> = Vec::new();
973    for (idx, segment) in enumerate(segments) {
974        let old_position = segment.get_position_marker();
975
976        let mut new_position = match old_position {
977            Some(pos_marker) => pos_marker.clone(),
978            None => {
979                let start_point = if idx > 0 {
980                    let prev_seg = segment_buffer[idx - 1].clone();
981                    Some(prev_seg.get_position_marker().unwrap().end_point_marker())
982                } else {
983                    Some(parent_pos.start_point_marker())
984                };
985
986                let mut end_point = None;
987                for fwd_seg in &segments[idx + 1..] {
988                    if fwd_seg.get_position_marker().is_some() {
989                        end_point = Some(
990                            fwd_seg.get_raw_segments()[0]
991                                .get_position_marker()
992                                .unwrap()
993                                .start_point_marker(),
994                        );
995                        break;
996                    }
997                }
998
999                if let Some((start_point, end_point)) = start_point
1000                    .as_ref()
1001                    .zip(end_point.as_ref())
1002                    .filter(|(start_point, end_point)| start_point != end_point)
1003                {
1004                    PositionMarker::from_points(start_point, end_point)
1005                } else if let Some(start_point) = start_point.as_ref() {
1006                    start_point.clone()
1007                } else if let Some(end_point) = end_point.as_ref() {
1008                    end_point.clone()
1009                } else {
1010                    unimplemented!("Unable to position new segment")
1011                }
1012            }
1013        };
1014
1015        new_position = new_position.with_working_position(line_no, line_pos);
1016        (line_no, line_pos) = PositionMarker::infer_next_position(segment.raw(), line_no, line_pos);
1017
1018        let mut new_seg = if !segment.segments().is_empty() && old_position != Some(&new_position) {
1019            let child_segments = position_segments(segment.segments(), &new_position);
1020            segment.change_segments(child_segments)
1021        } else {
1022            segment.deep_clone()
1023        };
1024
1025        new_seg.get_mut().set_position_marker(new_position.into());
1026        segment_buffer.push(new_seg);
1027    }
1028
1029    segment_buffer
1030}
1031
1032#[derive(Debug, Clone)]
1033pub struct NodeOrToken {
1034    id: u32,
1035    syntax_kind: SyntaxKind,
1036    class_types: SyntaxSet,
1037    position_marker: Option<PositionMarker>,
1038    kind: NodeOrTokenKind,
1039    code_idx: OnceCell<Rc<Vec<usize>>>,
1040    hash: OnceCell<u64>,
1041}
1042
1043#[derive(Debug, Clone)]
1044#[allow(clippy::large_enum_variant)]
1045pub enum NodeOrTokenKind {
1046    Node(NodeData),
1047    Token(TokenData),
1048}
1049
1050impl NodeOrToken {
1051    pub fn set_position_marker(&mut self, position_marker: Option<PositionMarker>) {
1052        self.position_marker = position_marker;
1053    }
1054
1055    pub fn set_id(&mut self, id: u32) {
1056        self.id = id;
1057    }
1058}
1059
1060#[derive(Debug, Clone)]
1061pub struct NodeData {
1062    dialect: DialectKind,
1063    segments: Vec<ErasedSegment>,
1064    raw: OnceCell<SmolStr>,
1065    source_fixes: Vec<SourceFix>,
1066    descendant_type_set: OnceCell<SyntaxSet>,
1067    raw_segments_with_ancestors: OnceCell<Vec<(ErasedSegment, Vec<PathStep>)>>,
1068}
1069
1070#[derive(Debug, Clone, PartialEq)]
1071pub struct TokenData {
1072    raw: SmolStr,
1073}
1074
1075#[track_caller]
1076pub fn pos_marker(segments: &[ErasedSegment]) -> PositionMarker {
1077    let markers = segments.iter().filter_map(|seg| seg.get_position_marker());
1078
1079    PositionMarker::from_child_markers(markers)
1080}
1081
1082#[derive(Debug, Clone)]
1083pub struct PathStep {
1084    pub segment: ErasedSegment,
1085    pub idx: usize,
1086    pub len: usize,
1087    pub code_idxs: Rc<Vec<usize>>,
1088}
1089
1090fn class_types(syntax_kind: SyntaxKind) -> SyntaxSet {
1091    match syntax_kind {
1092        SyntaxKind::ColumnReference => SyntaxSet::new(&[SyntaxKind::ObjectReference, syntax_kind]),
1093        SyntaxKind::WildcardIdentifier => {
1094            SyntaxSet::new(&[SyntaxKind::WildcardIdentifier, SyntaxKind::ObjectReference])
1095        }
1096        SyntaxKind::TableReference => SyntaxSet::new(&[SyntaxKind::ObjectReference, syntax_kind]),
1097        _ => SyntaxSet::single(syntax_kind),
1098    }
1099}
1100
1101#[cfg(test)]
1102mod tests {
1103    use super::*;
1104    use crate::lint_fix::LintFix;
1105    use crate::linter::compute_anchor_edit_info;
1106    use crate::parser::segments::test_functions::{raw_seg, raw_segments};
1107
1108    #[test]
1109    /// Test comparison of raw segments.
1110    fn test_parser_base_segments_raw_compare() {
1111        let template: TemplatedFile = "foobar".into();
1112        let rs1 = SegmentBuilder::token(0, "foobar", SyntaxKind::Word)
1113            .with_position(PositionMarker::new(
1114                0..6,
1115                0..6,
1116                template.clone(),
1117                None,
1118                None,
1119            ))
1120            .finish();
1121        let rs2 = SegmentBuilder::token(0, "foobar", SyntaxKind::Word)
1122            .with_position(PositionMarker::new(
1123                0..6,
1124                0..6,
1125                template.clone(),
1126                None,
1127                None,
1128            ))
1129            .finish();
1130
1131        assert_eq!(rs1, rs2)
1132    }
1133
1134    #[test]
1135    // TODO Implement
1136    /// Test raw segments behave as expected.
1137    fn test_parser_base_segments_raw() {
1138        let raw_seg = raw_seg();
1139
1140        assert_eq!(raw_seg.raw(), "foobar");
1141    }
1142
1143    #[test]
1144    /// Test BaseSegment.compute_anchor_edit_info().
1145    fn test_parser_base_segments_compute_anchor_edit_info() {
1146        let raw_segs = raw_segments();
1147        let tables = Tables::default();
1148
1149        // Construct a fix buffer, intentionally with:
1150        // - one duplicate.
1151        // - two different incompatible fixes on the same segment.
1152        let fixes = vec![
1153            LintFix::replace(
1154                raw_segs[0].clone(),
1155                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1156                None,
1157            ),
1158            LintFix::replace(
1159                raw_segs[0].clone(),
1160                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1161                None,
1162            ),
1163            LintFix::replace(
1164                raw_segs[0].clone(),
1165                vec![raw_segs[0].edit(tables.next_id(), Some("b".to_string()), None)],
1166                None,
1167            ),
1168        ];
1169
1170        let mut anchor_edit_info = Default::default();
1171        compute_anchor_edit_info(&mut anchor_edit_info, fixes);
1172
1173        // Check the target segment is the only key we have.
1174        assert_eq!(
1175            anchor_edit_info.keys().collect::<Vec<_>>(),
1176            vec![&raw_segs[0].id()]
1177        );
1178
1179        let anchor_info = anchor_edit_info.get(&raw_segs[0].id()).unwrap();
1180
1181        // Check that the duplicate as been deduplicated i.e. this isn't 3.
1182        assert_eq!(anchor_info.replace, 2);
1183
1184        // Check the fixes themselves.
1185        //   Note: There's no duplicated first fix.
1186        assert_eq!(
1187            anchor_info.fixes[0],
1188            LintFix::replace(
1189                raw_segs[0].clone(),
1190                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1191                None,
1192            )
1193        );
1194        assert_eq!(
1195            anchor_info.fixes[1],
1196            LintFix::replace(
1197                raw_segs[0].clone(),
1198                vec![raw_segs[0].edit(tables.next_id(), Some("b".to_string()), None)],
1199                None,
1200            )
1201        );
1202
1203        // Check the first replace
1204        assert_eq!(
1205            anchor_info.fixes[anchor_info.first_replace.unwrap()],
1206            LintFix::replace(
1207                raw_segs[0].clone(),
1208                vec![raw_segs[0].edit(tables.next_id(), Some("a".to_string()), None)],
1209                None,
1210            )
1211        );
1212    }
1213}