vibesql_executor/cache/
table_extractor.rs

1//! Extract table names from AST for cache invalidation
2
3use std::collections::HashSet;
4
5/// Extract all table names referenced in a SELECT statement
6pub fn extract_tables_from_select(stmt: &vibesql_ast::SelectStmt) -> HashSet<String> {
7    let mut tables = HashSet::new();
8
9    // Extract from FROM clause
10    if let Some(from_clause) = &stmt.from {
11        extract_from_from_clause(from_clause, &mut tables);
12    }
13
14    // Extract from subqueries in SELECT list
15    for select_item in &stmt.select_list {
16        if let vibesql_ast::SelectItem::Expression { expr, .. } = select_item {
17            extract_from_expression(expr, &mut tables);
18        }
19    }
20
21    // Extract from WHERE clause
22    if let Some(where_clause) = &stmt.where_clause {
23        extract_from_expression(where_clause, &mut tables);
24    }
25
26    // Extract from GROUP BY
27    if let Some(group_by) = &stmt.group_by {
28        for expr in group_by.all_expressions() {
29            extract_from_expression(expr, &mut tables);
30        }
31    }
32
33    // Extract from HAVING
34    if let Some(having) = &stmt.having {
35        extract_from_expression(having, &mut tables);
36    }
37
38    // Extract from ORDER BY
39    if let Some(order_by) = &stmt.order_by {
40        for order_item in order_by {
41            extract_from_expression(&order_item.expr, &mut tables);
42        }
43    }
44
45    // Extract from CTEs (WITH clause)
46    if let Some(with_clause) = &stmt.with_clause {
47        for cte in with_clause {
48            let cte_tables = extract_tables_from_select(&cte.query);
49            tables.extend(cte_tables);
50        }
51    }
52
53    // Extract from set operations (UNION, INTERSECT, EXCEPT)
54    if let Some(set_op) = &stmt.set_operation {
55        let right_tables = extract_tables_from_select(&set_op.right);
56        tables.extend(right_tables);
57    }
58
59    tables
60}
61
62/// Extract table names from a FROM clause
63fn extract_from_from_clause(from: &vibesql_ast::FromClause, tables: &mut HashSet<String>) {
64    match from {
65        vibesql_ast::FromClause::Table { name, .. } => {
66            // Handle qualified names (schema.table) - we want just the table name
67            let table_name = if let Some(pos) = name.rfind('.') { &name[pos + 1..] } else { name };
68            tables.insert(table_name.to_string());
69        }
70        vibesql_ast::FromClause::Join { left, right, condition, .. } => {
71            extract_from_from_clause(left, tables);
72            extract_from_from_clause(right, tables);
73            if let Some(cond) = condition {
74                extract_from_expression(cond, tables);
75            }
76        }
77        vibesql_ast::FromClause::Subquery { query, .. } => {
78            let subquery_tables = extract_tables_from_select(query);
79            tables.extend(subquery_tables);
80        }
81        vibesql_ast::FromClause::Values { .. } => {
82            // VALUES clauses don't reference any tables
83        }
84    }
85}
86
87/// Extract table names from an expression (for subqueries)
88fn extract_from_expression(expr: &vibesql_ast::Expression, tables: &mut HashSet<String>) {
89    match expr {
90        vibesql_ast::Expression::ScalarSubquery(stmt) => {
91            let subquery_tables = extract_tables_from_select(stmt);
92            tables.extend(subquery_tables);
93        }
94        vibesql_ast::Expression::BinaryOp { left, right, .. } => {
95            extract_from_expression(left, tables);
96            extract_from_expression(right, tables);
97        }
98        vibesql_ast::Expression::UnaryOp { expr, .. } => {
99            extract_from_expression(expr, tables);
100        }
101        vibesql_ast::Expression::Function { args, .. }
102        | vibesql_ast::Expression::AggregateFunction { args, .. } => {
103            for arg in args {
104                extract_from_expression(arg, tables);
105            }
106        }
107        vibesql_ast::Expression::Case { operand, when_clauses, else_result, .. } => {
108            if let Some(op) = operand {
109                extract_from_expression(op, tables);
110            }
111            for when_clause in when_clauses {
112                for condition in &when_clause.conditions {
113                    extract_from_expression(condition, tables);
114                }
115                extract_from_expression(&when_clause.result, tables);
116            }
117            if let Some(else_expr) = else_result {
118                extract_from_expression(else_expr, tables);
119            }
120        }
121        vibesql_ast::Expression::In { expr, subquery, .. } => {
122            extract_from_expression(expr, tables);
123            let subquery_tables = extract_tables_from_select(subquery);
124            tables.extend(subquery_tables);
125        }
126        vibesql_ast::Expression::InList { expr, values, .. } => {
127            extract_from_expression(expr, tables);
128            for val in values {
129                extract_from_expression(val, tables);
130            }
131        }
132        vibesql_ast::Expression::Exists { subquery, .. } => {
133            let subquery_tables = extract_tables_from_select(subquery);
134            tables.extend(subquery_tables);
135        }
136        vibesql_ast::Expression::Between { expr, low, high, .. } => {
137            extract_from_expression(expr, tables);
138            extract_from_expression(low, tables);
139            extract_from_expression(high, tables);
140        }
141        vibesql_ast::Expression::IsNull { expr, .. } => {
142            extract_from_expression(expr, tables);
143        }
144        vibesql_ast::Expression::IsDistinctFrom { left, right, .. } => {
145            extract_from_expression(left, tables);
146            extract_from_expression(right, tables);
147        }
148        vibesql_ast::Expression::IsTruthValue { expr, .. } => {
149            extract_from_expression(expr, tables);
150        }
151        vibesql_ast::Expression::Cast { expr, .. } => {
152            extract_from_expression(expr, tables);
153        }
154        vibesql_ast::Expression::Like { expr, pattern, .. }
155        | vibesql_ast::Expression::Glob { expr, pattern, .. } => {
156            extract_from_expression(expr, tables);
157            extract_from_expression(pattern, tables);
158        }
159        vibesql_ast::Expression::Position { substring, string, .. } => {
160            extract_from_expression(substring, tables);
161            extract_from_expression(string, tables);
162        }
163        vibesql_ast::Expression::Trim { removal_char, string, .. } => {
164            if let Some(removal) = removal_char {
165                extract_from_expression(removal, tables);
166            }
167            extract_from_expression(string, tables);
168        }
169        vibesql_ast::Expression::Extract { expr, .. } => {
170            extract_from_expression(expr, tables);
171        }
172        vibesql_ast::Expression::QuantifiedComparison { expr, subquery, .. } => {
173            extract_from_expression(expr, tables);
174            let subquery_tables = extract_tables_from_select(subquery);
175            tables.extend(subquery_tables);
176        }
177        vibesql_ast::Expression::Conjunction(children)
178        | vibesql_ast::Expression::Disjunction(children)
179        | vibesql_ast::Expression::RowValueConstructor(children) => {
180            for child in children {
181                extract_from_expression(child, tables);
182            }
183        }
184
185        vibesql_ast::Expression::Collate { expr, .. } => {
186            extract_from_expression(expr, tables);
187        }
188
189        // Leaf expressions - no tables to extract
190        vibesql_ast::Expression::Literal(_)
191        | vibesql_ast::Expression::Placeholder(_)
192        | vibesql_ast::Expression::NumberedPlaceholder(_)
193        | vibesql_ast::Expression::NamedPlaceholder(_)
194        | vibesql_ast::Expression::ColumnRef(_)
195        | vibesql_ast::Expression::Wildcard
196        | vibesql_ast::Expression::CurrentDate
197        | vibesql_ast::Expression::CurrentTime { .. }
198        | vibesql_ast::Expression::CurrentTimestamp { .. }
199        | vibesql_ast::Expression::Interval { .. }
200        | vibesql_ast::Expression::Default
201        | vibesql_ast::Expression::DuplicateKeyValue { .. }
202        | vibesql_ast::Expression::WindowFunction { .. }
203        | vibesql_ast::Expression::NextValue { .. }
204        | vibesql_ast::Expression::MatchAgainst { .. }
205        | vibesql_ast::Expression::PseudoVariable { .. }
206        | vibesql_ast::Expression::SessionVariable { .. } => {}
207    }
208}
209
210/// Extract table names from any statement (for comprehensive cache invalidation)
211pub fn extract_tables_from_statement(stmt: &vibesql_ast::Statement) -> HashSet<String> {
212    match stmt {
213        vibesql_ast::Statement::Select(select) => extract_tables_from_select(select),
214        vibesql_ast::Statement::Insert(insert) => {
215            let mut tables = HashSet::new();
216            // Extract table name being inserted into
217            let table_name = if let Some(pos) = insert.table_name.rfind('.') {
218                &insert.table_name[pos + 1..]
219            } else {
220                &insert.table_name
221            };
222            tables.insert(table_name.to_string());
223
224            // Extract from source (VALUES or SELECT)
225            match &insert.source {
226                vibesql_ast::InsertSource::Values(values) => {
227                    for row in values {
228                        for expr in row {
229                            extract_from_expression(expr, &mut tables);
230                        }
231                    }
232                }
233                vibesql_ast::InsertSource::Select(select) => {
234                    let select_tables = extract_tables_from_select(select);
235                    tables.extend(select_tables);
236                }
237                vibesql_ast::InsertSource::DefaultValues => {
238                    // No expressions to extract from DEFAULT VALUES
239                }
240            }
241
242            tables
243        }
244        vibesql_ast::Statement::Update(update) => {
245            let mut tables = HashSet::new();
246            // Extract table being updated
247            let table_name = if let Some(pos) = update.table_name.rfind('.') {
248                &update.table_name[pos + 1..]
249            } else {
250                &update.table_name
251            };
252            tables.insert(table_name.to_string());
253
254            // Extract from SET assignments
255            for assignment in &update.assignments {
256                extract_from_expression(&assignment.value, &mut tables);
257            }
258
259            // Extract from WHERE clause
260            if let Some(vibesql_ast::WhereClause::Condition(expr)) = &update.where_clause {
261                extract_from_expression(expr, &mut tables);
262            }
263
264            tables
265        }
266        vibesql_ast::Statement::Delete(delete) => {
267            let mut tables = HashSet::new();
268            // Extract table being deleted from
269            let table_name = if let Some(pos) = delete.table_name.rfind('.') {
270                &delete.table_name[pos + 1..]
271            } else {
272                &delete.table_name
273            };
274            tables.insert(table_name.to_string());
275
276            // Extract from WHERE clause
277            if let Some(vibesql_ast::WhereClause::Condition(expr)) = &delete.where_clause {
278                extract_from_expression(expr, &mut tables);
279            }
280
281            tables
282        }
283        // DDL statements don't reference tables in a way that matters for SELECT caching
284        _ => HashSet::new(),
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use vibesql_parser::Parser;
291
292    use super::*;
293
294    #[test]
295    fn test_extract_simple_select() {
296        let sql = "SELECT * FROM users";
297        let stmt = Parser::parse_sql(sql).unwrap();
298
299        if let vibesql_ast::Statement::Select(select) = stmt {
300            let tables = extract_tables_from_select(&select);
301            assert_eq!(tables.len(), 1);
302            // Parser uppercases identifiers
303            assert!(tables.contains("users"));
304        } else {
305            panic!("Expected SELECT statement");
306        }
307    }
308
309    #[test]
310    fn test_extract_join() {
311        let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
312        let stmt = Parser::parse_sql(sql).unwrap();
313
314        if let vibesql_ast::Statement::Select(select) = stmt {
315            let tables = extract_tables_from_select(&select);
316            assert_eq!(tables.len(), 2);
317            // Parser uppercases identifiers
318            assert!(tables.contains("users"));
319            assert!(tables.contains("orders"));
320        } else {
321            panic!("Expected SELECT statement");
322        }
323    }
324
325    #[test]
326    fn test_extract_qualified_table_name() {
327        let sql = "SELECT * FROM public.users";
328        let stmt = Parser::parse_sql(sql).unwrap();
329
330        if let vibesql_ast::Statement::Select(select) = stmt {
331            let tables = extract_tables_from_select(&select);
332            assert_eq!(tables.len(), 1);
333            // Should extract just the table name, not the schema
334            assert!(tables.contains("users"));
335        } else {
336            panic!("Expected SELECT statement");
337        }
338    }
339
340    #[test]
341    fn test_extract_subquery_in_from() {
342        let sql = "SELECT * FROM (SELECT * FROM users) AS u";
343        let stmt = Parser::parse_sql(sql).unwrap();
344
345        if let vibesql_ast::Statement::Select(select) = stmt {
346            let tables = extract_tables_from_select(&select);
347            assert_eq!(tables.len(), 1);
348            assert!(tables.contains("users"));
349        } else {
350            panic!("Expected SELECT statement");
351        }
352    }
353
354    #[test]
355    fn test_extract_from_insert() {
356        let sql = "INSERT INTO users VALUES (1, 'Alice')";
357        let stmt = Parser::parse_sql(sql).unwrap();
358        let tables = extract_tables_from_statement(&stmt);
359
360        assert_eq!(tables.len(), 1);
361        assert!(tables.contains("users"));
362    }
363
364    #[test]
365    fn test_extract_from_update() {
366        let sql = "UPDATE users SET name = 'Bob' WHERE id = 1";
367        let stmt = Parser::parse_sql(sql).unwrap();
368        let tables = extract_tables_from_statement(&stmt);
369
370        assert_eq!(tables.len(), 1);
371        assert!(tables.contains("users"));
372    }
373
374    #[test]
375    fn test_extract_from_delete() {
376        let sql = "DELETE FROM users WHERE id = 1";
377        let stmt = Parser::parse_sql(sql).unwrap();
378        let tables = extract_tables_from_statement(&stmt);
379
380        assert_eq!(tables.len(), 1);
381        assert!(tables.contains("users"));
382    }
383}