vibesql_executor/cache/
table_extractor.rs

1//! Extract table names from AST for cache invalidation
2
3use std::collections::HashSet;
4
5/// Extract all table names referenced in a SELECT statement
6pub fn extract_tables_from_select(stmt: &vibesql_ast::SelectStmt) -> HashSet<String> {
7    let mut tables = HashSet::new();
8
9    // Extract from FROM clause
10    if let Some(from_clause) = &stmt.from {
11        extract_from_from_clause(from_clause, &mut tables);
12    }
13
14    // Extract from subqueries in SELECT list
15    for select_item in &stmt.select_list {
16        if let vibesql_ast::SelectItem::Expression { expr, .. } = select_item {
17            extract_from_expression(expr, &mut tables);
18        }
19    }
20
21    // Extract from WHERE clause
22    if let Some(where_clause) = &stmt.where_clause {
23        extract_from_expression(where_clause, &mut tables);
24    }
25
26    // Extract from GROUP BY
27    if let Some(group_by) = &stmt.group_by {
28        for expr in group_by.all_expressions() {
29            extract_from_expression(expr, &mut tables);
30        }
31    }
32
33    // Extract from HAVING
34    if let Some(having) = &stmt.having {
35        extract_from_expression(having, &mut tables);
36    }
37
38    // Extract from ORDER BY
39    if let Some(order_by) = &stmt.order_by {
40        for order_item in order_by {
41            extract_from_expression(&order_item.expr, &mut tables);
42        }
43    }
44
45    // Extract from CTEs (WITH clause)
46    if let Some(with_clause) = &stmt.with_clause {
47        for cte in with_clause {
48            let cte_tables = extract_tables_from_select(&cte.query);
49            tables.extend(cte_tables);
50        }
51    }
52
53    // Extract from set operations (UNION, INTERSECT, EXCEPT)
54    if let Some(set_op) = &stmt.set_operation {
55        let right_tables = extract_tables_from_select(&set_op.right);
56        tables.extend(right_tables);
57    }
58
59    tables
60}
61
62/// Extract table names from a FROM clause
63fn extract_from_from_clause(from: &vibesql_ast::FromClause, tables: &mut HashSet<String>) {
64    match from {
65        vibesql_ast::FromClause::Table { name, .. } => {
66            // Handle qualified names (schema.table) - we want just the table name
67            let table_name = if let Some(pos) = name.rfind('.') { &name[pos + 1..] } else { name };
68            tables.insert(table_name.to_string());
69        }
70        vibesql_ast::FromClause::Join { left, right, condition, .. } => {
71            extract_from_from_clause(left, tables);
72            extract_from_from_clause(right, tables);
73            if let Some(cond) = condition {
74                extract_from_expression(cond, tables);
75            }
76        }
77        vibesql_ast::FromClause::Subquery { query, .. } => {
78            let subquery_tables = extract_tables_from_select(query);
79            tables.extend(subquery_tables);
80        }
81    }
82}
83
84/// Extract table names from an expression (for subqueries)
85fn extract_from_expression(expr: &vibesql_ast::Expression, tables: &mut HashSet<String>) {
86    match expr {
87        vibesql_ast::Expression::ScalarSubquery(stmt) => {
88            let subquery_tables = extract_tables_from_select(stmt);
89            tables.extend(subquery_tables);
90        }
91        vibesql_ast::Expression::BinaryOp { left, right, .. } => {
92            extract_from_expression(left, tables);
93            extract_from_expression(right, tables);
94        }
95        vibesql_ast::Expression::UnaryOp { expr, .. } => {
96            extract_from_expression(expr, tables);
97        }
98        vibesql_ast::Expression::Function { args, .. }
99        | vibesql_ast::Expression::AggregateFunction { args, .. } => {
100            for arg in args {
101                extract_from_expression(arg, tables);
102            }
103        }
104        vibesql_ast::Expression::Case { operand, when_clauses, else_result, .. } => {
105            if let Some(op) = operand {
106                extract_from_expression(op, tables);
107            }
108            for when_clause in when_clauses {
109                for condition in &when_clause.conditions {
110                    extract_from_expression(condition, tables);
111                }
112                extract_from_expression(&when_clause.result, tables);
113            }
114            if let Some(else_expr) = else_result {
115                extract_from_expression(else_expr, tables);
116            }
117        }
118        vibesql_ast::Expression::In { expr, subquery, .. } => {
119            extract_from_expression(expr, tables);
120            let subquery_tables = extract_tables_from_select(subquery);
121            tables.extend(subquery_tables);
122        }
123        vibesql_ast::Expression::InList { expr, values, .. } => {
124            extract_from_expression(expr, tables);
125            for val in values {
126                extract_from_expression(val, tables);
127            }
128        }
129        vibesql_ast::Expression::Exists { subquery, .. } => {
130            let subquery_tables = extract_tables_from_select(subquery);
131            tables.extend(subquery_tables);
132        }
133        vibesql_ast::Expression::Between { expr, low, high, .. } => {
134            extract_from_expression(expr, tables);
135            extract_from_expression(low, tables);
136            extract_from_expression(high, tables);
137        }
138        vibesql_ast::Expression::IsNull { expr, .. } => {
139            extract_from_expression(expr, tables);
140        }
141        vibesql_ast::Expression::Cast { expr, .. } => {
142            extract_from_expression(expr, tables);
143        }
144        vibesql_ast::Expression::Like { expr, pattern, .. } => {
145            extract_from_expression(expr, tables);
146            extract_from_expression(pattern, tables);
147        }
148        vibesql_ast::Expression::Position { substring, string, .. } => {
149            extract_from_expression(substring, tables);
150            extract_from_expression(string, tables);
151        }
152        vibesql_ast::Expression::Trim { removal_char, string, .. } => {
153            if let Some(removal) = removal_char {
154                extract_from_expression(removal, tables);
155            }
156            extract_from_expression(string, tables);
157        }
158        vibesql_ast::Expression::Extract { expr, .. } => {
159            extract_from_expression(expr, tables);
160        }
161        vibesql_ast::Expression::QuantifiedComparison { expr, subquery, .. } => {
162            extract_from_expression(expr, tables);
163            let subquery_tables = extract_tables_from_select(subquery);
164            tables.extend(subquery_tables);
165        }
166        vibesql_ast::Expression::Conjunction(children)
167        | vibesql_ast::Expression::Disjunction(children) => {
168            for child in children {
169                extract_from_expression(child, tables);
170            }
171        }
172
173        // Leaf expressions - no tables to extract
174        vibesql_ast::Expression::Literal(_)
175        | vibesql_ast::Expression::Placeholder(_)
176        | vibesql_ast::Expression::NumberedPlaceholder(_)
177        | vibesql_ast::Expression::NamedPlaceholder(_)
178        | vibesql_ast::Expression::ColumnRef { .. }
179        | vibesql_ast::Expression::Wildcard
180        | vibesql_ast::Expression::CurrentDate
181        | vibesql_ast::Expression::CurrentTime { .. }
182        | vibesql_ast::Expression::CurrentTimestamp { .. }
183        | vibesql_ast::Expression::Interval { .. }
184        | vibesql_ast::Expression::Default
185        | vibesql_ast::Expression::DuplicateKeyValue { .. }
186        | vibesql_ast::Expression::WindowFunction { .. }
187        | vibesql_ast::Expression::NextValue { .. }
188        | vibesql_ast::Expression::MatchAgainst { .. }
189        | vibesql_ast::Expression::PseudoVariable { .. }
190        | vibesql_ast::Expression::SessionVariable { .. } => {}
191    }
192}
193
194/// Extract table names from any statement (for comprehensive cache invalidation)
195pub fn extract_tables_from_statement(stmt: &vibesql_ast::Statement) -> HashSet<String> {
196    match stmt {
197        vibesql_ast::Statement::Select(select) => extract_tables_from_select(select),
198        vibesql_ast::Statement::Insert(insert) => {
199            let mut tables = HashSet::new();
200            // Extract table name being inserted into
201            let table_name = if let Some(pos) = insert.table_name.rfind('.') {
202                &insert.table_name[pos + 1..]
203            } else {
204                &insert.table_name
205            };
206            tables.insert(table_name.to_string());
207
208            // Extract from source (VALUES or SELECT)
209            match &insert.source {
210                vibesql_ast::InsertSource::Values(values) => {
211                    for row in values {
212                        for expr in row {
213                            extract_from_expression(expr, &mut tables);
214                        }
215                    }
216                }
217                vibesql_ast::InsertSource::Select(select) => {
218                    let select_tables = extract_tables_from_select(select);
219                    tables.extend(select_tables);
220                }
221            }
222
223            tables
224        }
225        vibesql_ast::Statement::Update(update) => {
226            let mut tables = HashSet::new();
227            // Extract table being updated
228            let table_name = if let Some(pos) = update.table_name.rfind('.') {
229                &update.table_name[pos + 1..]
230            } else {
231                &update.table_name
232            };
233            tables.insert(table_name.to_string());
234
235            // Extract from SET assignments
236            for assignment in &update.assignments {
237                extract_from_expression(&assignment.value, &mut tables);
238            }
239
240            // Extract from WHERE clause
241            if let Some(vibesql_ast::WhereClause::Condition(expr)) = &update.where_clause {
242                extract_from_expression(expr, &mut tables);
243            }
244
245            tables
246        }
247        vibesql_ast::Statement::Delete(delete) => {
248            let mut tables = HashSet::new();
249            // Extract table being deleted from
250            let table_name = if let Some(pos) = delete.table_name.rfind('.') {
251                &delete.table_name[pos + 1..]
252            } else {
253                &delete.table_name
254            };
255            tables.insert(table_name.to_string());
256
257            // Extract from WHERE clause
258            if let Some(vibesql_ast::WhereClause::Condition(expr)) = &delete.where_clause {
259                extract_from_expression(expr, &mut tables);
260            }
261
262            tables
263        }
264        // DDL statements don't reference tables in a way that matters for SELECT caching
265        _ => HashSet::new(),
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use vibesql_parser::Parser;
272
273    use super::*;
274
275    #[test]
276    fn test_extract_simple_select() {
277        let sql = "SELECT * FROM users";
278        let stmt = Parser::parse_sql(sql).unwrap();
279
280        if let vibesql_ast::Statement::Select(select) = stmt {
281            let tables = extract_tables_from_select(&select);
282            assert_eq!(tables.len(), 1);
283            // Parser uppercases identifiers
284            assert!(tables.contains("USERS"));
285        } else {
286            panic!("Expected SELECT statement");
287        }
288    }
289
290    #[test]
291    fn test_extract_join() {
292        let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
293        let stmt = Parser::parse_sql(sql).unwrap();
294
295        if let vibesql_ast::Statement::Select(select) = stmt {
296            let tables = extract_tables_from_select(&select);
297            assert_eq!(tables.len(), 2);
298            // Parser uppercases identifiers
299            assert!(tables.contains("USERS"));
300            assert!(tables.contains("ORDERS"));
301        } else {
302            panic!("Expected SELECT statement");
303        }
304    }
305
306    #[test]
307    fn test_extract_qualified_table_name() {
308        let sql = "SELECT * FROM public.users";
309        let stmt = Parser::parse_sql(sql).unwrap();
310
311        if let vibesql_ast::Statement::Select(select) = stmt {
312            let tables = extract_tables_from_select(&select);
313            assert_eq!(tables.len(), 1);
314            // Should extract just the table name, not the schema
315            assert!(tables.contains("USERS"));
316        } else {
317            panic!("Expected SELECT statement");
318        }
319    }
320
321    #[test]
322    fn test_extract_subquery_in_from() {
323        let sql = "SELECT * FROM (SELECT * FROM users) AS u";
324        let stmt = Parser::parse_sql(sql).unwrap();
325
326        if let vibesql_ast::Statement::Select(select) = stmt {
327            let tables = extract_tables_from_select(&select);
328            assert_eq!(tables.len(), 1);
329            assert!(tables.contains("USERS"));
330        } else {
331            panic!("Expected SELECT statement");
332        }
333    }
334
335    #[test]
336    fn test_extract_from_insert() {
337        let sql = "INSERT INTO users VALUES (1, 'Alice')";
338        let stmt = Parser::parse_sql(sql).unwrap();
339        let tables = extract_tables_from_statement(&stmt);
340
341        assert_eq!(tables.len(), 1);
342        assert!(tables.contains("USERS"));
343    }
344
345    #[test]
346    fn test_extract_from_update() {
347        let sql = "UPDATE users SET name = 'Bob' WHERE id = 1";
348        let stmt = Parser::parse_sql(sql).unwrap();
349        let tables = extract_tables_from_statement(&stmt);
350
351        assert_eq!(tables.len(), 1);
352        assert!(tables.contains("USERS"));
353    }
354
355    #[test]
356    fn test_extract_from_delete() {
357        let sql = "DELETE FROM users WHERE id = 1";
358        let stmt = Parser::parse_sql(sql).unwrap();
359        let tables = extract_tables_from_statement(&stmt);
360
361        assert_eq!(tables.len(), 1);
362        assert!(tables.contains("USERS"));
363    }
364}