sqruff_lib_core/utils/analysis/
query.rs1use std::cell::RefCell;
2use std::rc::Rc;
3
4use smol_str::{SmolStr, StrExt, ToSmolStr};
5
6use super::select::SelectStatementColumnsAndTables;
7use crate::dialects::Dialect;
8use crate::dialects::common::AliasInfo;
9use crate::dialects::syntax::{SyntaxKind, SyntaxSet};
10use crate::helpers::IndexMap;
11use crate::parser::segments::ErasedSegment;
12use crate::utils::analysis::select::get_select_statement_info;
13use crate::utils::functional::segments::Segments;
14
15const SELECTABLE_TYPES: SyntaxSet = SyntaxSet::new(&[
16 SyntaxKind::WithCompoundStatement,
17 SyntaxKind::SetExpression,
18 SyntaxKind::SelectStatement,
19]);
20
21const SUBSELECT_TYPES: SyntaxSet = SyntaxSet::new(&[
22 SyntaxKind::MergeStatement,
23 SyntaxKind::UpdateStatement,
24 SyntaxKind::DeleteStatement,
25 SyntaxKind::ValuesClause,
30]);
31
32#[derive(Debug, Clone, Copy)]
33pub enum QueryType {
34 Simple,
35 WithCompound,
36}
37
38pub struct WildcardInfo {
39 pub segment: ErasedSegment,
40 pub tables: Vec<SmolStr>,
41}
42
43#[derive(Debug, Clone)]
44pub struct Selectable<'me> {
45 pub selectable: ErasedSegment,
46 pub dialect: &'me Dialect,
47}
48
49impl Selectable<'_> {
50 pub fn find_alias(&self, table: &str) -> Option<AliasInfo> {
51 self.select_info()
52 .as_ref()?
53 .table_aliases
54 .iter()
55 .find(|&t| t.aliased && t.ref_str == table)
56 .cloned()
57 }
58}
59
60impl Selectable<'_> {
61 pub fn wildcard_info(&self) -> Vec<WildcardInfo> {
62 let Some(select_info) = self.select_info() else {
63 return Vec::new();
64 };
65
66 let mut buff = Vec::new();
67 for seg in select_info.select_targets {
68 if seg
69 .0
70 .child(const { &SyntaxSet::new(&[SyntaxKind::WildcardExpression]) })
71 .is_some()
72 {
73 if seg.0.raw().contains('.') {
74 let table = seg
75 .0
76 .raw()
77 .rsplit_once('.')
78 .map(|x| x.0)
79 .unwrap_or_default()
80 .to_smolstr();
81 buff.push(WildcardInfo {
82 segment: seg.0.clone(),
83 tables: vec![table],
84 });
85 } else {
86 let tables = select_info
87 .table_aliases
88 .iter()
89 .filter(|it| !it.ref_str.is_empty())
90 .map(|it| {
91 if it.aliased {
92 it.ref_str.clone()
93 } else {
94 it.from_expression_element.raw().clone()
95 }
96 })
97 .collect();
98 buff.push(WildcardInfo {
99 segment: seg.0.clone(),
100 tables,
101 });
102 }
103 }
104 }
105
106 buff
107 }
108}
109
110impl Selectable<'_> {
111 pub fn select_info(&self) -> Option<SelectStatementColumnsAndTables> {
112 if self.selectable.is_type(SyntaxKind::SelectStatement) {
113 return get_select_statement_info(&self.selectable, self.dialect.into(), false);
114 }
115
116 let values = Segments::new(self.selectable.clone(), None);
117 let alias_expression = values.children(None).find_first(Some(|it: &ErasedSegment| {
118 it.is_type(SyntaxKind::AliasExpression)
119 }));
120 let name = alias_expression
121 .children(None)
122 .find_first(Some(|it: &ErasedSegment| {
123 matches!(
124 it.get_type(),
125 SyntaxKind::NakedIdentifier | SyntaxKind::QuotedIdentifier,
126 )
127 }));
128
129 let alias_info = AliasInfo {
130 ref_str: if name.is_empty() {
131 SmolStr::new_static("")
132 } else {
133 name.first().unwrap().raw().clone()
134 },
135 segment: name.first().cloned(),
136 aliased: !name.is_empty(),
137 from_expression_element: self.selectable.clone(),
138 alias_expression: alias_expression.first().cloned(),
139 object_reference: None,
140 };
141
142 SelectStatementColumnsAndTables {
143 select_statement: self.selectable.clone(),
144 table_aliases: vec![alias_info],
145 standalone_aliases: Vec::new(),
146 reference_buffer: Vec::new(),
147 select_targets: Vec::new(),
148 col_aliases: Vec::new(),
149 using_cols: Vec::new(),
150 }
151 .into()
152 }
153}
154
155#[derive(Debug, Clone)]
156pub struct Query<'me, T> {
157 pub inner: Rc<RefCell<QueryInner<'me, T>>>,
158}
159
160#[derive(Debug, Clone)]
161pub struct QueryInner<'me, T> {
162 pub query_type: QueryType,
163 pub dialect: &'me Dialect,
164 pub selectables: Vec<Selectable<'me>>,
165 pub ctes: IndexMap<SmolStr, Query<'me, T>>,
166 pub parent: Option<Query<'me, T>>,
167 pub subqueries: Vec<Query<'me, T>>,
168 pub cte_definition_segment: Option<ErasedSegment>,
169 pub cte_name_segment: Option<ErasedSegment>,
170 pub payload: T,
171}
172
173impl<'me, T: Clone + Default> Query<'me, T> {
174 pub fn crawl_sources(
175 &self,
176 segment: ErasedSegment,
177
178 pop: bool,
179 lookup_cte: bool,
180 ) -> Vec<Source<'me, T>> {
181 let mut acc = Vec::new();
182
183 for seg in segment.recursive_crawl(
184 const {
185 &SyntaxSet::new(&[
186 SyntaxKind::TableReference,
187 SyntaxKind::SetExpression,
188 SyntaxKind::SelectStatement,
189 SyntaxKind::ValuesClause,
190 ])
191 },
192 false,
193 &SyntaxSet::EMPTY,
194 false,
195 ) {
196 if seg.is_type(SyntaxKind::TableReference) {
197 let _seg = seg.reference();
198 if !_seg.is_qualified()
199 && lookup_cte
200 && let Some(cte) = self.lookup_cte(seg.raw().as_ref(), pop)
201 {
202 acc.push(Source::Query(cte));
203 }
204 acc.push(Source::TableReference(seg.raw().clone()));
205 } else {
206 acc.push(Source::Query(Query::from_segment(
207 &seg,
208 self.inner.borrow().dialect,
209 Some(self.clone()),
210 )))
211 }
212 }
213
214 if acc.is_empty()
215 && let Some(table_expr) =
216 segment.child(const { &SyntaxSet::new(&[SyntaxKind::TableExpression]) })
217 {
218 return vec![Source::TableReference(table_expr.raw().to_smolstr())];
219 }
220
221 acc
222 }
223
224 #[track_caller]
225 pub fn lookup_cte(&self, name: &str, pop: bool) -> Option<Query<'me, T>> {
226 let cte = if pop {
227 self.inner
228 .borrow_mut()
229 .ctes
230 .shift_remove(&name.to_uppercase_smolstr())
231 } else {
232 self.inner
233 .borrow()
234 .ctes
235 .get(&name.to_uppercase_smolstr())
236 .cloned()
237 };
238
239 cte.or_else(move || {
240 self.inner
241 .borrow_mut()
242 .parent
243 .as_mut()
244 .and_then(|it| it.lookup_cte(name, pop))
245 })
246 }
247
248 fn post_init(&self) {
249 let this = self.clone();
250
251 for subquery in &RefCell::borrow(&self.inner).subqueries {
252 RefCell::borrow_mut(&subquery.inner).parent = this.clone().into();
253 }
254
255 for cte in RefCell::borrow(&self.inner).ctes.values().cloned() {
256 RefCell::borrow_mut(&cte.inner).parent = this.clone().into();
257 }
258 }
259}
260
261impl<T: Default + Clone> Query<'_, T> {
262 pub fn children(&self) -> Vec<Self> {
263 self.inner
264 .borrow()
265 .ctes
266 .values()
267 .chain(self.inner.borrow().subqueries.iter())
268 .cloned()
269 .collect()
270 }
271
272 fn extract_subqueries<'a>(selectable: &Selectable, dialect: &'a Dialect) -> Vec<Query<'a, T>> {
273 let mut acc = Vec::new();
274
275 for subselect in selectable.selectable.recursive_crawl(
276 &SELECTABLE_TYPES,
277 false,
278 &SyntaxSet::EMPTY,
279 false,
280 ) {
281 acc.push(Query::from_segment(&subselect, dialect, None));
282 }
283
284 acc
285 }
286
287 pub fn from_root<'a>(
288 root_segment: &ErasedSegment,
289 dialect: &'a Dialect,
290 ) -> Option<Query<'a, T>> {
291 let stmts = root_segment.recursive_crawl(
292 &SELECTABLE_TYPES,
293 true,
294 &SyntaxSet::single(SyntaxKind::MergeStatement),
295 true,
296 );
297 let selectable_segment = stmts.first()?;
298
299 Some(Query::from_segment(selectable_segment, dialect, None))
300 }
301
302 pub fn from_segment<'a>(
303 segment: &ErasedSegment,
304 dialect: &'a Dialect,
305 parent: Option<Query<'a, T>>,
306 ) -> Query<'a, T> {
307 let mut selectables = Vec::new();
308 let mut subqueries = Vec::new();
309 let mut cte_defs: Vec<ErasedSegment> = Vec::new();
310 let mut query_type = QueryType::Simple;
311
312 if segment.is_type(SyntaxKind::SelectStatement)
313 || SUBSELECT_TYPES.contains(segment.get_type())
314 {
315 selectables.push(Selectable {
316 selectable: segment.clone(),
317 dialect,
318 });
319 } else if segment.is_type(SyntaxKind::SetExpression) {
320 selectables.extend(
321 segment
322 .children(const { &SyntaxSet::new(&[SyntaxKind::SelectStatement]) })
323 .cloned()
324 .map(|selectable| Selectable {
325 selectable,
326 dialect,
327 }),
328 )
329 } else {
330 query_type = QueryType::WithCompound;
331
332 for seg in segment.recursive_crawl(
333 const { &SyntaxSet::new(&[SyntaxKind::SelectStatement]) },
334 false,
335 const { &SyntaxSet::single(SyntaxKind::CommonTableExpression) },
336 true,
337 ) {
338 selectables.push(Selectable {
339 selectable: seg,
340 dialect,
341 });
342 }
343
344 for seg in segment.recursive_crawl(
345 const { &SyntaxSet::new(&[SyntaxKind::CommonTableExpression]) },
346 false,
347 const { &SyntaxSet::single(SyntaxKind::WithCompoundStatement) },
348 true,
349 ) {
350 cte_defs.push(seg);
351 }
352 }
353
354 for selectable in &selectables {
355 subqueries.extend(Self::extract_subqueries(selectable, dialect));
356 }
357
358 let outer_query = Query {
359 inner: Rc::new(RefCell::new(QueryInner {
360 query_type,
361 dialect,
362 selectables,
363 ctes: <_>::default(),
364 parent,
365 subqueries,
366 cte_definition_segment: None,
367 cte_name_segment: None,
368 payload: T::default(),
369 })),
370 };
371
372 outer_query.post_init();
373
374 if cte_defs.is_empty() {
375 return outer_query;
376 }
377
378 let mut ctes = IndexMap::default();
379 for cte in cte_defs {
380 let name_seg = cte.segments()[0].clone();
381 let name = name_seg.raw().to_uppercase_smolstr();
382
383 let queries = cte.recursive_crawl(
384 const { &SELECTABLE_TYPES.union(&SUBSELECT_TYPES) },
385 true,
386 &SyntaxSet::EMPTY,
387 true,
388 );
389
390 if queries.is_empty() {
391 continue;
392 };
393
394 let query = &queries[0];
395 let query = Self::from_segment(query, dialect, outer_query.clone().into());
396
397 RefCell::borrow_mut(&query.inner).cte_definition_segment = cte.into();
398 RefCell::borrow_mut(&query.inner).cte_name_segment = name_seg.into();
399
400 ctes.insert(name, query);
401 }
402
403 RefCell::borrow_mut(&outer_query.inner).ctes = ctes;
404 outer_query
405 }
406}
407
408pub enum Source<'a, T> {
409 TableReference(SmolStr),
410 Query(Query<'a, T>),
411}