vibesql_executor/cache/
table_extractor.rs

1//! Extract table names from AST for cache invalidation
2
3use std::collections::HashSet;
4
5/// Extract all table names referenced in a SELECT statement
6pub fn extract_tables_from_select(stmt: &vibesql_ast::SelectStmt) -> HashSet<String> {
7    let mut tables = HashSet::new();
8
9    // Extract from FROM clause
10    if let Some(from_clause) = &stmt.from {
11        extract_from_from_clause(from_clause, &mut tables);
12    }
13
14    // Extract from subqueries in SELECT list
15    for select_item in &stmt.select_list {
16        if let vibesql_ast::SelectItem::Expression { expr, .. } = select_item {
17            extract_from_expression(expr, &mut tables);
18        }
19    }
20
21    // Extract from WHERE clause
22    if let Some(where_clause) = &stmt.where_clause {
23        extract_from_expression(where_clause, &mut tables);
24    }
25
26    // Extract from GROUP BY
27    if let Some(group_by) = &stmt.group_by {
28        for expr in group_by.all_expressions() {
29            extract_from_expression(expr, &mut tables);
30        }
31    }
32
33    // Extract from HAVING
34    if let Some(having) = &stmt.having {
35        extract_from_expression(having, &mut tables);
36    }
37
38    // Extract from ORDER BY
39    if let Some(order_by) = &stmt.order_by {
40        for order_item in order_by {
41            extract_from_expression(&order_item.expr, &mut tables);
42        }
43    }
44
45    // Extract from CTEs (WITH clause)
46    if let Some(with_clause) = &stmt.with_clause {
47        for cte in with_clause {
48            let cte_tables = extract_tables_from_select(&cte.query);
49            tables.extend(cte_tables);
50        }
51    }
52
53    // Extract from set operations (UNION, INTERSECT, EXCEPT)
54    if let Some(set_op) = &stmt.set_operation {
55        let right_tables = extract_tables_from_select(&set_op.right);
56        tables.extend(right_tables);
57    }
58
59    tables
60}
61
62/// Extract table names from a FROM clause
63fn extract_from_from_clause(from: &vibesql_ast::FromClause, tables: &mut HashSet<String>) {
64    match from {
65        vibesql_ast::FromClause::Table { name, .. } => {
66            // Handle qualified names (schema.table) - we want just the table name
67            let table_name = if let Some(pos) = name.rfind('.') { &name[pos + 1..] } else { name };
68            tables.insert(table_name.to_string());
69        }
70        vibesql_ast::FromClause::Join { left, right, condition, .. } => {
71            extract_from_from_clause(left, tables);
72            extract_from_from_clause(right, tables);
73            if let Some(cond) = condition {
74                extract_from_expression(cond, tables);
75            }
76        }
77        vibesql_ast::FromClause::Subquery { query, .. } => {
78            let subquery_tables = extract_tables_from_select(query);
79            tables.extend(subquery_tables);
80        }
81    }
82}
83
84/// Extract table names from an expression (for subqueries)
85fn extract_from_expression(expr: &vibesql_ast::Expression, tables: &mut HashSet<String>) {
86    match expr {
87        vibesql_ast::Expression::ScalarSubquery(stmt) => {
88            let subquery_tables = extract_tables_from_select(stmt);
89            tables.extend(subquery_tables);
90        }
91        vibesql_ast::Expression::BinaryOp { left, right, .. } => {
92            extract_from_expression(left, tables);
93            extract_from_expression(right, tables);
94        }
95        vibesql_ast::Expression::UnaryOp { expr, .. } => {
96            extract_from_expression(expr, tables);
97        }
98        vibesql_ast::Expression::Function { args, .. }
99        | vibesql_ast::Expression::AggregateFunction { args, .. } => {
100            for arg in args {
101                extract_from_expression(arg, tables);
102            }
103        }
104        vibesql_ast::Expression::Case { operand, when_clauses, else_result, .. } => {
105            if let Some(op) = operand {
106                extract_from_expression(op, tables);
107            }
108            for when_clause in when_clauses {
109                for condition in &when_clause.conditions {
110                    extract_from_expression(condition, tables);
111                }
112                extract_from_expression(&when_clause.result, tables);
113            }
114            if let Some(else_expr) = else_result {
115                extract_from_expression(else_expr, tables);
116            }
117        }
118        vibesql_ast::Expression::In { expr, subquery, .. } => {
119            extract_from_expression(expr, tables);
120            let subquery_tables = extract_tables_from_select(subquery);
121            tables.extend(subquery_tables);
122        }
123        vibesql_ast::Expression::InList { expr, values, .. } => {
124            extract_from_expression(expr, tables);
125            for val in values {
126                extract_from_expression(val, tables);
127            }
128        }
129        vibesql_ast::Expression::Exists { subquery, .. } => {
130            let subquery_tables = extract_tables_from_select(subquery);
131            tables.extend(subquery_tables);
132        }
133        vibesql_ast::Expression::Between { expr, low, high, .. } => {
134            extract_from_expression(expr, tables);
135            extract_from_expression(low, tables);
136            extract_from_expression(high, tables);
137        }
138        vibesql_ast::Expression::IsNull { expr, .. } => {
139            extract_from_expression(expr, tables);
140        }
141        vibesql_ast::Expression::Cast { expr, .. } => {
142            extract_from_expression(expr, tables);
143        }
144        vibesql_ast::Expression::Like { expr, pattern, .. } => {
145            extract_from_expression(expr, tables);
146            extract_from_expression(pattern, tables);
147        }
148        vibesql_ast::Expression::Position { substring, string, .. } => {
149            extract_from_expression(substring, tables);
150            extract_from_expression(string, tables);
151        }
152        vibesql_ast::Expression::Trim { removal_char, string, .. } => {
153            if let Some(removal) = removal_char {
154                extract_from_expression(removal, tables);
155            }
156            extract_from_expression(string, tables);
157        }
158        vibesql_ast::Expression::Extract { expr, .. } => {
159            extract_from_expression(expr, tables);
160        }
161        vibesql_ast::Expression::QuantifiedComparison { expr, subquery, .. } => {
162            extract_from_expression(expr, tables);
163            let subquery_tables = extract_tables_from_select(subquery);
164            tables.extend(subquery_tables);
165        }
166        vibesql_ast::Expression::Conjunction(children) | vibesql_ast::Expression::Disjunction(children) => {
167            for child in children {
168                extract_from_expression(child, tables);
169            }
170        }
171
172        // Leaf expressions - no tables to extract
173        vibesql_ast::Expression::Literal(_)
174        | vibesql_ast::Expression::Placeholder(_)
175        | vibesql_ast::Expression::NumberedPlaceholder(_)
176        | vibesql_ast::Expression::NamedPlaceholder(_)
177        | vibesql_ast::Expression::ColumnRef { .. }
178        | vibesql_ast::Expression::Wildcard
179        | vibesql_ast::Expression::CurrentDate
180        | vibesql_ast::Expression::CurrentTime { .. }
181        | vibesql_ast::Expression::CurrentTimestamp { .. }
182        | vibesql_ast::Expression::Interval { .. }
183        | vibesql_ast::Expression::Default
184        | vibesql_ast::Expression::DuplicateKeyValue { .. }
185        | vibesql_ast::Expression::WindowFunction { .. }
186        | vibesql_ast::Expression::NextValue { .. }
187        | vibesql_ast::Expression::MatchAgainst { .. }
188        | vibesql_ast::Expression::PseudoVariable { .. }
189        | vibesql_ast::Expression::SessionVariable { .. } => {}
190    }
191}
192
193/// Extract table names from any statement (for comprehensive cache invalidation)
194pub fn extract_tables_from_statement(stmt: &vibesql_ast::Statement) -> HashSet<String> {
195    match stmt {
196        vibesql_ast::Statement::Select(select) => extract_tables_from_select(select),
197        vibesql_ast::Statement::Insert(insert) => {
198            let mut tables = HashSet::new();
199            // Extract table name being inserted into
200            let table_name = if let Some(pos) = insert.table_name.rfind('.') {
201                &insert.table_name[pos + 1..]
202            } else {
203                &insert.table_name
204            };
205            tables.insert(table_name.to_string());
206
207            // Extract from source (VALUES or SELECT)
208            match &insert.source {
209                vibesql_ast::InsertSource::Values(values) => {
210                    for row in values {
211                        for expr in row {
212                            extract_from_expression(expr, &mut tables);
213                        }
214                    }
215                }
216                vibesql_ast::InsertSource::Select(select) => {
217                    let select_tables = extract_tables_from_select(select);
218                    tables.extend(select_tables);
219                }
220            }
221
222            tables
223        }
224        vibesql_ast::Statement::Update(update) => {
225            let mut tables = HashSet::new();
226            // Extract table being updated
227            let table_name = if let Some(pos) = update.table_name.rfind('.') {
228                &update.table_name[pos + 1..]
229            } else {
230                &update.table_name
231            };
232            tables.insert(table_name.to_string());
233
234            // Extract from SET assignments
235            for assignment in &update.assignments {
236                extract_from_expression(&assignment.value, &mut tables);
237            }
238
239            // Extract from WHERE clause
240            if let Some(vibesql_ast::WhereClause::Condition(expr)) = &update.where_clause {
241                extract_from_expression(expr, &mut tables);
242            }
243
244            tables
245        }
246        vibesql_ast::Statement::Delete(delete) => {
247            let mut tables = HashSet::new();
248            // Extract table being deleted from
249            let table_name = if let Some(pos) = delete.table_name.rfind('.') {
250                &delete.table_name[pos + 1..]
251            } else {
252                &delete.table_name
253            };
254            tables.insert(table_name.to_string());
255
256            // Extract from WHERE clause
257            if let Some(vibesql_ast::WhereClause::Condition(expr)) = &delete.where_clause {
258                extract_from_expression(expr, &mut tables);
259            }
260
261            tables
262        }
263        // DDL statements don't reference tables in a way that matters for SELECT caching
264        _ => HashSet::new(),
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use vibesql_parser::Parser;
271
272    use super::*;
273
274    #[test]
275    fn test_extract_simple_select() {
276        let sql = "SELECT * FROM users";
277        let stmt = Parser::parse_sql(sql).unwrap();
278
279        if let vibesql_ast::Statement::Select(select) = stmt {
280            let tables = extract_tables_from_select(&select);
281            assert_eq!(tables.len(), 1);
282            // Parser uppercases identifiers
283            assert!(tables.contains("USERS"));
284        } else {
285            panic!("Expected SELECT statement");
286        }
287    }
288
289    #[test]
290    fn test_extract_join() {
291        let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
292        let stmt = Parser::parse_sql(sql).unwrap();
293
294        if let vibesql_ast::Statement::Select(select) = stmt {
295            let tables = extract_tables_from_select(&select);
296            assert_eq!(tables.len(), 2);
297            // Parser uppercases identifiers
298            assert!(tables.contains("USERS"));
299            assert!(tables.contains("ORDERS"));
300        } else {
301            panic!("Expected SELECT statement");
302        }
303    }
304
305    #[test]
306    fn test_extract_qualified_table_name() {
307        let sql = "SELECT * FROM public.users";
308        let stmt = Parser::parse_sql(sql).unwrap();
309
310        if let vibesql_ast::Statement::Select(select) = stmt {
311            let tables = extract_tables_from_select(&select);
312            assert_eq!(tables.len(), 1);
313            // Should extract just the table name, not the schema
314            assert!(tables.contains("USERS"));
315        } else {
316            panic!("Expected SELECT statement");
317        }
318    }
319
320    #[test]
321    fn test_extract_subquery_in_from() {
322        let sql = "SELECT * FROM (SELECT * FROM users) AS u";
323        let stmt = Parser::parse_sql(sql).unwrap();
324
325        if let vibesql_ast::Statement::Select(select) = stmt {
326            let tables = extract_tables_from_select(&select);
327            assert_eq!(tables.len(), 1);
328            assert!(tables.contains("USERS"));
329        } else {
330            panic!("Expected SELECT statement");
331        }
332    }
333
334    #[test]
335    fn test_extract_from_insert() {
336        let sql = "INSERT INTO users VALUES (1, 'Alice')";
337        let stmt = Parser::parse_sql(sql).unwrap();
338        let tables = extract_tables_from_statement(&stmt);
339
340        assert_eq!(tables.len(), 1);
341        assert!(tables.contains("USERS"));
342    }
343
344    #[test]
345    fn test_extract_from_update() {
346        let sql = "UPDATE users SET name = 'Bob' WHERE id = 1";
347        let stmt = Parser::parse_sql(sql).unwrap();
348        let tables = extract_tables_from_statement(&stmt);
349
350        assert_eq!(tables.len(), 1);
351        assert!(tables.contains("USERS"));
352    }
353
354    #[test]
355    fn test_extract_from_delete() {
356        let sql = "DELETE FROM users WHERE id = 1";
357        let stmt = Parser::parse_sql(sql).unwrap();
358        let tables = extract_tables_from_statement(&stmt);
359
360        assert_eq!(tables.len(), 1);
361        assert!(tables.contains("USERS"));
362    }
363}