vibesql_executor/
trigger_execution.rs

1//! Trigger execution logic for firing triggers on DML operations
2
3use std::cell::Cell;
4
5use vibesql_ast::{PseudoTable, TriggerEvent, TriggerGranularity, TriggerTiming};
6use vibesql_catalog::{TableSchema, TriggerDefinition};
7use vibesql_storage::{Database, Row};
8use vibesql_types::SqlValue;
9
10use crate::errors::ExecutorError;
11
12/// Maximum trigger recursion depth to prevent infinite loops
13const MAX_TRIGGER_RECURSION_DEPTH: usize = 16;
14
15thread_local! {
16    /// Current trigger recursion depth for this thread
17    static TRIGGER_RECURSION_DEPTH: Cell<usize> = const { Cell::new(0) };
18}
19
20/// RAII guard for managing trigger recursion depth
21/// Increments depth on creation, decrements on drop
22struct RecursionGuard;
23
24impl RecursionGuard {
25    /// Create a new recursion guard, incrementing the depth
26    ///
27    /// # Returns
28    /// Ok(RecursionGuard) if depth is within limits, Err if limit exceeded
29    fn new() -> Result<Self, ExecutorError> {
30        TRIGGER_RECURSION_DEPTH.with(|depth| {
31            let current = depth.get();
32            if current >= MAX_TRIGGER_RECURSION_DEPTH {
33                Err(ExecutorError::UnsupportedExpression(format!(
34                    "Trigger recursion depth limit exceeded (max: {}). Possible infinite trigger loop.",
35                    MAX_TRIGGER_RECURSION_DEPTH
36                )))
37            } else {
38                depth.set(current + 1);
39                Ok(RecursionGuard)
40            }
41        })
42    }
43}
44
45impl Drop for RecursionGuard {
46    fn drop(&mut self) {
47        TRIGGER_RECURSION_DEPTH.with(|depth| {
48            depth.set(depth.get().saturating_sub(1));
49        });
50    }
51}
52
53/// Execution context for triggers with OLD/NEW row access
54/// Provides pseudo-variable resolution for trigger bodies
55pub struct TriggerContext<'a> {
56    /// OLD row - available for UPDATE and DELETE triggers
57    pub old_row: Option<&'a Row>,
58    /// NEW row - available for INSERT and UPDATE triggers
59    pub new_row: Option<&'a Row>,
60    /// Table schema for column lookups
61    pub table_schema: &'a TableSchema,
62}
63
64impl<'a> TriggerContext<'a> {
65    /// Resolve a pseudo-variable reference to a SqlValue
66    ///
67    /// # Arguments
68    /// * `pseudo_table` - Which pseudo-table (OLD or NEW)
69    /// * `column` - Column name to retrieve
70    ///
71    /// # Returns
72    /// Ok(SqlValue) with the column value, or Err if invalid
73    ///
74    /// # Errors
75    /// - If OLD/NEW is not available for this trigger type
76    /// - If column doesn't exist in table schema
77    pub fn resolve_pseudo_var(
78        &self,
79        pseudo_table: PseudoTable,
80        column: &str,
81    ) -> Result<SqlValue, ExecutorError> {
82        // Get the appropriate row
83        let row = match pseudo_table {
84            PseudoTable::Old => self.old_row.ok_or_else(|| {
85                ExecutorError::UnsupportedExpression(
86                    "OLD pseudo-variable not available in this trigger context".to_string(),
87                )
88            })?,
89            PseudoTable::New => self.new_row.ok_or_else(|| {
90                ExecutorError::UnsupportedExpression(
91                    "NEW pseudo-variable not available in this trigger context".to_string(),
92                )
93            })?,
94        };
95
96        // Find column index in schema
97        let col_idx = self
98            .table_schema
99            .columns
100            .iter()
101            .position(|c| c.name == column)
102            .ok_or_else(|| ExecutorError::ColumnNotFound {
103                column_name: column.to_string(),
104                table_name: self.table_schema.name.clone(),
105                searched_tables: vec![self.table_schema.name.clone()],
106                available_columns: self.table_schema.columns.iter().map(|c| c.name.clone()).collect(),
107            })?;
108
109        // Return the value
110        Ok(row.values[col_idx].clone())
111    }
112}
113
114/// Helper struct for trigger firing (execution during DML operations)
115pub struct TriggerFirer;
116
117impl TriggerFirer {
118    /// Find triggers for a table and event
119    ///
120    /// # Arguments
121    /// * `db` - Database reference
122    /// * `table_name` - Name of the table to find triggers for
123    /// * `timing` - Trigger timing (BEFORE, AFTER, INSTEAD OF)
124    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
125    ///
126    /// # Returns
127    /// Vector of trigger definitions matching the criteria, sorted by creation order
128    pub fn find_triggers(
129        db: &Database,
130        table_name: &str,
131        timing: TriggerTiming,
132        event: TriggerEvent,
133    ) -> Vec<TriggerDefinition> {
134        db.catalog
135            .get_triggers_for_table(table_name, Some(event.clone()))
136            .filter(|trigger| trigger.timing == timing && trigger.enabled) // Skip disabled triggers
137            .cloned()
138            .collect()
139    }
140
141    /// Check if an UPDATE OF trigger should fire based on which columns changed
142    ///
143    /// # Arguments
144    /// * `trigger` - Trigger definition
145    /// * `old_row` - OLD row values
146    /// * `new_row` - NEW row values
147    /// * `table_schema` - Table schema for column lookup
148    ///
149    /// # Returns
150    /// true if the trigger should fire, false otherwise
151    fn should_fire_update_of(
152        trigger: &TriggerDefinition,
153        old_row: &Row,
154        new_row: &Row,
155        table_schema: &TableSchema,
156    ) -> bool {
157        match &trigger.event {
158            TriggerEvent::Update(Some(columns)) => {
159                // Check if any of the specified columns changed
160                for col_name in columns {
161                    if let Some(col_idx) = table_schema.columns.iter().position(|c| &c.name == col_name) {
162                        if col_idx < old_row.values.len() && col_idx < new_row.values.len()
163                            && old_row.values[col_idx] != new_row.values[col_idx] {
164                                return true; // At least one monitored column changed
165                            }
166                    }
167                }
168                false // None of the monitored columns changed
169            }
170            _ => true, // Not an UPDATE OF trigger, always fire
171        }
172    }
173
174    /// Execute a single trigger
175    ///
176    /// # Arguments
177    /// * `db` - Mutable database reference
178    /// * `trigger` - Trigger definition to execute
179    /// * `old_row` - OLD row for UPDATE/DELETE (None for INSERT)
180    /// * `new_row` - NEW row for INSERT/UPDATE (None for DELETE)
181    ///
182    /// # Returns
183    /// Ok(()) if trigger executed successfully, Err if execution failed
184    ///
185    /// # Notes
186    /// - For ROW-level triggers, this is called once per affected row
187    /// - For STATEMENT-level triggers, this is called once per statement
188    /// - WHEN conditions are evaluated here
189    pub fn execute_trigger(
190        db: &mut Database,
191        trigger: &TriggerDefinition,
192        old_row: Option<&Row>,
193        new_row: Option<&Row>,
194    ) -> Result<(), ExecutorError> {
195        // 1. Evaluate WHEN condition (if present)
196        if let Some(when_expr) = &trigger.when_condition {
197            let condition_result = Self::evaluate_when_condition(
198                db,
199                &trigger.table_name,
200                when_expr,
201                old_row,
202                new_row,
203            )?;
204
205            // Skip trigger execution if WHEN condition is false
206            if !condition_result {
207                return Ok(());
208            }
209        }
210
211        // 2. Execute trigger action
212        Self::execute_trigger_action(db, trigger, old_row, new_row)?;
213
214        Ok(())
215    }
216
217    /// Evaluate WHEN condition for a trigger
218    ///
219    /// # Arguments
220    /// * `db` - Database reference
221    /// * `table_name` - Name of the table
222    /// * `when_expr` - WHEN condition expression
223    /// * `old_row` - OLD row (for UPDATE/DELETE)
224    /// * `new_row` - NEW row (for INSERT/UPDATE)
225    ///
226    /// # Returns
227    /// Ok(true) if condition evaluates to true, Ok(false) otherwise
228    fn evaluate_when_condition(
229        db: &Database,
230        table_name: &str,
231        when_expr: &vibesql_ast::Expression,
232        old_row: Option<&Row>,
233        new_row: Option<&Row>,
234    ) -> Result<bool, ExecutorError> {
235        // Get table schema
236        let schema = db
237            .catalog
238            .get_table(table_name)
239            .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?;
240
241        // Use NEW row as the base row for evaluation (prefer NEW over OLD)
242        // The trigger context will handle OLD/NEW pseudo-variable references
243        let row = new_row.or(old_row).ok_or_else(|| {
244            ExecutorError::UnsupportedExpression("WHEN condition requires a row context".to_string())
245        })?;
246
247        // Create trigger context for OLD/NEW pseudo-variable resolution
248        let trigger_context = TriggerContext {
249            old_row,
250            new_row,
251            table_schema: schema,
252        };
253
254        // Create evaluator with trigger context
255        let evaluator = crate::ExpressionEvaluator::with_trigger_context(schema, db, &trigger_context);
256        let result = evaluator.eval(when_expr, row)?;
257
258        // Convert to boolean
259        match result {
260            vibesql_types::SqlValue::Boolean(b) => Ok(b),
261            vibesql_types::SqlValue::Null => Ok(false),
262            _ => Err(ExecutorError::UnsupportedExpression(
263                "WHEN condition must evaluate to boolean".to_string(),
264            )),
265        }
266    }
267
268    /// Execute trigger action statements
269    ///
270    /// # Arguments
271    /// * `db` - Mutable database reference
272    /// * `trigger` - Trigger definition
273    /// * `old_row` - OLD row (for UPDATE/DELETE)
274    /// * `new_row` - NEW row (for INSERT/UPDATE)
275    ///
276    /// # Returns
277    /// Ok(()) if action executed successfully, Err if execution failed
278    fn execute_trigger_action(
279        db: &mut Database,
280        trigger: &TriggerDefinition,
281        old_row: Option<&Row>,
282        new_row: Option<&Row>,
283    ) -> Result<(), ExecutorError> {
284        // Extract SQL from trigger action
285        let sql = match &trigger.triggered_action {
286            vibesql_ast::TriggerAction::RawSql(sql) => sql.clone(),
287        };
288
289        // Parse the trigger action SQL
290        let statements = Self::parse_trigger_sql(&sql)?;
291
292        // Get table schema for trigger context (clone to avoid borrow checker issues)
293        let schema = db
294            .catalog
295            .get_table(&trigger.table_name)
296            .ok_or_else(|| ExecutorError::TableNotFound(trigger.table_name.clone()))?
297            .clone();
298
299        // Create trigger context for OLD/NEW pseudo-variable resolution
300        let trigger_context = TriggerContext {
301            old_row,
302            new_row,
303            table_schema: &schema,
304        };
305
306        // Execute each statement in the trigger body with trigger context
307        for statement in statements {
308            Self::execute_statement(db, &statement, &trigger_context)?;
309        }
310
311        Ok(())
312    }
313
314    /// Parse trigger SQL into statements
315    ///
316    /// # Arguments
317    /// * `sql` - Raw SQL string from trigger action
318    ///
319    /// # Returns
320    /// Vector of parsed statements
321    fn parse_trigger_sql(sql: &str) -> Result<Vec<vibesql_ast::Statement>, ExecutorError> {
322        // Strip BEGIN/END wrapper if present
323        let sql = sql.trim();
324        let sql = if sql.to_uppercase().starts_with("BEGIN") {
325            // Remove BEGIN and END
326            let sql = sql[5..].trim();
327            if sql.to_uppercase().ends_with("END") {
328                &sql[..sql.len() - 3]
329            } else {
330                sql
331            }
332        } else {
333            sql
334        };
335
336        // Split by semicolons and parse each statement
337        let mut statements = Vec::new();
338        for stmt_sql in sql.split(';') {
339            let stmt_sql = stmt_sql.trim();
340            if stmt_sql.is_empty() || stmt_sql.starts_with("--") {
341                // Skip empty statements or comments
342                continue;
343            }
344
345            match vibesql_parser::Parser::parse_sql(stmt_sql) {
346                Ok(stmt) => statements.push(stmt),
347                Err(e) => {
348                    return Err(ExecutorError::UnsupportedExpression(format!(
349                        "Failed to parse trigger SQL: {}",
350                        e.message
351                    )))
352                }
353            }
354        }
355
356        // If no statements parsed (e.g., trigger body was only comments), that's OK
357        // Just return empty vector
358        Ok(statements)
359    }
360
361    /// Execute a single statement from trigger body
362    ///
363    /// # Arguments
364    /// * `db` - Mutable database reference
365    /// * `statement` - Statement to execute
366    /// * `trigger_context` - Trigger context with OLD/NEW row data
367    ///
368    /// # Returns
369    /// Ok(()) if statement executed successfully
370    fn execute_statement(
371        db: &mut Database,
372        statement: &vibesql_ast::Statement,
373        trigger_context: &TriggerContext,
374    ) -> Result<(), ExecutorError> {
375        use vibesql_ast::Statement;
376
377        match statement {
378            Statement::Insert(insert_stmt) => {
379                // Execute INSERT with trigger context support
380                crate::insert::execute_insert_with_trigger_context(db, insert_stmt, trigger_context)?;
381                Ok(())
382            }
383            Statement::Update(update_stmt) => {
384                // Execute UPDATE with trigger context support
385                crate::update::execute_update_with_trigger_context(db, update_stmt, trigger_context)?;
386                Ok(())
387            }
388            Statement::Delete(delete_stmt) => {
389                // Execute DELETE with trigger context support
390                crate::delete::execute_delete_with_trigger_context(db, delete_stmt, trigger_context)?;
391                Ok(())
392            }
393            Statement::Select(select_stmt) => {
394                // Execute SELECT but ignore results (useful for side effects)
395                // Note: SELECT doesn't need special trigger context handling since
396                // it can reference OLD/NEW through normal expression evaluation
397                let executor = crate::SelectExecutor::new(db);
398                executor.execute_with_columns(select_stmt)?;
399                Ok(())
400            }
401            _ => Err(ExecutorError::UnsupportedExpression(format!(
402                "Statement type not supported in triggers: {:?}",
403                statement
404            ))),
405        }
406    }
407
408    /// Execute all BEFORE ROW-level triggers for an operation
409    ///
410    /// # Arguments
411    /// * `db` - Mutable database reference
412    /// * `table_name` - Name of the table
413    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
414    /// * `old_row` - OLD row (for UPDATE/DELETE)
415    /// * `new_row` - NEW row (for INSERT/UPDATE)
416    ///
417    /// # Returns
418    /// Ok(()) if all triggers executed successfully
419    pub fn execute_before_triggers(
420        db: &mut Database,
421        table_name: &str,
422        event: TriggerEvent,
423        old_row: Option<&Row>,
424        new_row: Option<&Row>,
425    ) -> Result<(), ExecutorError> {
426        // Check recursion depth before executing any triggers
427        let _guard = RecursionGuard::new()?;
428
429        let triggers = Self::find_triggers(db, table_name, TriggerTiming::Before, event);
430
431        // Get table schema for UPDATE OF checking
432        let table_schema = db
433            .catalog
434            .get_table(table_name)
435            .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?
436            .clone();
437
438        for trigger in triggers {
439            // Only execute ROW-level triggers in this method
440            if trigger.granularity == TriggerGranularity::Row {
441                // For UPDATE OF triggers, check if monitored columns changed
442                if let (Some(old), Some(new)) = (old_row, new_row) {
443                    if !Self::should_fire_update_of(&trigger, old, new, &table_schema) {
444                        continue; // Skip this trigger
445                    }
446                }
447
448                Self::execute_trigger(db, &trigger, old_row, new_row)?;
449            }
450        }
451
452        Ok(())
453    }
454
455    /// Execute all BEFORE STATEMENT-level triggers for an operation
456    ///
457    /// # Arguments
458    /// * `db` - Mutable database reference
459    /// * `table_name` - Name of the table
460    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
461    ///
462    /// # Returns
463    /// Ok(()) if all triggers executed successfully
464    pub fn execute_before_statement_triggers(
465        db: &mut Database,
466        table_name: &str,
467        event: TriggerEvent,
468    ) -> Result<(), ExecutorError> {
469        // Check recursion depth before executing any triggers
470        let _guard = RecursionGuard::new()?;
471
472        let triggers = Self::find_triggers(db, table_name, TriggerTiming::Before, event);
473
474        for trigger in triggers {
475            // Only execute STATEMENT-level triggers in this method
476            if trigger.granularity == TriggerGranularity::Statement {
477                // Statement-level triggers don't have OLD/NEW row access
478                Self::execute_trigger(db, &trigger, None, None)?;
479            }
480        }
481
482        Ok(())
483    }
484
485    /// Execute all AFTER ROW-level triggers for an operation
486    ///
487    /// # Arguments
488    /// * `db` - Mutable database reference
489    /// * `table_name` - Name of the table
490    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
491    /// * `old_row` - OLD row (for UPDATE/DELETE)
492    /// * `new_row` - NEW row (for INSERT/UPDATE)
493    ///
494    /// # Returns
495    /// Ok(()) if all triggers executed successfully
496    pub fn execute_after_triggers(
497        db: &mut Database,
498        table_name: &str,
499        event: TriggerEvent,
500        old_row: Option<&Row>,
501        new_row: Option<&Row>,
502    ) -> Result<(), ExecutorError> {
503        // Check recursion depth before executing any triggers
504        let _guard = RecursionGuard::new()?;
505
506        let triggers = Self::find_triggers(db, table_name, TriggerTiming::After, event);
507
508        // Get table schema for UPDATE OF checking
509        let table_schema = db
510            .catalog
511            .get_table(table_name)
512            .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?
513            .clone();
514
515        for trigger in triggers {
516            // Only execute ROW-level triggers in this method
517            if trigger.granularity == TriggerGranularity::Row {
518                // For UPDATE OF triggers, check if monitored columns changed
519                if let (Some(old), Some(new)) = (old_row, new_row) {
520                    if !Self::should_fire_update_of(&trigger, old, new, &table_schema) {
521                        continue; // Skip this trigger
522                    }
523                }
524
525                Self::execute_trigger(db, &trigger, old_row, new_row)?;
526            }
527        }
528
529        Ok(())
530    }
531
532    /// Execute all AFTER STATEMENT-level triggers for an operation
533    ///
534    /// # Arguments
535    /// * `db` - Mutable database reference
536    /// * `table_name` - Name of the table
537    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
538    ///
539    /// # Returns
540    /// Ok(()) if all triggers executed successfully
541    pub fn execute_after_statement_triggers(
542        db: &mut Database,
543        table_name: &str,
544        event: TriggerEvent,
545    ) -> Result<(), ExecutorError> {
546        // Check recursion depth before executing any triggers
547        let _guard = RecursionGuard::new()?;
548
549        let triggers = Self::find_triggers(db, table_name, TriggerTiming::After, event);
550
551        for trigger in triggers {
552            // Only execute STATEMENT-level triggers in this method
553            if trigger.granularity == TriggerGranularity::Statement {
554                // Statement-level triggers don't have OLD/NEW row access
555                Self::execute_trigger(db, &trigger, None, None)?;
556            }
557        }
558
559        Ok(())
560    }
561}