1use std::collections::HashSet;
4
5pub fn extract_tables_from_select(stmt: &vibesql_ast::SelectStmt) -> HashSet<String> {
7 let mut tables = HashSet::new();
8
9 if let Some(from_clause) = &stmt.from {
11 extract_from_from_clause(from_clause, &mut tables);
12 }
13
14 for select_item in &stmt.select_list {
16 if let vibesql_ast::SelectItem::Expression { expr, .. } = select_item {
17 extract_from_expression(expr, &mut tables);
18 }
19 }
20
21 if let Some(where_clause) = &stmt.where_clause {
23 extract_from_expression(where_clause, &mut tables);
24 }
25
26 if let Some(group_by) = &stmt.group_by {
28 for expr in group_by.all_expressions() {
29 extract_from_expression(expr, &mut tables);
30 }
31 }
32
33 if let Some(having) = &stmt.having {
35 extract_from_expression(having, &mut tables);
36 }
37
38 if let Some(order_by) = &stmt.order_by {
40 for order_item in order_by {
41 extract_from_expression(&order_item.expr, &mut tables);
42 }
43 }
44
45 if let Some(with_clause) = &stmt.with_clause {
47 for cte in with_clause {
48 let cte_tables = extract_tables_from_select(&cte.query);
49 tables.extend(cte_tables);
50 }
51 }
52
53 if let Some(set_op) = &stmt.set_operation {
55 let right_tables = extract_tables_from_select(&set_op.right);
56 tables.extend(right_tables);
57 }
58
59 tables
60}
61
62fn extract_from_from_clause(from: &vibesql_ast::FromClause, tables: &mut HashSet<String>) {
64 match from {
65 vibesql_ast::FromClause::Table { name, .. } => {
66 let table_name = if let Some(pos) = name.rfind('.') { &name[pos + 1..] } else { name };
68 tables.insert(table_name.to_string());
69 }
70 vibesql_ast::FromClause::Join { left, right, condition, .. } => {
71 extract_from_from_clause(left, tables);
72 extract_from_from_clause(right, tables);
73 if let Some(cond) = condition {
74 extract_from_expression(cond, tables);
75 }
76 }
77 vibesql_ast::FromClause::Subquery { query, .. } => {
78 let subquery_tables = extract_tables_from_select(query);
79 tables.extend(subquery_tables);
80 }
81 vibesql_ast::FromClause::Values { .. } => {
82 }
84 }
85}
86
87fn extract_from_expression(expr: &vibesql_ast::Expression, tables: &mut HashSet<String>) {
89 match expr {
90 vibesql_ast::Expression::ScalarSubquery(stmt) => {
91 let subquery_tables = extract_tables_from_select(stmt);
92 tables.extend(subquery_tables);
93 }
94 vibesql_ast::Expression::BinaryOp { left, right, .. } => {
95 extract_from_expression(left, tables);
96 extract_from_expression(right, tables);
97 }
98 vibesql_ast::Expression::UnaryOp { expr, .. } => {
99 extract_from_expression(expr, tables);
100 }
101 vibesql_ast::Expression::Function { args, .. }
102 | vibesql_ast::Expression::AggregateFunction { args, .. } => {
103 for arg in args {
104 extract_from_expression(arg, tables);
105 }
106 }
107 vibesql_ast::Expression::Case { operand, when_clauses, else_result, .. } => {
108 if let Some(op) = operand {
109 extract_from_expression(op, tables);
110 }
111 for when_clause in when_clauses {
112 for condition in &when_clause.conditions {
113 extract_from_expression(condition, tables);
114 }
115 extract_from_expression(&when_clause.result, tables);
116 }
117 if let Some(else_expr) = else_result {
118 extract_from_expression(else_expr, tables);
119 }
120 }
121 vibesql_ast::Expression::In { expr, subquery, .. } => {
122 extract_from_expression(expr, tables);
123 let subquery_tables = extract_tables_from_select(subquery);
124 tables.extend(subquery_tables);
125 }
126 vibesql_ast::Expression::InList { expr, values, .. } => {
127 extract_from_expression(expr, tables);
128 for val in values {
129 extract_from_expression(val, tables);
130 }
131 }
132 vibesql_ast::Expression::Exists { subquery, .. } => {
133 let subquery_tables = extract_tables_from_select(subquery);
134 tables.extend(subquery_tables);
135 }
136 vibesql_ast::Expression::Between { expr, low, high, .. } => {
137 extract_from_expression(expr, tables);
138 extract_from_expression(low, tables);
139 extract_from_expression(high, tables);
140 }
141 vibesql_ast::Expression::IsNull { expr, .. } => {
142 extract_from_expression(expr, tables);
143 }
144 vibesql_ast::Expression::IsDistinctFrom { left, right, .. } => {
145 extract_from_expression(left, tables);
146 extract_from_expression(right, tables);
147 }
148 vibesql_ast::Expression::IsTruthValue { expr, .. } => {
149 extract_from_expression(expr, tables);
150 }
151 vibesql_ast::Expression::Cast { expr, .. } => {
152 extract_from_expression(expr, tables);
153 }
154 vibesql_ast::Expression::Like { expr, pattern, .. }
155 | vibesql_ast::Expression::Glob { expr, pattern, .. } => {
156 extract_from_expression(expr, tables);
157 extract_from_expression(pattern, tables);
158 }
159 vibesql_ast::Expression::Position { substring, string, .. } => {
160 extract_from_expression(substring, tables);
161 extract_from_expression(string, tables);
162 }
163 vibesql_ast::Expression::Trim { removal_char, string, .. } => {
164 if let Some(removal) = removal_char {
165 extract_from_expression(removal, tables);
166 }
167 extract_from_expression(string, tables);
168 }
169 vibesql_ast::Expression::Extract { expr, .. } => {
170 extract_from_expression(expr, tables);
171 }
172 vibesql_ast::Expression::QuantifiedComparison { expr, subquery, .. } => {
173 extract_from_expression(expr, tables);
174 let subquery_tables = extract_tables_from_select(subquery);
175 tables.extend(subquery_tables);
176 }
177 vibesql_ast::Expression::Conjunction(children)
178 | vibesql_ast::Expression::Disjunction(children)
179 | vibesql_ast::Expression::RowValueConstructor(children) => {
180 for child in children {
181 extract_from_expression(child, tables);
182 }
183 }
184
185 vibesql_ast::Expression::Collate { expr, .. } => {
186 extract_from_expression(expr, tables);
187 }
188
189 vibesql_ast::Expression::Literal(_)
191 | vibesql_ast::Expression::Placeholder(_)
192 | vibesql_ast::Expression::NumberedPlaceholder(_)
193 | vibesql_ast::Expression::NamedPlaceholder(_)
194 | vibesql_ast::Expression::ColumnRef(_)
195 | vibesql_ast::Expression::Wildcard
196 | vibesql_ast::Expression::CurrentDate
197 | vibesql_ast::Expression::CurrentTime { .. }
198 | vibesql_ast::Expression::CurrentTimestamp { .. }
199 | vibesql_ast::Expression::Interval { .. }
200 | vibesql_ast::Expression::Default
201 | vibesql_ast::Expression::DuplicateKeyValue { .. }
202 | vibesql_ast::Expression::WindowFunction { .. }
203 | vibesql_ast::Expression::NextValue { .. }
204 | vibesql_ast::Expression::MatchAgainst { .. }
205 | vibesql_ast::Expression::PseudoVariable { .. }
206 | vibesql_ast::Expression::SessionVariable { .. } => {}
207 }
208}
209
210pub fn extract_tables_from_statement(stmt: &vibesql_ast::Statement) -> HashSet<String> {
212 match stmt {
213 vibesql_ast::Statement::Select(select) => extract_tables_from_select(select),
214 vibesql_ast::Statement::Insert(insert) => {
215 let mut tables = HashSet::new();
216 let table_name = if let Some(pos) = insert.table_name.rfind('.') {
218 &insert.table_name[pos + 1..]
219 } else {
220 &insert.table_name
221 };
222 tables.insert(table_name.to_string());
223
224 match &insert.source {
226 vibesql_ast::InsertSource::Values(values) => {
227 for row in values {
228 for expr in row {
229 extract_from_expression(expr, &mut tables);
230 }
231 }
232 }
233 vibesql_ast::InsertSource::Select(select) => {
234 let select_tables = extract_tables_from_select(select);
235 tables.extend(select_tables);
236 }
237 vibesql_ast::InsertSource::DefaultValues => {
238 }
240 }
241
242 tables
243 }
244 vibesql_ast::Statement::Update(update) => {
245 let mut tables = HashSet::new();
246 let table_name = if let Some(pos) = update.table_name.rfind('.') {
248 &update.table_name[pos + 1..]
249 } else {
250 &update.table_name
251 };
252 tables.insert(table_name.to_string());
253
254 for assignment in &update.assignments {
256 extract_from_expression(&assignment.value, &mut tables);
257 }
258
259 if let Some(vibesql_ast::WhereClause::Condition(expr)) = &update.where_clause {
261 extract_from_expression(expr, &mut tables);
262 }
263
264 tables
265 }
266 vibesql_ast::Statement::Delete(delete) => {
267 let mut tables = HashSet::new();
268 let table_name = if let Some(pos) = delete.table_name.rfind('.') {
270 &delete.table_name[pos + 1..]
271 } else {
272 &delete.table_name
273 };
274 tables.insert(table_name.to_string());
275
276 if let Some(vibesql_ast::WhereClause::Condition(expr)) = &delete.where_clause {
278 extract_from_expression(expr, &mut tables);
279 }
280
281 tables
282 }
283 _ => HashSet::new(),
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use vibesql_parser::Parser;
291
292 use super::*;
293
294 #[test]
295 fn test_extract_simple_select() {
296 let sql = "SELECT * FROM users";
297 let stmt = Parser::parse_sql(sql).unwrap();
298
299 if let vibesql_ast::Statement::Select(select) = stmt {
300 let tables = extract_tables_from_select(&select);
301 assert_eq!(tables.len(), 1);
302 assert!(tables.contains("users"));
304 } else {
305 panic!("Expected SELECT statement");
306 }
307 }
308
309 #[test]
310 fn test_extract_join() {
311 let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
312 let stmt = Parser::parse_sql(sql).unwrap();
313
314 if let vibesql_ast::Statement::Select(select) = stmt {
315 let tables = extract_tables_from_select(&select);
316 assert_eq!(tables.len(), 2);
317 assert!(tables.contains("users"));
319 assert!(tables.contains("orders"));
320 } else {
321 panic!("Expected SELECT statement");
322 }
323 }
324
325 #[test]
326 fn test_extract_qualified_table_name() {
327 let sql = "SELECT * FROM public.users";
328 let stmt = Parser::parse_sql(sql).unwrap();
329
330 if let vibesql_ast::Statement::Select(select) = stmt {
331 let tables = extract_tables_from_select(&select);
332 assert_eq!(tables.len(), 1);
333 assert!(tables.contains("users"));
335 } else {
336 panic!("Expected SELECT statement");
337 }
338 }
339
340 #[test]
341 fn test_extract_subquery_in_from() {
342 let sql = "SELECT * FROM (SELECT * FROM users) AS u";
343 let stmt = Parser::parse_sql(sql).unwrap();
344
345 if let vibesql_ast::Statement::Select(select) = stmt {
346 let tables = extract_tables_from_select(&select);
347 assert_eq!(tables.len(), 1);
348 assert!(tables.contains("users"));
349 } else {
350 panic!("Expected SELECT statement");
351 }
352 }
353
354 #[test]
355 fn test_extract_from_insert() {
356 let sql = "INSERT INTO users VALUES (1, 'Alice')";
357 let stmt = Parser::parse_sql(sql).unwrap();
358 let tables = extract_tables_from_statement(&stmt);
359
360 assert_eq!(tables.len(), 1);
361 assert!(tables.contains("users"));
362 }
363
364 #[test]
365 fn test_extract_from_update() {
366 let sql = "UPDATE users SET name = 'Bob' WHERE id = 1";
367 let stmt = Parser::parse_sql(sql).unwrap();
368 let tables = extract_tables_from_statement(&stmt);
369
370 assert_eq!(tables.len(), 1);
371 assert!(tables.contains("users"));
372 }
373
374 #[test]
375 fn test_extract_from_delete() {
376 let sql = "DELETE FROM users WHERE id = 1";
377 let stmt = Parser::parse_sql(sql).unwrap();
378 let tables = extract_tables_from_statement(&stmt);
379
380 assert_eq!(tables.len(), 1);
381 assert!(tables.contains("users"));
382 }
383}