sqruff_lib_core/utils/analysis/
query.rs

1use std::cell::RefCell;
2use std::rc::Rc;
3
4use smol_str::{SmolStr, StrExt, ToSmolStr};
5
6use super::select::SelectStatementColumnsAndTables;
7use crate::dialects::Dialect;
8use crate::dialects::common::AliasInfo;
9use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
10use crate::helpers::IndexMap;
11use crate::parser::segments::ErasedSegment;
12use crate::utils::analysis::select::get_select_statement_info;
13use crate::utils::functional::segments::Segments;
14
15const SELECTABLE_TYPES: SyntaxSet = SyntaxSet::new(&[
16    SyntaxKind::WithCompoundStatement,
17    SyntaxKind::SetExpression,
18    SyntaxKind::SelectStatement,
19]);
20
21const SUBSELECT_TYPES: SyntaxSet = SyntaxSet::new(&[
22    SyntaxKind::MergeStatement,
23    SyntaxKind::UpdateStatement,
24    SyntaxKind::DeleteStatement,
25    // NOTE: Values clauses won't have sub selects, but it's
26    // also harmless to look, and they may appear in similar
27    // locations. We include them here because they come through
28    // the same code paths - although are likely to return nothing.
29    SyntaxKind::ValuesClause,
30]);
31
32#[derive(Debug, Clone, Copy)]
33pub enum QueryType {
34    Simple,
35    WithCompound,
36}
37
38pub struct WildcardInfo {
39    pub segment: ErasedSegment,
40    pub tables: Vec<SmolStr>,
41}
42
43#[derive(Debug, Clone)]
44pub struct Selectable<'me> {
45    pub selectable: ErasedSegment,
46    pub dialect: &'me Dialect,
47}
48
49impl Selectable<'_> {
50    pub fn find_alias(&self, table: &str) -> Option<AliasInfo> {
51        self.select_info()
52            .as_ref()?
53            .table_aliases
54            .iter()
55            .find(|&t| t.aliased && t.ref_str == table)
56            .cloned()
57    }
58}
59
60impl Selectable<'_> {
61    pub fn wildcard_info(&self) -> Vec<WildcardInfo> {
62        let Some(select_info) = self.select_info() else {
63            return Vec::new();
64        };
65
66        let mut buff = Vec::new();
67        for seg in select_info.select_targets {
68            if seg
69                .0
70                .child(const { &SyntaxSet::new(&[SyntaxKind::WildcardExpression]) })
71                .is_some()
72            {
73                if seg.0.raw().contains('.') {
74                    let table = seg
75                        .0
76                        .raw()
77                        .rsplit_once('.')
78                        .map(|x| x.0)
79                        .unwrap_or_default()
80                        .to_smolstr();
81                    buff.push(WildcardInfo {
82                        segment: seg.0.clone(),
83                        tables: vec![table],
84                    });
85                } else {
86                    let tables = select_info
87                        .table_aliases
88                        .iter()
89                        .filter(|it| !it.ref_str.is_empty())
90                        .map(|it| {
91                            if it.aliased {
92                                it.ref_str.clone()
93                            } else {
94                                it.from_expression_element.raw().clone()
95                            }
96                        })
97                        .collect();
98                    buff.push(WildcardInfo {
99                        segment: seg.0.clone(),
100                        tables,
101                    });
102                }
103            }
104        }
105
106        buff
107    }
108}
109
110impl Selectable<'_> {
111    pub fn select_info(&self) -> Option<SelectStatementColumnsAndTables> {
112        if self.selectable.is_type(SyntaxKind::SelectStatement) {
113            return get_select_statement_info(&self.selectable, self.dialect.into(), false);
114        }
115
116        let values = Segments::new(self.selectable.clone(), None);
117        let alias_expression = values.children(None).find_first(Some(|it: &ErasedSegment| {
118            it.is_type(SyntaxKind::AliasExpression)
119        }));
120        let name = alias_expression
121            .children(None)
122            .find_first(Some(|it: &ErasedSegment| {
123                matches!(
124                    it.get_type(),
125                    SyntaxKind::NakedIdentifier | SyntaxKind::QuotedIdentifier,
126                )
127            }));
128
129        let alias_info = AliasInfo {
130            ref_str: if name.is_empty() {
131                SmolStr::new_static("")
132            } else {
133                name.first().unwrap().raw().clone()
134            },
135            segment: name.first().cloned(),
136            aliased: !name.is_empty(),
137            from_expression_element: self.selectable.clone(),
138            alias_expression: alias_expression.first().cloned(),
139            object_reference: None,
140        };
141
142        SelectStatementColumnsAndTables {
143            select_statement: self.selectable.clone(),
144            table_aliases: vec![alias_info],
145            standalone_aliases: Vec::new(),
146            reference_buffer: Vec::new(),
147            select_targets: Vec::new(),
148            col_aliases: Vec::new(),
149            using_cols: Vec::new(),
150        }
151        .into()
152    }
153}
154
155#[derive(Debug, Clone)]
156pub struct Query<'me, T> {
157    pub inner: Rc<RefCell<QueryInner<'me, T>>>,
158}
159
160#[derive(Debug, Clone)]
161pub struct QueryInner<'me, T> {
162    pub query_type: QueryType,
163    pub dialect: &'me Dialect,
164    pub selectables: Vec<Selectable<'me>>,
165    pub ctes: IndexMap<SmolStr, Query<'me, T>>,
166    pub parent: Option<Query<'me, T>>,
167    pub subqueries: Vec<Query<'me, T>>,
168    pub cte_definition_segment: Option<ErasedSegment>,
169    pub cte_name_segment: Option<ErasedSegment>,
170    pub payload: T,
171}
172
173impl<'me, T: Clone + Default> Query<'me, T> {
174    pub fn crawl_sources(
175        &self,
176        segment: ErasedSegment,
177
178        pop: bool,
179        lookup_cte: bool,
180    ) -> Vec<Source<'me, T>> {
181        let mut acc = Vec::new();
182
183        for seg in segment.recursive_crawl(
184            const {
185                &SyntaxSet::new(&[
186                    SyntaxKind::TableReference,
187                    SyntaxKind::SetExpression,
188                    SyntaxKind::SelectStatement,
189                    SyntaxKind::ValuesClause,
190                ])
191            },
192            false,
193            &SyntaxSet::EMPTY,
194            false,
195        ) {
196            if seg.is_type(SyntaxKind::TableReference) {
197                let _seg = seg.reference();
198                if !_seg.is_qualified()
199                    && lookup_cte
200                    && let Some(cte) = self.lookup_cte(seg.raw().as_ref(), pop)
201                {
202                    acc.push(Source::Query(cte));
203                }
204                acc.push(Source::TableReference(seg.raw().clone()));
205            } else {
206                acc.push(Source::Query(Query::from_segment(
207                    &seg,
208                    self.inner.borrow().dialect,
209                    Some(self.clone()),
210                )))
211            }
212        }
213
214        if acc.is_empty()
215            && let Some(table_expr) =
216                segment.child(const { &SyntaxSet::new(&[SyntaxKind::TableExpression]) })
217        {
218            return vec![Source::TableReference(table_expr.raw().to_smolstr())];
219        }
220
221        acc
222    }
223
224    #[track_caller]
225    pub fn lookup_cte(&self, name: &str, pop: bool) -> Option<Query<'me, T>> {
226        let cte = if pop {
227            self.inner
228                .borrow_mut()
229                .ctes
230                .shift_remove(&name.to_uppercase_smolstr())
231        } else {
232            self.inner
233                .borrow()
234                .ctes
235                .get(&name.to_uppercase_smolstr())
236                .cloned()
237        };
238
239        cte.or_else(move || {
240            self.inner
241                .borrow_mut()
242                .parent
243                .as_mut()
244                .and_then(|it| it.lookup_cte(name, pop))
245        })
246    }
247
248    fn post_init(&self) {
249        let this = self.clone();
250
251        for subquery in &RefCell::borrow(&self.inner).subqueries {
252            RefCell::borrow_mut(&subquery.inner).parent = this.clone().into();
253        }
254
255        for cte in RefCell::borrow(&self.inner).ctes.values().cloned() {
256            RefCell::borrow_mut(&cte.inner).parent = this.clone().into();
257        }
258    }
259}
260
261impl<T: Default + Clone> Query<'_, T> {
262    pub fn children(&self) -> Vec<Self> {
263        self.inner
264            .borrow()
265            .ctes
266            .values()
267            .chain(self.inner.borrow().subqueries.iter())
268            .cloned()
269            .collect()
270    }
271
272    fn extract_subqueries<'a>(selectable: &Selectable, dialect: &'a Dialect) -> Vec<Query<'a, T>> {
273        let mut acc = Vec::new();
274
275        for subselect in selectable.selectable.recursive_crawl(
276            &SELECTABLE_TYPES,
277            false,
278            &SyntaxSet::EMPTY,
279            false,
280        ) {
281            acc.push(Query::from_segment(&subselect, dialect, None));
282        }
283
284        acc
285    }
286
287    pub fn from_root<'a>(
288        root_segment: &ErasedSegment,
289        dialect: &'a Dialect,
290    ) -> Option<Query<'a, T>> {
291        let stmts = root_segment.recursive_crawl(
292            &SELECTABLE_TYPES,
293            true,
294            &SyntaxSet::single(SyntaxKind::MergeStatement),
295            true,
296        );
297        let selectable_segment = stmts.first()?;
298
299        Some(Query::from_segment(selectable_segment, dialect, None))
300    }
301
302    pub fn from_segment<'a>(
303        segment: &ErasedSegment,
304        dialect: &'a Dialect,
305        parent: Option<Query<'a, T>>,
306    ) -> Query<'a, T> {
307        let mut selectables = Vec::new();
308        let mut subqueries = Vec::new();
309        let mut cte_defs: Vec<ErasedSegment> = Vec::new();
310        let mut query_type = QueryType::Simple;
311
312        if segment.is_type(SyntaxKind::SelectStatement)
313            || SUBSELECT_TYPES.contains(segment.get_type())
314        {
315            selectables.push(Selectable {
316                selectable: segment.clone(),
317                dialect,
318            });
319        } else if segment.is_type(SyntaxKind::SetExpression) {
320            selectables.extend(
321                segment
322                    .children(const { &SyntaxSet::new(&[SyntaxKind::SelectStatement]) })
323                    .cloned()
324                    .map(|selectable| Selectable {
325                        selectable,
326                        dialect,
327                    }),
328            )
329        } else {
330            query_type = QueryType::WithCompound;
331
332            for seg in segment.recursive_crawl(
333                const { &SyntaxSet::new(&[SyntaxKind::SelectStatement]) },
334                false,
335                const { &SyntaxSet::single(SyntaxKind::CommonTableExpression) },
336                true,
337            ) {
338                selectables.push(Selectable {
339                    selectable: seg,
340                    dialect,
341                });
342            }
343
344            for seg in segment.recursive_crawl(
345                const { &SyntaxSet::new(&[SyntaxKind::CommonTableExpression]) },
346                false,
347                const { &SyntaxSet::single(SyntaxKind::WithCompoundStatement) },
348                true,
349            ) {
350                cte_defs.push(seg);
351            }
352        }
353
354        for selectable in &selectables {
355            subqueries.extend(Self::extract_subqueries(selectable, dialect));
356        }
357
358        let outer_query = Query {
359            inner: Rc::new(RefCell::new(QueryInner {
360                query_type,
361                dialect,
362                selectables,
363                ctes: <_>::default(),
364                parent,
365                subqueries,
366                cte_definition_segment: None,
367                cte_name_segment: None,
368                payload: T::default(),
369            })),
370        };
371
372        outer_query.post_init();
373
374        if cte_defs.is_empty() {
375            return outer_query;
376        }
377
378        let mut ctes = IndexMap::default();
379        for cte in cte_defs {
380            let name_seg = cte.segments()[0].clone();
381            let name = name_seg.raw().to_uppercase_smolstr();
382
383            let queries = cte.recursive_crawl(
384                const { &SELECTABLE_TYPES.union(&SUBSELECT_TYPES) },
385                true,
386                &SyntaxSet::EMPTY,
387                true,
388            );
389
390            if queries.is_empty() {
391                continue;
392            };
393
394            let query = &queries[0];
395            let query = Self::from_segment(query, dialect, outer_query.clone().into());
396
397            RefCell::borrow_mut(&query.inner).cte_definition_segment = cte.into();
398            RefCell::borrow_mut(&query.inner).cte_name_segment = name_seg.into();
399
400            ctes.insert(name, query);
401        }
402
403        RefCell::borrow_mut(&outer_query.inner).ctes = ctes;
404        outer_query
405    }
406}
407
408pub enum Source<'a, T> {
409    TableReference(SmolStr),
410    Query(Query<'a, T>),
411}