sqruff_lib_core/parser/segments/
object_reference.rs

1use itertools::{Itertools, enumerate};
2use smol_str::{SmolStr, ToSmolStr};
3
4use crate::dialects::init::DialectKind;
5use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
6use crate::parser::segments::base::ErasedSegment;
7
8#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
9pub enum ObjectReferenceLevel {
10    Object = 1,
11    Table = 2,
12    Schema = 3,
13}
14
15#[derive(Clone, Debug)]
16pub struct ObjectReferencePart {
17    pub part: String,
18    pub segments: Vec<ErasedSegment>,
19}
20
21#[derive(Clone)]
22pub struct ObjectReferenceSegment(pub ErasedSegment, pub ObjectReferenceKind);
23
24#[derive(Clone)]
25pub enum ObjectReferenceKind {
26    Object,
27    Table,
28    WildcardIdentifier,
29}
30
31impl ObjectReferenceSegment {
32    pub fn is_qualified(&self) -> bool {
33        self.iter_raw_references().len() > 1
34    }
35
36    pub fn qualification(&self) -> &'static str {
37        if self.is_qualified() {
38            "qualified"
39        } else {
40            "unqualified"
41        }
42    }
43
44    pub fn extract_possible_references(
45        &self,
46        level: ObjectReferenceLevel,
47        dialect: DialectKind,
48    ) -> Vec<ObjectReferencePart> {
49        let refs = self.iter_raw_references();
50
51        match dialect {
52            DialectKind::Bigquery => {
53                if level == ObjectReferenceLevel::Schema && refs.len() >= 3 {
54                    return vec![refs[0].clone()];
55                }
56
57                if level == ObjectReferenceLevel::Table {
58                    return refs.into_iter().take(3).collect_vec();
59                }
60
61                if level == ObjectReferenceLevel::Object && refs.len() >= 3 {
62                    return vec![refs[1].clone(), refs[2].clone()];
63                }
64
65                self.extract_possible_references(level, DialectKind::Ansi)
66            }
67            _ => {
68                let level = level as usize;
69                if refs.len() >= level && level > 0 {
70                    refs.get(refs.len() - level).cloned().into_iter().collect()
71                } else {
72                    vec![]
73                }
74            }
75        }
76    }
77
78    pub fn extract_possible_multipart_references(
79        &self,
80        levels: &[ObjectReferenceLevel],
81    ) -> Vec<Vec<ObjectReferencePart>> {
82        self.extract_possible_multipart_references_inner(levels, self.0.dialect())
83    }
84
85    pub fn extract_possible_multipart_references_inner(
86        &self,
87        levels: &[ObjectReferenceLevel],
88        dialect_kind: DialectKind,
89    ) -> Vec<Vec<ObjectReferencePart>> {
90        match dialect_kind {
91            DialectKind::Bigquery => {
92                let levels_tmp: Vec<_> = levels.iter().map(|level| *level as usize).collect();
93                let min_level: usize = *levels_tmp.iter().min().unwrap();
94                let max_level: usize = *levels_tmp.iter().max().unwrap();
95                let refs = self.iter_raw_references();
96
97                if max_level == ObjectReferenceLevel::Schema as usize && refs.len() >= 3 {
98                    return vec![refs[0..=max_level - min_level].to_vec()];
99                }
100
101                self.extract_possible_multipart_references_inner(levels, DialectKind::Ansi)
102            }
103            _ => {
104                let refs = self.iter_raw_references();
105                let mut sorted_levels = levels.to_vec();
106                sorted_levels.sort_unstable();
107
108                if let (Some(&min_level), Some(&max_level)) =
109                    (sorted_levels.first(), sorted_levels.last())
110                {
111                    if refs.len() >= max_level as usize {
112                        let start = refs.len() - max_level as usize;
113                        let end = refs.len() - min_level as usize + 1;
114                        if start < end {
115                            return vec![refs[start..end].to_vec()];
116                        }
117                    }
118                }
119                vec![]
120            }
121        }
122    }
123
124    pub fn iter_raw_references(&self) -> Vec<ObjectReferencePart> {
125        match self.1 {
126            ObjectReferenceKind::Table if self.0.dialect() == DialectKind::Bigquery => {
127                let mut acc = Vec::new();
128                let mut parts = Vec::new();
129                let mut elems_for_parts = Vec::new();
130
131                let mut flush =
132                    |parts: &mut Vec<SmolStr>, elems_for_parts: &mut Vec<ErasedSegment>| {
133                        acc.push(ObjectReferencePart {
134                            part: std::mem::take(parts).iter().join(""),
135                            segments: std::mem::take(elems_for_parts),
136                        });
137                    };
138
139                for elem in self.0.recursive_crawl(
140                    const {
141                        &SyntaxSet::new(&[
142                            SyntaxKind::Identifier,
143                            SyntaxKind::NakedIdentifier,
144                            SyntaxKind::QuotedIdentifier,
145                            SyntaxKind::Literal,
146                            SyntaxKind::Dash,
147                            SyntaxKind::Dot,
148                            SyntaxKind::Star,
149                        ])
150                    },
151                    true,
152                    &SyntaxSet::EMPTY,
153                    true,
154                ) {
155                    if !elem.is_type(SyntaxKind::Dot) {
156                        if elem.is_type(SyntaxKind::Identifier)
157                            || elem.is_type(SyntaxKind::NakedIdentifier)
158                            || elem.is_type(SyntaxKind::QuotedIdentifier)
159                        {
160                            let raw = elem.raw();
161                            let elem_raw = raw.trim_matches('`');
162                            let elem_subparts = elem_raw.split(".").collect_vec();
163                            let elem_subparts_count = elem_subparts.len();
164
165                            for (idx, part) in enumerate(elem_subparts) {
166                                parts.push(part.to_smolstr());
167                                elems_for_parts.push(elem.clone());
168
169                                if idx != elem_subparts_count - 1 {
170                                    flush(&mut parts, &mut elems_for_parts);
171                                }
172                            }
173                        } else {
174                            parts.push(elem.raw().to_smolstr());
175                            elems_for_parts.push(elem);
176                        }
177                    } else {
178                        flush(&mut parts, &mut elems_for_parts);
179                    }
180                }
181
182                if !parts.is_empty() {
183                    flush(&mut parts, &mut elems_for_parts);
184                }
185
186                acc
187            }
188            ObjectReferenceKind::Object | ObjectReferenceKind::Table => {
189                let mut acc = Vec::new();
190
191                for elem in self.0.recursive_crawl(
192                    const {
193                        &SyntaxSet::new(&[
194                            SyntaxKind::Identifier,
195                            SyntaxKind::NakedIdentifier,
196                            SyntaxKind::QuotedIdentifier,
197                        ])
198                    },
199                    true,
200                    &SyntaxSet::EMPTY,
201                    true,
202                ) {
203                    acc.extend(self.iter_reference_parts(elem));
204                }
205
206                acc
207            }
208            ObjectReferenceKind::WildcardIdentifier => {
209                let mut acc = Vec::new();
210
211                for elem in self.0.recursive_crawl(
212                    const {
213                        &SyntaxSet::new(&[
214                            SyntaxKind::Identifier,
215                            SyntaxKind::Star,
216                            SyntaxKind::NakedIdentifier,
217                            SyntaxKind::QuotedIdentifier,
218                        ])
219                    },
220                    true,
221                    &SyntaxSet::EMPTY,
222                    true,
223                ) {
224                    acc.extend(self.iter_reference_parts(elem));
225                }
226
227                acc
228            }
229        }
230    }
231
232    fn iter_reference_parts(&self, elem: ErasedSegment) -> Vec<ObjectReferencePart> {
233        let mut acc = Vec::new();
234
235        let raw = elem.raw();
236        let parts = raw.split('.');
237
238        for part in parts {
239            acc.push(ObjectReferencePart {
240                part: part.into(),
241                segments: vec![elem.clone()],
242            });
243        }
244
245        acc
246    }
247}