vibesql_executor/cache/
table_extractor.rs1use std::collections::HashSet;
4
5pub fn extract_tables_from_select(stmt: &vibesql_ast::SelectStmt) -> HashSet<String> {
7 let mut tables = HashSet::new();
8
9 if let Some(from_clause) = &stmt.from {
11 extract_from_from_clause(from_clause, &mut tables);
12 }
13
14 for select_item in &stmt.select_list {
16 if let vibesql_ast::SelectItem::Expression { expr, .. } = select_item {
17 extract_from_expression(expr, &mut tables);
18 }
19 }
20
21 if let Some(where_clause) = &stmt.where_clause {
23 extract_from_expression(where_clause, &mut tables);
24 }
25
26 if let Some(group_by) = &stmt.group_by {
28 for expr in group_by.all_expressions() {
29 extract_from_expression(expr, &mut tables);
30 }
31 }
32
33 if let Some(having) = &stmt.having {
35 extract_from_expression(having, &mut tables);
36 }
37
38 if let Some(order_by) = &stmt.order_by {
40 for order_item in order_by {
41 extract_from_expression(&order_item.expr, &mut tables);
42 }
43 }
44
45 if let Some(with_clause) = &stmt.with_clause {
47 for cte in with_clause {
48 let cte_tables = extract_tables_from_select(&cte.query);
49 tables.extend(cte_tables);
50 }
51 }
52
53 if let Some(set_op) = &stmt.set_operation {
55 let right_tables = extract_tables_from_select(&set_op.right);
56 tables.extend(right_tables);
57 }
58
59 tables
60}
61
62fn extract_from_from_clause(from: &vibesql_ast::FromClause, tables: &mut HashSet<String>) {
64 match from {
65 vibesql_ast::FromClause::Table { name, .. } => {
66 let table_name = if let Some(pos) = name.rfind('.') { &name[pos + 1..] } else { name };
68 tables.insert(table_name.to_string());
69 }
70 vibesql_ast::FromClause::Join { left, right, condition, .. } => {
71 extract_from_from_clause(left, tables);
72 extract_from_from_clause(right, tables);
73 if let Some(cond) = condition {
74 extract_from_expression(cond, tables);
75 }
76 }
77 vibesql_ast::FromClause::Subquery { query, .. } => {
78 let subquery_tables = extract_tables_from_select(query);
79 tables.extend(subquery_tables);
80 }
81 }
82}
83
84fn extract_from_expression(expr: &vibesql_ast::Expression, tables: &mut HashSet<String>) {
86 match expr {
87 vibesql_ast::Expression::ScalarSubquery(stmt) => {
88 let subquery_tables = extract_tables_from_select(stmt);
89 tables.extend(subquery_tables);
90 }
91 vibesql_ast::Expression::BinaryOp { left, right, .. } => {
92 extract_from_expression(left, tables);
93 extract_from_expression(right, tables);
94 }
95 vibesql_ast::Expression::UnaryOp { expr, .. } => {
96 extract_from_expression(expr, tables);
97 }
98 vibesql_ast::Expression::Function { args, .. }
99 | vibesql_ast::Expression::AggregateFunction { args, .. } => {
100 for arg in args {
101 extract_from_expression(arg, tables);
102 }
103 }
104 vibesql_ast::Expression::Case { operand, when_clauses, else_result, .. } => {
105 if let Some(op) = operand {
106 extract_from_expression(op, tables);
107 }
108 for when_clause in when_clauses {
109 for condition in &when_clause.conditions {
110 extract_from_expression(condition, tables);
111 }
112 extract_from_expression(&when_clause.result, tables);
113 }
114 if let Some(else_expr) = else_result {
115 extract_from_expression(else_expr, tables);
116 }
117 }
118 vibesql_ast::Expression::In { expr, subquery, .. } => {
119 extract_from_expression(expr, tables);
120 let subquery_tables = extract_tables_from_select(subquery);
121 tables.extend(subquery_tables);
122 }
123 vibesql_ast::Expression::InList { expr, values, .. } => {
124 extract_from_expression(expr, tables);
125 for val in values {
126 extract_from_expression(val, tables);
127 }
128 }
129 vibesql_ast::Expression::Exists { subquery, .. } => {
130 let subquery_tables = extract_tables_from_select(subquery);
131 tables.extend(subquery_tables);
132 }
133 vibesql_ast::Expression::Between { expr, low, high, .. } => {
134 extract_from_expression(expr, tables);
135 extract_from_expression(low, tables);
136 extract_from_expression(high, tables);
137 }
138 vibesql_ast::Expression::IsNull { expr, .. } => {
139 extract_from_expression(expr, tables);
140 }
141 vibesql_ast::Expression::Cast { expr, .. } => {
142 extract_from_expression(expr, tables);
143 }
144 vibesql_ast::Expression::Like { expr, pattern, .. } => {
145 extract_from_expression(expr, tables);
146 extract_from_expression(pattern, tables);
147 }
148 vibesql_ast::Expression::Position { substring, string, .. } => {
149 extract_from_expression(substring, tables);
150 extract_from_expression(string, tables);
151 }
152 vibesql_ast::Expression::Trim { removal_char, string, .. } => {
153 if let Some(removal) = removal_char {
154 extract_from_expression(removal, tables);
155 }
156 extract_from_expression(string, tables);
157 }
158 vibesql_ast::Expression::Extract { expr, .. } => {
159 extract_from_expression(expr, tables);
160 }
161 vibesql_ast::Expression::QuantifiedComparison { expr, subquery, .. } => {
162 extract_from_expression(expr, tables);
163 let subquery_tables = extract_tables_from_select(subquery);
164 tables.extend(subquery_tables);
165 }
166 vibesql_ast::Expression::Conjunction(children) | vibesql_ast::Expression::Disjunction(children) => {
167 for child in children {
168 extract_from_expression(child, tables);
169 }
170 }
171
172 vibesql_ast::Expression::Literal(_)
174 | vibesql_ast::Expression::Placeholder(_)
175 | vibesql_ast::Expression::NumberedPlaceholder(_)
176 | vibesql_ast::Expression::NamedPlaceholder(_)
177 | vibesql_ast::Expression::ColumnRef { .. }
178 | vibesql_ast::Expression::Wildcard
179 | vibesql_ast::Expression::CurrentDate
180 | vibesql_ast::Expression::CurrentTime { .. }
181 | vibesql_ast::Expression::CurrentTimestamp { .. }
182 | vibesql_ast::Expression::Interval { .. }
183 | vibesql_ast::Expression::Default
184 | vibesql_ast::Expression::DuplicateKeyValue { .. }
185 | vibesql_ast::Expression::WindowFunction { .. }
186 | vibesql_ast::Expression::NextValue { .. }
187 | vibesql_ast::Expression::MatchAgainst { .. }
188 | vibesql_ast::Expression::PseudoVariable { .. }
189 | vibesql_ast::Expression::SessionVariable { .. } => {}
190 }
191}
192
193pub fn extract_tables_from_statement(stmt: &vibesql_ast::Statement) -> HashSet<String> {
195 match stmt {
196 vibesql_ast::Statement::Select(select) => extract_tables_from_select(select),
197 vibesql_ast::Statement::Insert(insert) => {
198 let mut tables = HashSet::new();
199 let table_name = if let Some(pos) = insert.table_name.rfind('.') {
201 &insert.table_name[pos + 1..]
202 } else {
203 &insert.table_name
204 };
205 tables.insert(table_name.to_string());
206
207 match &insert.source {
209 vibesql_ast::InsertSource::Values(values) => {
210 for row in values {
211 for expr in row {
212 extract_from_expression(expr, &mut tables);
213 }
214 }
215 }
216 vibesql_ast::InsertSource::Select(select) => {
217 let select_tables = extract_tables_from_select(select);
218 tables.extend(select_tables);
219 }
220 }
221
222 tables
223 }
224 vibesql_ast::Statement::Update(update) => {
225 let mut tables = HashSet::new();
226 let table_name = if let Some(pos) = update.table_name.rfind('.') {
228 &update.table_name[pos + 1..]
229 } else {
230 &update.table_name
231 };
232 tables.insert(table_name.to_string());
233
234 for assignment in &update.assignments {
236 extract_from_expression(&assignment.value, &mut tables);
237 }
238
239 if let Some(vibesql_ast::WhereClause::Condition(expr)) = &update.where_clause {
241 extract_from_expression(expr, &mut tables);
242 }
243
244 tables
245 }
246 vibesql_ast::Statement::Delete(delete) => {
247 let mut tables = HashSet::new();
248 let table_name = if let Some(pos) = delete.table_name.rfind('.') {
250 &delete.table_name[pos + 1..]
251 } else {
252 &delete.table_name
253 };
254 tables.insert(table_name.to_string());
255
256 if let Some(vibesql_ast::WhereClause::Condition(expr)) = &delete.where_clause {
258 extract_from_expression(expr, &mut tables);
259 }
260
261 tables
262 }
263 _ => HashSet::new(),
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use vibesql_parser::Parser;
271
272 use super::*;
273
274 #[test]
275 fn test_extract_simple_select() {
276 let sql = "SELECT * FROM users";
277 let stmt = Parser::parse_sql(sql).unwrap();
278
279 if let vibesql_ast::Statement::Select(select) = stmt {
280 let tables = extract_tables_from_select(&select);
281 assert_eq!(tables.len(), 1);
282 assert!(tables.contains("USERS"));
284 } else {
285 panic!("Expected SELECT statement");
286 }
287 }
288
289 #[test]
290 fn test_extract_join() {
291 let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
292 let stmt = Parser::parse_sql(sql).unwrap();
293
294 if let vibesql_ast::Statement::Select(select) = stmt {
295 let tables = extract_tables_from_select(&select);
296 assert_eq!(tables.len(), 2);
297 assert!(tables.contains("USERS"));
299 assert!(tables.contains("ORDERS"));
300 } else {
301 panic!("Expected SELECT statement");
302 }
303 }
304
305 #[test]
306 fn test_extract_qualified_table_name() {
307 let sql = "SELECT * FROM public.users";
308 let stmt = Parser::parse_sql(sql).unwrap();
309
310 if let vibesql_ast::Statement::Select(select) = stmt {
311 let tables = extract_tables_from_select(&select);
312 assert_eq!(tables.len(), 1);
313 assert!(tables.contains("USERS"));
315 } else {
316 panic!("Expected SELECT statement");
317 }
318 }
319
320 #[test]
321 fn test_extract_subquery_in_from() {
322 let sql = "SELECT * FROM (SELECT * FROM users) AS u";
323 let stmt = Parser::parse_sql(sql).unwrap();
324
325 if let vibesql_ast::Statement::Select(select) = stmt {
326 let tables = extract_tables_from_select(&select);
327 assert_eq!(tables.len(), 1);
328 assert!(tables.contains("USERS"));
329 } else {
330 panic!("Expected SELECT statement");
331 }
332 }
333
334 #[test]
335 fn test_extract_from_insert() {
336 let sql = "INSERT INTO users VALUES (1, 'Alice')";
337 let stmt = Parser::parse_sql(sql).unwrap();
338 let tables = extract_tables_from_statement(&stmt);
339
340 assert_eq!(tables.len(), 1);
341 assert!(tables.contains("USERS"));
342 }
343
344 #[test]
345 fn test_extract_from_update() {
346 let sql = "UPDATE users SET name = 'Bob' WHERE id = 1";
347 let stmt = Parser::parse_sql(sql).unwrap();
348 let tables = extract_tables_from_statement(&stmt);
349
350 assert_eq!(tables.len(), 1);
351 assert!(tables.contains("USERS"));
352 }
353
354 #[test]
355 fn test_extract_from_delete() {
356 let sql = "DELETE FROM users WHERE id = 1";
357 let stmt = Parser::parse_sql(sql).unwrap();
358 let tables = extract_tables_from_statement(&stmt);
359
360 assert_eq!(tables.len(), 1);
361 assert!(tables.contains("USERS"));
362 }
363}