vibesql_executor/cache/
table_extractor.rs1use std::collections::HashSet;
4
5pub fn extract_tables_from_select(stmt: &vibesql_ast::SelectStmt) -> HashSet<String> {
7 let mut tables = HashSet::new();
8
9 if let Some(from_clause) = &stmt.from {
11 extract_from_from_clause(from_clause, &mut tables);
12 }
13
14 for select_item in &stmt.select_list {
16 if let vibesql_ast::SelectItem::Expression { expr, .. } = select_item {
17 extract_from_expression(expr, &mut tables);
18 }
19 }
20
21 if let Some(where_clause) = &stmt.where_clause {
23 extract_from_expression(where_clause, &mut tables);
24 }
25
26 if let Some(group_by) = &stmt.group_by {
28 for expr in group_by.all_expressions() {
29 extract_from_expression(expr, &mut tables);
30 }
31 }
32
33 if let Some(having) = &stmt.having {
35 extract_from_expression(having, &mut tables);
36 }
37
38 if let Some(order_by) = &stmt.order_by {
40 for order_item in order_by {
41 extract_from_expression(&order_item.expr, &mut tables);
42 }
43 }
44
45 if let Some(with_clause) = &stmt.with_clause {
47 for cte in with_clause {
48 let cte_tables = extract_tables_from_select(&cte.query);
49 tables.extend(cte_tables);
50 }
51 }
52
53 if let Some(set_op) = &stmt.set_operation {
55 let right_tables = extract_tables_from_select(&set_op.right);
56 tables.extend(right_tables);
57 }
58
59 tables
60}
61
62fn extract_from_from_clause(from: &vibesql_ast::FromClause, tables: &mut HashSet<String>) {
64 match from {
65 vibesql_ast::FromClause::Table { name, .. } => {
66 let table_name = if let Some(pos) = name.rfind('.') { &name[pos + 1..] } else { name };
68 tables.insert(table_name.to_string());
69 }
70 vibesql_ast::FromClause::Join { left, right, condition, .. } => {
71 extract_from_from_clause(left, tables);
72 extract_from_from_clause(right, tables);
73 if let Some(cond) = condition {
74 extract_from_expression(cond, tables);
75 }
76 }
77 vibesql_ast::FromClause::Subquery { query, .. } => {
78 let subquery_tables = extract_tables_from_select(query);
79 tables.extend(subquery_tables);
80 }
81 }
82}
83
84fn extract_from_expression(expr: &vibesql_ast::Expression, tables: &mut HashSet<String>) {
86 match expr {
87 vibesql_ast::Expression::ScalarSubquery(stmt) => {
88 let subquery_tables = extract_tables_from_select(stmt);
89 tables.extend(subquery_tables);
90 }
91 vibesql_ast::Expression::BinaryOp { left, right, .. } => {
92 extract_from_expression(left, tables);
93 extract_from_expression(right, tables);
94 }
95 vibesql_ast::Expression::UnaryOp { expr, .. } => {
96 extract_from_expression(expr, tables);
97 }
98 vibesql_ast::Expression::Function { args, .. }
99 | vibesql_ast::Expression::AggregateFunction { args, .. } => {
100 for arg in args {
101 extract_from_expression(arg, tables);
102 }
103 }
104 vibesql_ast::Expression::Case { operand, when_clauses, else_result, .. } => {
105 if let Some(op) = operand {
106 extract_from_expression(op, tables);
107 }
108 for when_clause in when_clauses {
109 for condition in &when_clause.conditions {
110 extract_from_expression(condition, tables);
111 }
112 extract_from_expression(&when_clause.result, tables);
113 }
114 if let Some(else_expr) = else_result {
115 extract_from_expression(else_expr, tables);
116 }
117 }
118 vibesql_ast::Expression::In { expr, subquery, .. } => {
119 extract_from_expression(expr, tables);
120 let subquery_tables = extract_tables_from_select(subquery);
121 tables.extend(subquery_tables);
122 }
123 vibesql_ast::Expression::InList { expr, values, .. } => {
124 extract_from_expression(expr, tables);
125 for val in values {
126 extract_from_expression(val, tables);
127 }
128 }
129 vibesql_ast::Expression::Exists { subquery, .. } => {
130 let subquery_tables = extract_tables_from_select(subquery);
131 tables.extend(subquery_tables);
132 }
133 vibesql_ast::Expression::Between { expr, low, high, .. } => {
134 extract_from_expression(expr, tables);
135 extract_from_expression(low, tables);
136 extract_from_expression(high, tables);
137 }
138 vibesql_ast::Expression::IsNull { expr, .. } => {
139 extract_from_expression(expr, tables);
140 }
141 vibesql_ast::Expression::Cast { expr, .. } => {
142 extract_from_expression(expr, tables);
143 }
144 vibesql_ast::Expression::Like { expr, pattern, .. } => {
145 extract_from_expression(expr, tables);
146 extract_from_expression(pattern, tables);
147 }
148 vibesql_ast::Expression::Position { substring, string, .. } => {
149 extract_from_expression(substring, tables);
150 extract_from_expression(string, tables);
151 }
152 vibesql_ast::Expression::Trim { removal_char, string, .. } => {
153 if let Some(removal) = removal_char {
154 extract_from_expression(removal, tables);
155 }
156 extract_from_expression(string, tables);
157 }
158 vibesql_ast::Expression::Extract { expr, .. } => {
159 extract_from_expression(expr, tables);
160 }
161 vibesql_ast::Expression::QuantifiedComparison { expr, subquery, .. } => {
162 extract_from_expression(expr, tables);
163 let subquery_tables = extract_tables_from_select(subquery);
164 tables.extend(subquery_tables);
165 }
166 vibesql_ast::Expression::Conjunction(children)
167 | vibesql_ast::Expression::Disjunction(children) => {
168 for child in children {
169 extract_from_expression(child, tables);
170 }
171 }
172
173 vibesql_ast::Expression::Literal(_)
175 | vibesql_ast::Expression::Placeholder(_)
176 | vibesql_ast::Expression::NumberedPlaceholder(_)
177 | vibesql_ast::Expression::NamedPlaceholder(_)
178 | vibesql_ast::Expression::ColumnRef { .. }
179 | vibesql_ast::Expression::Wildcard
180 | vibesql_ast::Expression::CurrentDate
181 | vibesql_ast::Expression::CurrentTime { .. }
182 | vibesql_ast::Expression::CurrentTimestamp { .. }
183 | vibesql_ast::Expression::Interval { .. }
184 | vibesql_ast::Expression::Default
185 | vibesql_ast::Expression::DuplicateKeyValue { .. }
186 | vibesql_ast::Expression::WindowFunction { .. }
187 | vibesql_ast::Expression::NextValue { .. }
188 | vibesql_ast::Expression::MatchAgainst { .. }
189 | vibesql_ast::Expression::PseudoVariable { .. }
190 | vibesql_ast::Expression::SessionVariable { .. } => {}
191 }
192}
193
194pub fn extract_tables_from_statement(stmt: &vibesql_ast::Statement) -> HashSet<String> {
196 match stmt {
197 vibesql_ast::Statement::Select(select) => extract_tables_from_select(select),
198 vibesql_ast::Statement::Insert(insert) => {
199 let mut tables = HashSet::new();
200 let table_name = if let Some(pos) = insert.table_name.rfind('.') {
202 &insert.table_name[pos + 1..]
203 } else {
204 &insert.table_name
205 };
206 tables.insert(table_name.to_string());
207
208 match &insert.source {
210 vibesql_ast::InsertSource::Values(values) => {
211 for row in values {
212 for expr in row {
213 extract_from_expression(expr, &mut tables);
214 }
215 }
216 }
217 vibesql_ast::InsertSource::Select(select) => {
218 let select_tables = extract_tables_from_select(select);
219 tables.extend(select_tables);
220 }
221 }
222
223 tables
224 }
225 vibesql_ast::Statement::Update(update) => {
226 let mut tables = HashSet::new();
227 let table_name = if let Some(pos) = update.table_name.rfind('.') {
229 &update.table_name[pos + 1..]
230 } else {
231 &update.table_name
232 };
233 tables.insert(table_name.to_string());
234
235 for assignment in &update.assignments {
237 extract_from_expression(&assignment.value, &mut tables);
238 }
239
240 if let Some(vibesql_ast::WhereClause::Condition(expr)) = &update.where_clause {
242 extract_from_expression(expr, &mut tables);
243 }
244
245 tables
246 }
247 vibesql_ast::Statement::Delete(delete) => {
248 let mut tables = HashSet::new();
249 let table_name = if let Some(pos) = delete.table_name.rfind('.') {
251 &delete.table_name[pos + 1..]
252 } else {
253 &delete.table_name
254 };
255 tables.insert(table_name.to_string());
256
257 if let Some(vibesql_ast::WhereClause::Condition(expr)) = &delete.where_clause {
259 extract_from_expression(expr, &mut tables);
260 }
261
262 tables
263 }
264 _ => HashSet::new(),
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use vibesql_parser::Parser;
272
273 use super::*;
274
275 #[test]
276 fn test_extract_simple_select() {
277 let sql = "SELECT * FROM users";
278 let stmt = Parser::parse_sql(sql).unwrap();
279
280 if let vibesql_ast::Statement::Select(select) = stmt {
281 let tables = extract_tables_from_select(&select);
282 assert_eq!(tables.len(), 1);
283 assert!(tables.contains("USERS"));
285 } else {
286 panic!("Expected SELECT statement");
287 }
288 }
289
290 #[test]
291 fn test_extract_join() {
292 let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
293 let stmt = Parser::parse_sql(sql).unwrap();
294
295 if let vibesql_ast::Statement::Select(select) = stmt {
296 let tables = extract_tables_from_select(&select);
297 assert_eq!(tables.len(), 2);
298 assert!(tables.contains("USERS"));
300 assert!(tables.contains("ORDERS"));
301 } else {
302 panic!("Expected SELECT statement");
303 }
304 }
305
306 #[test]
307 fn test_extract_qualified_table_name() {
308 let sql = "SELECT * FROM public.users";
309 let stmt = Parser::parse_sql(sql).unwrap();
310
311 if let vibesql_ast::Statement::Select(select) = stmt {
312 let tables = extract_tables_from_select(&select);
313 assert_eq!(tables.len(), 1);
314 assert!(tables.contains("USERS"));
316 } else {
317 panic!("Expected SELECT statement");
318 }
319 }
320
321 #[test]
322 fn test_extract_subquery_in_from() {
323 let sql = "SELECT * FROM (SELECT * FROM users) AS u";
324 let stmt = Parser::parse_sql(sql).unwrap();
325
326 if let vibesql_ast::Statement::Select(select) = stmt {
327 let tables = extract_tables_from_select(&select);
328 assert_eq!(tables.len(), 1);
329 assert!(tables.contains("USERS"));
330 } else {
331 panic!("Expected SELECT statement");
332 }
333 }
334
335 #[test]
336 fn test_extract_from_insert() {
337 let sql = "INSERT INTO users VALUES (1, 'Alice')";
338 let stmt = Parser::parse_sql(sql).unwrap();
339 let tables = extract_tables_from_statement(&stmt);
340
341 assert_eq!(tables.len(), 1);
342 assert!(tables.contains("USERS"));
343 }
344
345 #[test]
346 fn test_extract_from_update() {
347 let sql = "UPDATE users SET name = 'Bob' WHERE id = 1";
348 let stmt = Parser::parse_sql(sql).unwrap();
349 let tables = extract_tables_from_statement(&stmt);
350
351 assert_eq!(tables.len(), 1);
352 assert!(tables.contains("USERS"));
353 }
354
355 #[test]
356 fn test_extract_from_delete() {
357 let sql = "DELETE FROM users WHERE id = 1";
358 let stmt = Parser::parse_sql(sql).unwrap();
359 let tables = extract_tables_from_statement(&stmt);
360
361 assert_eq!(tables.len(), 1);
362 assert!(tables.contains("USERS"));
363 }
364}