sqruff_lib_core/utils/analysis/
select.rs

1use itertools::Itertools;
2use smol_str::{SmolStr, ToSmolStr};
3
4use crate::dialects::Dialect;
5use crate::dialects::common::{AliasInfo, ColumnAliasInfo};
6use crate::dialects::init::DialectKind;
7use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
8use crate::parser::segments::ErasedSegment;
9use crate::parser::segments::from::FromClauseSegment;
10use crate::parser::segments::join::JoinClauseSegment;
11use crate::parser::segments::object_reference::ObjectReferenceSegment;
12use crate::parser::segments::select::SelectClauseElementSegment;
13
14#[derive(Clone)]
15pub struct SelectStatementColumnsAndTables {
16    pub select_statement: ErasedSegment,
17    pub table_aliases: Vec<AliasInfo>,
18    pub standalone_aliases: Vec<SmolStr>,
19    pub reference_buffer: Vec<ObjectReferenceSegment>,
20    pub select_targets: Vec<SelectClauseElementSegment>,
21    pub col_aliases: Vec<ColumnAliasInfo>,
22    pub using_cols: Vec<SmolStr>,
23}
24
25pub fn get_object_references(segment: &ErasedSegment) -> Vec<ObjectReferenceSegment> {
26    segment
27        .recursive_crawl(
28            const { &SyntaxSet::new(&[SyntaxKind::ObjectReference, SyntaxKind::ColumnReference]) },
29            true,
30            const { &SyntaxSet::single(SyntaxKind::SelectStatement) },
31            true,
32        )
33        .into_iter()
34        .map(|seg| seg.reference())
35        .collect()
36}
37
38pub fn get_select_statement_info(
39    segment: &ErasedSegment,
40    dialect: Option<&Dialect>,
41    early_exit: bool,
42) -> Option<SelectStatementColumnsAndTables> {
43    let (table_aliases, standalone_aliases) = get_aliases_from_select(segment, dialect);
44
45    if early_exit && table_aliases.is_empty() && standalone_aliases.is_empty() {
46        return None;
47    }
48
49    let sc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::SelectClause]) })?;
50    let mut reference_buffer = get_object_references(&sc);
51    for potential_clause in [
52        SyntaxKind::WhereClause,
53        SyntaxKind::GroupbyClause,
54        SyntaxKind::HavingClause,
55        SyntaxKind::OrderbyClause,
56        SyntaxKind::QualifyClause,
57    ] {
58        let clause = segment.child(&SyntaxSet::new(&[potential_clause]));
59        if let Some(clause) = clause {
60            reference_buffer.extend(get_object_references(&clause));
61        }
62    }
63
64    let select_clause = segment
65        .child(const { &SyntaxSet::new(&[SyntaxKind::SelectClause]) })
66        .unwrap();
67    let select_targets =
68        select_clause.children(const { &SyntaxSet::new(&[SyntaxKind::SelectClauseElement]) });
69    let select_targets = select_targets
70        .map(|it| SelectClauseElementSegment(it.clone()))
71        .collect_vec();
72
73    let col_aliases = select_targets
74        .iter()
75        .filter_map(|s| s.alias())
76        .collect_vec();
77
78    let mut using_cols: Vec<SmolStr> = Vec::new();
79    let fc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::FromClause]) });
80
81    if let Some(fc) = fc {
82        for join_clause in fc.recursive_crawl(
83            const { &SyntaxSet::new(&[SyntaxKind::JoinClause]) },
84            true,
85            const { &SyntaxSet::single(SyntaxKind::SelectStatement) },
86            true,
87        ) {
88            let mut seen_using = false;
89
90            for seg in join_clause.segments() {
91                if seg.is_keyword("USING") {
92                    seen_using = true;
93                } else if seg.is_type(SyntaxKind::JoinOnCondition) {
94                    for on_seg in seg.segments() {
95                        if matches!(
96                            on_seg.get_type(),
97                            SyntaxKind::Bracketed | SyntaxKind::Expression
98                        ) {
99                            reference_buffer.extend(get_object_references(seg));
100                        }
101                    }
102                } else if seen_using && seg.is_type(SyntaxKind::Bracketed) {
103                    for subseg in seg.segments() {
104                        if subseg.is_type(SyntaxKind::Identifier)
105                            || subseg.is_type(SyntaxKind::NakedIdentifier)
106                        {
107                            using_cols.push(subseg.raw().clone());
108                        }
109                    }
110                    seen_using = false;
111                }
112            }
113        }
114    }
115
116    SelectStatementColumnsAndTables {
117        select_statement: segment.clone(),
118        table_aliases,
119        standalone_aliases,
120        reference_buffer,
121        select_targets,
122        col_aliases,
123        using_cols,
124    }
125    .into()
126}
127
128pub fn get_aliases_from_select(
129    segment: &ErasedSegment,
130    dialect: Option<&Dialect>,
131) -> (Vec<AliasInfo>, Vec<SmolStr>) {
132    let fc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::FromClause]) });
133    let Some(fc) = fc else {
134        return (Vec::new(), Vec::new());
135    };
136
137    let aliases = if fc.is_type(SyntaxKind::FromClause) {
138        FromClauseSegment(fc).eventual_aliases()
139    } else if fc.is_type(SyntaxKind::JoinClause) {
140        JoinClauseSegment(fc).eventual_aliases()
141    } else {
142        unimplemented!()
143    };
144
145    let mut standalone_aliases = Vec::new();
146    standalone_aliases.extend(get_pivot_table_columns(segment, dialect));
147    standalone_aliases.extend(get_lambda_argument_columns(segment, dialect));
148
149    let mut table_aliases = Vec::new();
150    for (table_expr, alias_info) in aliases {
151        if has_value_table_function(table_expr, dialect) {
152            if !standalone_aliases.contains(&alias_info.ref_str) {
153                standalone_aliases.push(alias_info.ref_str);
154            }
155        } else if !table_aliases.contains(&alias_info) {
156            table_aliases.push(alias_info);
157        }
158    }
159
160    (table_aliases, standalone_aliases)
161}
162
163fn has_value_table_function(table_expr: ErasedSegment, dialect: Option<&Dialect>) -> bool {
164    let Some(dialect) = dialect else {
165        return false;
166    };
167
168    for function_name in table_expr.recursive_crawl(
169        const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) },
170        true,
171        &SyntaxSet::EMPTY,
172        true,
173    ) {
174        if dialect
175            .sets("value_table_functions")
176            .contains(function_name.raw().to_uppercase().trim())
177        {
178            return true;
179        }
180    }
181
182    false
183}
184
185fn get_pivot_table_columns(segment: &ErasedSegment, dialect: Option<&Dialect>) -> Vec<SmolStr> {
186    let Some(_dialect) = dialect else {
187        return Vec::new();
188    };
189
190    let fc = segment.recursive_crawl(
191        const { &SyntaxSet::new(&[SyntaxKind::FromPivotExpression]) },
192        true,
193        &SyntaxSet::EMPTY,
194        true,
195    );
196    if !fc.is_empty() {
197        return Vec::new();
198    }
199
200    let mut pivot_table_column_aliases = Vec::new();
201    for pivot_table_column_alias in segment.recursive_crawl(
202        const { &SyntaxSet::new(&[SyntaxKind::PivotColumnReference]) },
203        true,
204        &SyntaxSet::EMPTY,
205        true,
206    ) {
207        let raw = pivot_table_column_alias.raw().clone();
208        if !pivot_table_column_aliases.contains(&raw) {
209            pivot_table_column_aliases.push(raw);
210        }
211    }
212
213    pivot_table_column_aliases
214}
215
216fn get_lambda_argument_columns(segment: &ErasedSegment, dialect: Option<&Dialect>) -> Vec<SmolStr> {
217    let Some(dialect) = dialect else {
218        return Vec::new();
219    };
220
221    if !matches!(dialect.name, DialectKind::Athena | DialectKind::Sparksql) {
222        return Vec::new();
223    }
224
225    let mut lambda_argument_columns = Vec::new();
226    for potential_lambda in segment.recursive_crawl(
227        const { &SyntaxSet::single(SyntaxKind::Expression) },
228        true,
229        &SyntaxSet::EMPTY,
230        true,
231    ) {
232        let Some(potential_arrow) =
233            potential_lambda.child(&SyntaxSet::single(SyntaxKind::BinaryOperator))
234        else {
235            continue;
236        };
237
238        if potential_arrow.raw() == "->" {
239            let arrow_operator = &potential_arrow;
240            let mut argument_segments = potential_lambda
241                .segments()
242                .iter()
243                .take_while(|&it| it != arrow_operator)
244                .filter(|it| {
245                    matches!(
246                        it.get_type(),
247                        SyntaxKind::Bracketed | SyntaxKind::ColumnReference
248                    )
249                })
250                .collect_vec();
251
252            assert_eq!(argument_segments.len(), 1);
253            let child_segment = argument_segments.pop().unwrap();
254
255            match child_segment.get_type() {
256                SyntaxKind::Bracketed => {
257                    let start_bracket = child_segment
258                        .child(&SyntaxSet::single(SyntaxKind::StartBracket))
259                        .unwrap();
260
261                    if start_bracket.raw() == "(" {
262                        let bracketed_arguments = child_segment
263                            .children(const { &SyntaxSet::single(SyntaxKind::ColumnReference) });
264
265                        lambda_argument_columns.extend(
266                            bracketed_arguments
267                                .into_iter()
268                                .map(|argument| argument.raw().to_smolstr()),
269                        )
270                    }
271                }
272                SyntaxKind::ColumnReference => {
273                    lambda_argument_columns.push(child_segment.raw().to_smolstr())
274                }
275                _ => {}
276            }
277        }
278    }
279
280    lambda_argument_columns
281}