sqruff_lib_core/parser/segments/
object_reference.rs

1use itertools::{Itertools, enumerate};
2use smol_str::{SmolStr, ToSmolStr};
3
4use crate::dialects::init::DialectKind;
5use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
6use crate::parser::segments::ErasedSegment;
7
8#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
9pub enum ObjectReferenceLevel {
10    Object = 1,
11    Table = 2,
12    Schema = 3,
13}
14
15#[derive(Clone, Debug)]
16pub struct ObjectReferencePart {
17    pub part: String,
18    pub segments: Vec<ErasedSegment>,
19}
20
21#[derive(Clone)]
22pub struct ObjectReferenceSegment(pub ErasedSegment, pub ObjectReferenceKind);
23
24#[derive(Clone)]
25pub enum ObjectReferenceKind {
26    Object,
27    Table,
28    WildcardIdentifier,
29}
30
31impl ObjectReferenceSegment {
32    pub fn is_qualified(&self) -> bool {
33        self.iter_raw_references().len() > 1
34    }
35
36    pub fn qualification(&self) -> &'static str {
37        if self.is_qualified() {
38            "qualified"
39        } else {
40            "unqualified"
41        }
42    }
43
44    pub fn extract_possible_references(
45        &self,
46        level: ObjectReferenceLevel,
47        dialect: DialectKind,
48    ) -> Vec<ObjectReferencePart> {
49        let refs = self.iter_raw_references();
50
51        match dialect {
52            DialectKind::Bigquery => {
53                if level == ObjectReferenceLevel::Schema && refs.len() >= 3 {
54                    return vec![refs[0].clone()];
55                }
56
57                if level == ObjectReferenceLevel::Table {
58                    return refs.into_iter().take(3).collect_vec();
59                }
60
61                if level == ObjectReferenceLevel::Object && refs.len() >= 3 {
62                    return vec![refs[1].clone(), refs[2].clone()];
63                }
64
65                self.extract_possible_references(level, DialectKind::Ansi)
66            }
67            _ => {
68                let level = level as usize;
69                if refs.len() >= level && level > 0 {
70                    refs.get(refs.len() - level).cloned().into_iter().collect()
71                } else {
72                    vec![]
73                }
74            }
75        }
76    }
77
78    pub fn extract_possible_multipart_references(
79        &self,
80        levels: &[ObjectReferenceLevel],
81    ) -> Vec<Vec<ObjectReferencePart>> {
82        self.extract_possible_multipart_references_inner(levels, self.0.dialect())
83    }
84
85    pub fn extract_possible_multipart_references_inner(
86        &self,
87        levels: &[ObjectReferenceLevel],
88        dialect_kind: DialectKind,
89    ) -> Vec<Vec<ObjectReferencePart>> {
90        match dialect_kind {
91            DialectKind::Bigquery => {
92                let levels_tmp: Vec<_> = levels.iter().map(|level| *level as usize).collect();
93                let min_level: usize = *levels_tmp.iter().min().unwrap();
94                let max_level: usize = *levels_tmp.iter().max().unwrap();
95                let refs = self.iter_raw_references();
96
97                if max_level == ObjectReferenceLevel::Schema as usize && refs.len() >= 3 {
98                    return vec![refs[0..=max_level - min_level].to_vec()];
99                }
100
101                self.extract_possible_multipart_references_inner(levels, DialectKind::Ansi)
102            }
103            _ => {
104                let refs = self.iter_raw_references();
105                let mut sorted_levels = levels.to_vec();
106                sorted_levels.sort_unstable();
107
108                if let (Some(&min_level), Some(&max_level)) =
109                    (sorted_levels.first(), sorted_levels.last())
110                    && refs.len() >= max_level as usize
111                {
112                    let start = refs.len() - max_level as usize;
113                    let end = refs.len() - min_level as usize + 1;
114                    if start < end {
115                        return vec![refs[start..end].to_vec()];
116                    }
117                }
118                vec![]
119            }
120        }
121    }
122
123    pub fn iter_raw_references(&self) -> Vec<ObjectReferencePart> {
124        match self.1 {
125            ObjectReferenceKind::Table if self.0.dialect() == DialectKind::Bigquery => {
126                let mut acc = Vec::new();
127                let mut parts = Vec::new();
128                let mut elems_for_parts = Vec::new();
129
130                let mut flush =
131                    |parts: &mut Vec<SmolStr>, elems_for_parts: &mut Vec<ErasedSegment>| {
132                        acc.push(ObjectReferencePart {
133                            part: std::mem::take(parts).iter().join(""),
134                            segments: std::mem::take(elems_for_parts),
135                        });
136                    };
137
138                for elem in self.0.recursive_crawl(
139                    const {
140                        &SyntaxSet::new(&[
141                            SyntaxKind::Identifier,
142                            SyntaxKind::NakedIdentifier,
143                            SyntaxKind::QuotedIdentifier,
144                            SyntaxKind::Literal,
145                            SyntaxKind::Dash,
146                            SyntaxKind::Dot,
147                            SyntaxKind::Star,
148                        ])
149                    },
150                    true,
151                    &SyntaxSet::EMPTY,
152                    true,
153                ) {
154                    if !elem.is_type(SyntaxKind::Dot) {
155                        if elem.is_type(SyntaxKind::Identifier)
156                            || elem.is_type(SyntaxKind::NakedIdentifier)
157                            || elem.is_type(SyntaxKind::QuotedIdentifier)
158                        {
159                            let raw = elem.raw();
160                            let elem_raw = raw.trim_matches('`');
161                            let elem_subparts = elem_raw.split(".").collect_vec();
162                            let elem_subparts_count = elem_subparts.len();
163
164                            for (idx, part) in enumerate(elem_subparts) {
165                                parts.push(part.to_smolstr());
166                                elems_for_parts.push(elem.clone());
167
168                                if idx != elem_subparts_count - 1 {
169                                    flush(&mut parts, &mut elems_for_parts);
170                                }
171                            }
172                        } else {
173                            parts.push(elem.raw().to_smolstr());
174                            elems_for_parts.push(elem);
175                        }
176                    } else {
177                        flush(&mut parts, &mut elems_for_parts);
178                    }
179                }
180
181                if !parts.is_empty() {
182                    flush(&mut parts, &mut elems_for_parts);
183                }
184
185                acc
186            }
187            ObjectReferenceKind::Object | ObjectReferenceKind::Table => {
188                let mut acc = Vec::new();
189
190                for elem in self.0.recursive_crawl(
191                    const {
192                        &SyntaxSet::new(&[
193                            SyntaxKind::Identifier,
194                            SyntaxKind::NakedIdentifier,
195                            SyntaxKind::QuotedIdentifier,
196                        ])
197                    },
198                    true,
199                    &SyntaxSet::EMPTY,
200                    true,
201                ) {
202                    acc.extend(self.iter_reference_parts(elem));
203                }
204
205                acc
206            }
207            ObjectReferenceKind::WildcardIdentifier => {
208                let mut acc = Vec::new();
209
210                for elem in self.0.recursive_crawl(
211                    const {
212                        &SyntaxSet::new(&[
213                            SyntaxKind::Identifier,
214                            SyntaxKind::Star,
215                            SyntaxKind::NakedIdentifier,
216                            SyntaxKind::QuotedIdentifier,
217                        ])
218                    },
219                    true,
220                    &SyntaxSet::EMPTY,
221                    true,
222                ) {
223                    acc.extend(self.iter_reference_parts(elem));
224                }
225
226                acc
227            }
228        }
229    }
230
231    fn iter_reference_parts(&self, elem: ErasedSegment) -> Vec<ObjectReferencePart> {
232        let mut acc = Vec::new();
233
234        let raw = elem.raw();
235        let parts = raw.split('.');
236
237        for part in parts {
238            acc.push(ObjectReferencePart {
239                part: part.into(),
240                segments: vec![elem.clone()],
241            });
242        }
243
244        acc
245    }
246}