Skip to main content

sqruff_lib_core/utils/analysis/
select.rs

1use itertools::Itertools;
2use smol_str::{SmolStr, ToSmolStr};
3
4use crate::dialects::Dialect;
5use crate::dialects::common::{AliasInfo, ColumnAliasInfo};
6use crate::dialects::init::DialectKind;
7use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
8use crate::parser::segments::ErasedSegment;
9use crate::parser::segments::from::FromClauseSegment;
10use crate::parser::segments::join::JoinClauseSegment;
11use crate::parser::segments::object_reference::ObjectReferenceSegment;
12use crate::parser::segments::select::SelectClauseElementSegment;
13
14#[derive(Clone)]
15pub struct SelectStatementColumnsAndTables {
16    pub select_statement: ErasedSegment,
17    pub table_aliases: Vec<AliasInfo>,
18    pub standalone_aliases: Vec<SmolStr>,
19    pub reference_buffer: Vec<ObjectReferenceSegment>,
20    pub select_targets: Vec<SelectClauseElementSegment>,
21    pub col_aliases: Vec<ColumnAliasInfo>,
22    pub using_cols: Vec<SmolStr>,
23    pub table_reference_buffer: Vec<ObjectReferenceSegment>,
24}
25
26pub fn get_object_references(segment: &ErasedSegment) -> Vec<ObjectReferenceSegment> {
27    segment
28        .recursive_crawl(
29            const { &SyntaxSet::new(&[SyntaxKind::ObjectReference, SyntaxKind::ColumnReference]) },
30            true,
31            const { &SyntaxSet::single(SyntaxKind::SelectStatement) },
32            true,
33        )
34        .into_iter()
35        .map(|seg| seg.reference())
36        .collect()
37}
38
39pub fn get_select_statement_info(
40    segment: &ErasedSegment,
41    dialect: Option<&Dialect>,
42    early_exit: bool,
43) -> Option<SelectStatementColumnsAndTables> {
44    let (table_aliases, standalone_aliases) = get_aliases_from_select(segment, dialect);
45
46    if early_exit && table_aliases.is_empty() && standalone_aliases.is_empty() {
47        return None;
48    }
49
50    let sc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::SelectClause]) })?;
51    let mut reference_buffer = get_object_references(&sc);
52    let mut table_reference_buffer = Vec::new();
53    for potential_clause in [
54        SyntaxKind::WhereClause,
55        SyntaxKind::GroupbyClause,
56        SyntaxKind::HavingClause,
57        SyntaxKind::OrderbyClause,
58        SyntaxKind::QualifyClause,
59    ] {
60        let clause = segment.child(&SyntaxSet::new(&[potential_clause]));
61        if let Some(clause) = clause {
62            reference_buffer.extend(get_object_references(&clause));
63        }
64    }
65
66    let select_clause = segment
67        .child(const { &SyntaxSet::new(&[SyntaxKind::SelectClause]) })
68        .unwrap();
69    let select_targets =
70        select_clause.children(const { &SyntaxSet::new(&[SyntaxKind::SelectClauseElement]) });
71    let select_targets = select_targets
72        .map(|it| SelectClauseElementSegment(it.clone()))
73        .collect_vec();
74
75    let col_aliases = select_targets
76        .iter()
77        .filter_map(|s| s.alias())
78        .collect_vec();
79
80    let mut using_cols: Vec<SmolStr> = Vec::new();
81    let fc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::FromClause]) });
82
83    if let Some(fc) = fc {
84        for table_expression in fc.recursive_crawl(
85            const { &SyntaxSet::new(&[SyntaxKind::TableExpression]) },
86            true,
87            const { &SyntaxSet::single(SyntaxKind::SelectStatement) },
88            true,
89        ) {
90            for seg in table_expression.segments() {
91                if !seg.is_type(SyntaxKind::TableReference) {
92                    reference_buffer.extend(get_object_references(seg));
93                } else if seg.reference().is_qualified() {
94                    table_reference_buffer.extend(get_object_references(seg));
95                }
96            }
97        }
98
99        for join_clause in fc.recursive_crawl(
100            const { &SyntaxSet::new(&[SyntaxKind::JoinClause]) },
101            true,
102            const { &SyntaxSet::single(SyntaxKind::SelectStatement) },
103            true,
104        ) {
105            let mut seen_using = false;
106
107            for seg in join_clause.segments() {
108                if seg.is_keyword("USING") {
109                    seen_using = true;
110                } else if seg.is_type(SyntaxKind::JoinOnCondition) {
111                    for on_seg in seg.segments() {
112                        if matches!(
113                            on_seg.get_type(),
114                            SyntaxKind::Bracketed | SyntaxKind::Expression
115                        ) {
116                            reference_buffer.extend(get_object_references(seg));
117                        }
118                    }
119                } else if seen_using && seg.is_type(SyntaxKind::Bracketed) {
120                    for subseg in seg.segments() {
121                        if subseg.is_type(SyntaxKind::Identifier)
122                            || subseg.is_type(SyntaxKind::NakedIdentifier)
123                        {
124                            using_cols.push(subseg.raw().clone());
125                        }
126                    }
127                    seen_using = false;
128                }
129            }
130        }
131    }
132
133    SelectStatementColumnsAndTables {
134        select_statement: segment.clone(),
135        table_aliases,
136        standalone_aliases,
137        reference_buffer,
138        select_targets,
139        col_aliases,
140        using_cols,
141        table_reference_buffer,
142    }
143    .into()
144}
145
146pub fn get_aliases_from_select(
147    segment: &ErasedSegment,
148    dialect: Option<&Dialect>,
149) -> (Vec<AliasInfo>, Vec<SmolStr>) {
150    let fc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::FromClause]) });
151    let Some(fc) = fc else {
152        return (Vec::new(), Vec::new());
153    };
154
155    let aliases = if fc.is_type(SyntaxKind::FromClause) {
156        FromClauseSegment(fc).eventual_aliases()
157    } else if fc.is_type(SyntaxKind::JoinClause) {
158        JoinClauseSegment(fc).eventual_aliases()
159    } else {
160        unimplemented!()
161    };
162
163    let mut standalone_aliases = Vec::new();
164    standalone_aliases.extend(get_pivot_table_columns(segment, dialect));
165    standalone_aliases.extend(get_lambda_argument_columns(segment, dialect));
166
167    let mut table_aliases = Vec::new();
168    for (table_expr, alias_info) in aliases {
169        if has_value_table_function(table_expr, dialect) {
170            if !standalone_aliases.contains(&alias_info.ref_str) {
171                standalone_aliases.push(alias_info.ref_str);
172            }
173        } else if !table_aliases.contains(&alias_info) {
174            table_aliases.push(alias_info);
175        }
176    }
177
178    (table_aliases, standalone_aliases)
179}
180
181fn has_value_table_function(table_expr: ErasedSegment, dialect: Option<&Dialect>) -> bool {
182    let Some(dialect) = dialect else {
183        return false;
184    };
185
186    for function_name in table_expr.recursive_crawl(
187        const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) },
188        true,
189        &SyntaxSet::EMPTY,
190        true,
191    ) {
192        if dialect
193            .sets("value_table_functions")
194            .contains(function_name.raw().to_uppercase().trim())
195        {
196            return true;
197        }
198    }
199
200    false
201}
202
203fn get_pivot_table_columns(segment: &ErasedSegment, dialect: Option<&Dialect>) -> Vec<SmolStr> {
204    let Some(_dialect) = dialect else {
205        return Vec::new();
206    };
207
208    let fc = segment.recursive_crawl(
209        const { &SyntaxSet::new(&[SyntaxKind::FromPivotExpression]) },
210        true,
211        &SyntaxSet::EMPTY,
212        true,
213    );
214    if !fc.is_empty() {
215        return Vec::new();
216    }
217
218    let mut pivot_table_column_aliases = Vec::new();
219    for pivot_table_column_alias in segment.recursive_crawl(
220        const { &SyntaxSet::new(&[SyntaxKind::PivotColumnReference]) },
221        true,
222        &SyntaxSet::EMPTY,
223        true,
224    ) {
225        let raw = pivot_table_column_alias.raw().clone();
226        if !pivot_table_column_aliases.contains(&raw) {
227            pivot_table_column_aliases.push(raw);
228        }
229    }
230
231    pivot_table_column_aliases
232}
233
234fn get_lambda_argument_columns(segment: &ErasedSegment, dialect: Option<&Dialect>) -> Vec<SmolStr> {
235    let Some(dialect) = dialect else {
236        return Vec::new();
237    };
238
239    if !matches!(dialect.name, DialectKind::Athena | DialectKind::Sparksql) {
240        return Vec::new();
241    }
242
243    let mut lambda_argument_columns = Vec::new();
244    for potential_lambda in segment.recursive_crawl(
245        const { &SyntaxSet::single(SyntaxKind::Expression) },
246        true,
247        &SyntaxSet::EMPTY,
248        true,
249    ) {
250        let Some(potential_arrow) =
251            potential_lambda.child(&SyntaxSet::single(SyntaxKind::BinaryOperator))
252        else {
253            continue;
254        };
255
256        if potential_arrow.raw() == "->" {
257            let arrow_operator = &potential_arrow;
258            let mut argument_segments = potential_lambda
259                .segments()
260                .iter()
261                .take_while(|&it| it != arrow_operator)
262                .filter(|it| {
263                    matches!(
264                        it.get_type(),
265                        SyntaxKind::Bracketed | SyntaxKind::ColumnReference
266                    )
267                })
268                .collect_vec();
269
270            assert_eq!(argument_segments.len(), 1);
271            let child_segment = argument_segments.pop().unwrap();
272
273            match child_segment.get_type() {
274                SyntaxKind::Bracketed => {
275                    let start_bracket = child_segment
276                        .child(&SyntaxSet::single(SyntaxKind::StartBracket))
277                        .unwrap();
278
279                    if start_bracket.raw() == "(" {
280                        let bracketed_arguments = child_segment
281                            .children(const { &SyntaxSet::single(SyntaxKind::ColumnReference) });
282
283                        lambda_argument_columns.extend(
284                            bracketed_arguments
285                                .into_iter()
286                                .map(|argument| argument.raw().to_smolstr()),
287                        )
288                    }
289                }
290                SyntaxKind::ColumnReference => {
291                    lambda_argument_columns.push(child_segment.raw().to_smolstr())
292                }
293                _ => {}
294            }
295        }
296    }
297
298    lambda_argument_columns
299}