sqruff_lib_core/utils/analysis/
select.rs1use itertools::Itertools;
2use smol_str::{SmolStr, ToSmolStr};
3
4use crate::dialects::Dialect;
5use crate::dialects::common::{AliasInfo, ColumnAliasInfo};
6use crate::dialects::init::DialectKind;
7use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
8use crate::parser::segments::ErasedSegment;
9use crate::parser::segments::from::FromClauseSegment;
10use crate::parser::segments::join::JoinClauseSegment;
11use crate::parser::segments::object_reference::ObjectReferenceSegment;
12use crate::parser::segments::select::SelectClauseElementSegment;
13
14#[derive(Clone)]
15pub struct SelectStatementColumnsAndTables {
16 pub select_statement: ErasedSegment,
17 pub table_aliases: Vec<AliasInfo>,
18 pub standalone_aliases: Vec<SmolStr>,
19 pub reference_buffer: Vec<ObjectReferenceSegment>,
20 pub select_targets: Vec<SelectClauseElementSegment>,
21 pub col_aliases: Vec<ColumnAliasInfo>,
22 pub using_cols: Vec<SmolStr>,
23 pub table_reference_buffer: Vec<ObjectReferenceSegment>,
24}
25
26pub fn get_object_references(segment: &ErasedSegment) -> Vec<ObjectReferenceSegment> {
27 segment
28 .recursive_crawl(
29 const { &SyntaxSet::new(&[SyntaxKind::ObjectReference, SyntaxKind::ColumnReference]) },
30 true,
31 const { &SyntaxSet::single(SyntaxKind::SelectStatement) },
32 true,
33 )
34 .into_iter()
35 .map(|seg| seg.reference())
36 .collect()
37}
38
39pub fn get_select_statement_info(
40 segment: &ErasedSegment,
41 dialect: Option<&Dialect>,
42 early_exit: bool,
43) -> Option<SelectStatementColumnsAndTables> {
44 let (table_aliases, standalone_aliases) = get_aliases_from_select(segment, dialect);
45
46 if early_exit && table_aliases.is_empty() && standalone_aliases.is_empty() {
47 return None;
48 }
49
50 let sc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::SelectClause]) })?;
51 let mut reference_buffer = get_object_references(&sc);
52 let mut table_reference_buffer = Vec::new();
53 for potential_clause in [
54 SyntaxKind::WhereClause,
55 SyntaxKind::GroupbyClause,
56 SyntaxKind::HavingClause,
57 SyntaxKind::OrderbyClause,
58 SyntaxKind::QualifyClause,
59 ] {
60 let clause = segment.child(&SyntaxSet::new(&[potential_clause]));
61 if let Some(clause) = clause {
62 reference_buffer.extend(get_object_references(&clause));
63 }
64 }
65
66 let select_clause = segment
67 .child(const { &SyntaxSet::new(&[SyntaxKind::SelectClause]) })
68 .unwrap();
69 let select_targets =
70 select_clause.children(const { &SyntaxSet::new(&[SyntaxKind::SelectClauseElement]) });
71 let select_targets = select_targets
72 .map(|it| SelectClauseElementSegment(it.clone()))
73 .collect_vec();
74
75 let col_aliases = select_targets
76 .iter()
77 .filter_map(|s| s.alias())
78 .collect_vec();
79
80 let mut using_cols: Vec<SmolStr> = Vec::new();
81 let fc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::FromClause]) });
82
83 if let Some(fc) = fc {
84 for table_expression in fc.recursive_crawl(
85 const { &SyntaxSet::new(&[SyntaxKind::TableExpression]) },
86 true,
87 const { &SyntaxSet::single(SyntaxKind::SelectStatement) },
88 true,
89 ) {
90 for seg in table_expression.segments() {
91 if !seg.is_type(SyntaxKind::TableReference) {
92 reference_buffer.extend(get_object_references(seg));
93 } else if seg.reference().is_qualified() {
94 table_reference_buffer.extend(get_object_references(seg));
95 }
96 }
97 }
98
99 for join_clause in fc.recursive_crawl(
100 const { &SyntaxSet::new(&[SyntaxKind::JoinClause]) },
101 true,
102 const { &SyntaxSet::single(SyntaxKind::SelectStatement) },
103 true,
104 ) {
105 let mut seen_using = false;
106
107 for seg in join_clause.segments() {
108 if seg.is_keyword("USING") {
109 seen_using = true;
110 } else if seg.is_type(SyntaxKind::JoinOnCondition) {
111 for on_seg in seg.segments() {
112 if matches!(
113 on_seg.get_type(),
114 SyntaxKind::Bracketed | SyntaxKind::Expression
115 ) {
116 reference_buffer.extend(get_object_references(seg));
117 }
118 }
119 } else if seen_using && seg.is_type(SyntaxKind::Bracketed) {
120 for subseg in seg.segments() {
121 if subseg.is_type(SyntaxKind::Identifier)
122 || subseg.is_type(SyntaxKind::NakedIdentifier)
123 {
124 using_cols.push(subseg.raw().clone());
125 }
126 }
127 seen_using = false;
128 }
129 }
130 }
131 }
132
133 SelectStatementColumnsAndTables {
134 select_statement: segment.clone(),
135 table_aliases,
136 standalone_aliases,
137 reference_buffer,
138 select_targets,
139 col_aliases,
140 using_cols,
141 table_reference_buffer,
142 }
143 .into()
144}
145
146pub fn get_aliases_from_select(
147 segment: &ErasedSegment,
148 dialect: Option<&Dialect>,
149) -> (Vec<AliasInfo>, Vec<SmolStr>) {
150 let fc = segment.child(const { &SyntaxSet::new(&[SyntaxKind::FromClause]) });
151 let Some(fc) = fc else {
152 return (Vec::new(), Vec::new());
153 };
154
155 let aliases = if fc.is_type(SyntaxKind::FromClause) {
156 FromClauseSegment(fc).eventual_aliases()
157 } else if fc.is_type(SyntaxKind::JoinClause) {
158 JoinClauseSegment(fc).eventual_aliases()
159 } else {
160 unimplemented!()
161 };
162
163 let mut standalone_aliases = Vec::new();
164 standalone_aliases.extend(get_pivot_table_columns(segment, dialect));
165 standalone_aliases.extend(get_lambda_argument_columns(segment, dialect));
166
167 let mut table_aliases = Vec::new();
168 for (table_expr, alias_info) in aliases {
169 if has_value_table_function(table_expr, dialect) {
170 if !standalone_aliases.contains(&alias_info.ref_str) {
171 standalone_aliases.push(alias_info.ref_str);
172 }
173 } else if !table_aliases.contains(&alias_info) {
174 table_aliases.push(alias_info);
175 }
176 }
177
178 (table_aliases, standalone_aliases)
179}
180
181fn has_value_table_function(table_expr: ErasedSegment, dialect: Option<&Dialect>) -> bool {
182 let Some(dialect) = dialect else {
183 return false;
184 };
185
186 for function_name in table_expr.recursive_crawl(
187 const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) },
188 true,
189 &SyntaxSet::EMPTY,
190 true,
191 ) {
192 if dialect
193 .sets("value_table_functions")
194 .contains(function_name.raw().to_uppercase().trim())
195 {
196 return true;
197 }
198 }
199
200 false
201}
202
203fn get_pivot_table_columns(segment: &ErasedSegment, dialect: Option<&Dialect>) -> Vec<SmolStr> {
204 let Some(_dialect) = dialect else {
205 return Vec::new();
206 };
207
208 let fc = segment.recursive_crawl(
209 const { &SyntaxSet::new(&[SyntaxKind::FromPivotExpression]) },
210 true,
211 &SyntaxSet::EMPTY,
212 true,
213 );
214 if !fc.is_empty() {
215 return Vec::new();
216 }
217
218 let mut pivot_table_column_aliases = Vec::new();
219 for pivot_table_column_alias in segment.recursive_crawl(
220 const { &SyntaxSet::new(&[SyntaxKind::PivotColumnReference]) },
221 true,
222 &SyntaxSet::EMPTY,
223 true,
224 ) {
225 let raw = pivot_table_column_alias.raw().clone();
226 if !pivot_table_column_aliases.contains(&raw) {
227 pivot_table_column_aliases.push(raw);
228 }
229 }
230
231 pivot_table_column_aliases
232}
233
234fn get_lambda_argument_columns(segment: &ErasedSegment, dialect: Option<&Dialect>) -> Vec<SmolStr> {
235 let Some(dialect) = dialect else {
236 return Vec::new();
237 };
238
239 if !matches!(dialect.name, DialectKind::Athena | DialectKind::Sparksql) {
240 return Vec::new();
241 }
242
243 let mut lambda_argument_columns = Vec::new();
244 for potential_lambda in segment.recursive_crawl(
245 const { &SyntaxSet::single(SyntaxKind::Expression) },
246 true,
247 &SyntaxSet::EMPTY,
248 true,
249 ) {
250 let Some(potential_arrow) =
251 potential_lambda.child(&SyntaxSet::single(SyntaxKind::BinaryOperator))
252 else {
253 continue;
254 };
255
256 if potential_arrow.raw() == "->" {
257 let arrow_operator = &potential_arrow;
258 let mut argument_segments = potential_lambda
259 .segments()
260 .iter()
261 .take_while(|&it| it != arrow_operator)
262 .filter(|it| {
263 matches!(
264 it.get_type(),
265 SyntaxKind::Bracketed | SyntaxKind::ColumnReference
266 )
267 })
268 .collect_vec();
269
270 assert_eq!(argument_segments.len(), 1);
271 let child_segment = argument_segments.pop().unwrap();
272
273 match child_segment.get_type() {
274 SyntaxKind::Bracketed => {
275 let start_bracket = child_segment
276 .child(&SyntaxSet::single(SyntaxKind::StartBracket))
277 .unwrap();
278
279 if start_bracket.raw() == "(" {
280 let bracketed_arguments = child_segment
281 .children(const { &SyntaxSet::single(SyntaxKind::ColumnReference) });
282
283 lambda_argument_columns.extend(
284 bracketed_arguments
285 .into_iter()
286 .map(|argument| argument.raw().to_smolstr()),
287 )
288 }
289 }
290 SyntaxKind::ColumnReference => {
291 lambda_argument_columns.push(child_segment.raw().to_smolstr())
292 }
293 _ => {}
294 }
295 }
296 }
297
298 lambda_argument_columns
299}