sqruff_lib_core/utils/analysis/
select.rs1use itertools::Itertools;
2use smol_str::{SmolStr, ToSmolStr};
3
4use crate::dialects::Dialect;
5use crate::dialects::common::{AliasInfo, ColumnAliasInfo};
6use crate::dialects::init::DialectKind;
7use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
8use crate::parser::segments::ErasedSegment;
9use crate::parser::segments::from::FromClauseSegment;
10use crate::parser::segments::join::JoinClauseSegment;
11use crate::parser::segments::object_reference::ObjectReferenceSegment;
12use crate::parser::segments::select::SelectClauseElementSegment;
13
14#[derive(Clone)]
15pub struct SelectStatementColumnsAndTables {
16 pub select_statement: ErasedSegment,
17 pub table_aliases: Vec<AliasInfo>,
18 pub standalone_aliases: Vec<SmolStr>,
19 pub reference_buffer: Vec<ObjectReferenceSegment>,
20 pub select_targets: Vec<SelectClauseElementSegment>,
21 pub col_aliases: Vec<ColumnAliasInfo>,
22 pub using_cols: Vec<SmolStr>,
23}
24
25pub fn get_object_references(segment: &ErasedSegment) -> Vec<ObjectReferenceSegment> {
26 segment
27 .recursive_crawl(
28 const { &SyntaxSet::new(&[SyntaxKind::ObjectReference, SyntaxKind::ColumnReference]) },
29 true,
30 const { &SyntaxSet::single(SyntaxKind::SelectStatement) },
31 true,
32 )
33 .into_iter()
34 .map(|seg| seg.reference())
35 .collect()
36}
37
38pub fn get_select_statement_info(
39 segment: &ErasedSegment,
40 dialect: Option<&Dialect>,
41 early_exit: bool,
42) -> Option<SelectStatementColumnsAndTables> {
43 let (table_aliases, standalone_aliases) = get_aliases_from_select(segment, dialect);
44
45 if early_exit && table_aliases.is_empty() && standalone_aliases.is_empty() {
46 return None;
47 }
48
49 let sc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::SelectClause]) })?;
50 let mut reference_buffer = get_object_references(&sc);
51 for potential_clause in [
52 SyntaxKind::WhereClause,
53 SyntaxKind::GroupbyClause,
54 SyntaxKind::HavingClause,
55 SyntaxKind::OrderbyClause,
56 SyntaxKind::QualifyClause,
57 ] {
58 let clause = segment.child(&SyntaxSet::new(&[potential_clause]));
59 if let Some(clause) = clause {
60 reference_buffer.extend(get_object_references(&clause));
61 }
62 }
63
64 let select_clause = segment
65 .child(const { &SyntaxSet::new(&[SyntaxKind::SelectClause]) })
66 .unwrap();
67 let select_targets =
68 select_clause.children(const { &SyntaxSet::new(&[SyntaxKind::SelectClauseElement]) });
69 let select_targets = select_targets
70 .map(|it| SelectClauseElementSegment(it.clone()))
71 .collect_vec();
72
73 let col_aliases = select_targets
74 .iter()
75 .filter_map(|s| s.alias())
76 .collect_vec();
77
78 let mut using_cols: Vec<SmolStr> = Vec::new();
79 let fc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::FromClause]) });
80
81 if let Some(fc) = fc {
82 for join_clause in fc.recursive_crawl(
83 const { &SyntaxSet::new(&[SyntaxKind::JoinClause]) },
84 true,
85 const { &SyntaxSet::single(SyntaxKind::SelectStatement) },
86 true,
87 ) {
88 let mut seen_using = false;
89
90 for seg in join_clause.segments() {
91 if seg.is_keyword("USING") {
92 seen_using = true;
93 } else if seg.is_type(SyntaxKind::JoinOnCondition) {
94 for on_seg in seg.segments() {
95 if matches!(
96 on_seg.get_type(),
97 SyntaxKind::Bracketed | SyntaxKind::Expression
98 ) {
99 reference_buffer.extend(get_object_references(seg));
100 }
101 }
102 } else if seen_using && seg.is_type(SyntaxKind::Bracketed) {
103 for subseg in seg.segments() {
104 if subseg.is_type(SyntaxKind::Identifier)
105 || subseg.is_type(SyntaxKind::NakedIdentifier)
106 {
107 using_cols.push(subseg.raw().clone());
108 }
109 }
110 seen_using = false;
111 }
112 }
113 }
114 }
115
116 SelectStatementColumnsAndTables {
117 select_statement: segment.clone(),
118 table_aliases,
119 standalone_aliases,
120 reference_buffer,
121 select_targets,
122 col_aliases,
123 using_cols,
124 }
125 .into()
126}
127
128pub fn get_aliases_from_select(
129 segment: &ErasedSegment,
130 dialect: Option<&Dialect>,
131) -> (Vec<AliasInfo>, Vec<SmolStr>) {
132 let fc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::FromClause]) });
133 let Some(fc) = fc else {
134 return (Vec::new(), Vec::new());
135 };
136
137 let aliases = if fc.is_type(SyntaxKind::FromClause) {
138 FromClauseSegment(fc).eventual_aliases()
139 } else if fc.is_type(SyntaxKind::JoinClause) {
140 JoinClauseSegment(fc).eventual_aliases()
141 } else {
142 unimplemented!()
143 };
144
145 let mut standalone_aliases = Vec::new();
146 standalone_aliases.extend(get_pivot_table_columns(segment, dialect));
147 standalone_aliases.extend(get_lambda_argument_columns(segment, dialect));
148
149 let mut table_aliases = Vec::new();
150 for (table_expr, alias_info) in aliases {
151 if has_value_table_function(table_expr, dialect) {
152 if !standalone_aliases.contains(&alias_info.ref_str) {
153 standalone_aliases.push(alias_info.ref_str);
154 }
155 } else if !table_aliases.contains(&alias_info) {
156 table_aliases.push(alias_info);
157 }
158 }
159
160 (table_aliases, standalone_aliases)
161}
162
163fn has_value_table_function(table_expr: ErasedSegment, dialect: Option<&Dialect>) -> bool {
164 let Some(dialect) = dialect else {
165 return false;
166 };
167
168 for function_name in table_expr.recursive_crawl(
169 const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) },
170 true,
171 &SyntaxSet::EMPTY,
172 true,
173 ) {
174 if dialect
175 .sets("value_table_functions")
176 .contains(function_name.raw().to_uppercase().trim())
177 {
178 return true;
179 }
180 }
181
182 false
183}
184
185fn get_pivot_table_columns(segment: &ErasedSegment, dialect: Option<&Dialect>) -> Vec<SmolStr> {
186 let Some(_dialect) = dialect else {
187 return Vec::new();
188 };
189
190 let fc = segment.recursive_crawl(
191 const { &SyntaxSet::new(&[SyntaxKind::FromPivotExpression]) },
192 true,
193 &SyntaxSet::EMPTY,
194 true,
195 );
196 if !fc.is_empty() {
197 return Vec::new();
198 }
199
200 let mut pivot_table_column_aliases = Vec::new();
201 for pivot_table_column_alias in segment.recursive_crawl(
202 const { &SyntaxSet::new(&[SyntaxKind::PivotColumnReference]) },
203 true,
204 &SyntaxSet::EMPTY,
205 true,
206 ) {
207 let raw = pivot_table_column_alias.raw().clone();
208 if !pivot_table_column_aliases.contains(&raw) {
209 pivot_table_column_aliases.push(raw);
210 }
211 }
212
213 pivot_table_column_aliases
214}
215
216fn get_lambda_argument_columns(segment: &ErasedSegment, dialect: Option<&Dialect>) -> Vec<SmolStr> {
217 let Some(dialect) = dialect else {
218 return Vec::new();
219 };
220
221 if !matches!(dialect.name, DialectKind::Athena | DialectKind::Sparksql) {
222 return Vec::new();
223 }
224
225 let mut lambda_argument_columns = Vec::new();
226 for potential_lambda in segment.recursive_crawl(
227 const { &SyntaxSet::single(SyntaxKind::Expression) },
228 true,
229 &SyntaxSet::EMPTY,
230 true,
231 ) {
232 let Some(potential_arrow) =
233 potential_lambda.child(&SyntaxSet::single(SyntaxKind::BinaryOperator))
234 else {
235 continue;
236 };
237
238 if potential_arrow.raw() == "->" {
239 let arrow_operator = &potential_arrow;
240 let mut argument_segments = potential_lambda
241 .segments()
242 .iter()
243 .take_while(|&it| it != arrow_operator)
244 .filter(|it| {
245 matches!(
246 it.get_type(),
247 SyntaxKind::Bracketed | SyntaxKind::ColumnReference
248 )
249 })
250 .collect_vec();
251
252 assert_eq!(argument_segments.len(), 1);
253 let child_segment = argument_segments.pop().unwrap();
254
255 match child_segment.get_type() {
256 SyntaxKind::Bracketed => {
257 let start_bracket = child_segment
258 .child(&SyntaxSet::single(SyntaxKind::StartBracket))
259 .unwrap();
260
261 if start_bracket.raw() == "(" {
262 let bracketed_arguments = child_segment
263 .children(const { &SyntaxSet::single(SyntaxKind::ColumnReference) });
264
265 lambda_argument_columns.extend(
266 bracketed_arguments
267 .into_iter()
268 .map(|argument| argument.raw().to_smolstr()),
269 )
270 }
271 }
272 SyntaxKind::ColumnReference => {
273 lambda_argument_columns.push(child_segment.raw().to_smolstr())
274 }
275 _ => {}
276 }
277 }
278 }
279
280 lambda_argument_columns
281}