vibesql_executor/
trigger_execution.rs

1//! Trigger execution logic for firing triggers on DML operations
2
3use std::cell::Cell;
4
5use vibesql_ast::{PseudoTable, TriggerEvent, TriggerGranularity, TriggerTiming};
6use vibesql_catalog::{TableSchema, TriggerDefinition};
7use vibesql_storage::{Database, Row};
8use vibesql_types::SqlValue;
9
10use crate::errors::ExecutorError;
11
12/// Maximum trigger recursion depth to prevent infinite loops
13const MAX_TRIGGER_RECURSION_DEPTH: usize = 16;
14
15thread_local! {
16    /// Current trigger recursion depth for this thread
17    static TRIGGER_RECURSION_DEPTH: Cell<usize> = const { Cell::new(0) };
18}
19
20/// RAII guard for managing trigger recursion depth
21/// Increments depth on creation, decrements on drop
22struct RecursionGuard;
23
24impl RecursionGuard {
25    /// Create a new recursion guard, incrementing the depth
26    ///
27    /// # Returns
28    /// Ok(RecursionGuard) if depth is within limits, Err if limit exceeded
29    fn new() -> Result<Self, ExecutorError> {
30        TRIGGER_RECURSION_DEPTH.with(|depth| {
31            let current = depth.get();
32            if current >= MAX_TRIGGER_RECURSION_DEPTH {
33                Err(ExecutorError::UnsupportedExpression(format!(
34                    "Trigger recursion depth limit exceeded (max: {}). Possible infinite trigger loop.",
35                    MAX_TRIGGER_RECURSION_DEPTH
36                )))
37            } else {
38                depth.set(current + 1);
39                Ok(RecursionGuard)
40            }
41        })
42    }
43}
44
45impl Drop for RecursionGuard {
46    fn drop(&mut self) {
47        TRIGGER_RECURSION_DEPTH.with(|depth| {
48            depth.set(depth.get().saturating_sub(1));
49        });
50    }
51}
52
53/// Execution context for triggers with OLD/NEW row access
54/// Provides pseudo-variable resolution for trigger bodies
55pub struct TriggerContext<'a> {
56    /// OLD row - available for UPDATE and DELETE triggers
57    pub old_row: Option<&'a Row>,
58    /// NEW row - available for INSERT and UPDATE triggers
59    pub new_row: Option<&'a Row>,
60    /// Table schema for column lookups
61    pub table_schema: &'a TableSchema,
62}
63
64impl<'a> TriggerContext<'a> {
65    /// Resolve a pseudo-variable reference to a SqlValue
66    ///
67    /// # Arguments
68    /// * `pseudo_table` - Which pseudo-table (OLD or NEW)
69    /// * `column` - Column name to retrieve
70    ///
71    /// # Returns
72    /// Ok(SqlValue) with the column value, or Err if invalid
73    ///
74    /// # Errors
75    /// - If OLD/NEW is not available for this trigger type
76    /// - If column doesn't exist in table schema
77    pub fn resolve_pseudo_var(
78        &self,
79        pseudo_table: PseudoTable,
80        column: &str,
81    ) -> Result<SqlValue, ExecutorError> {
82        // Get the appropriate row
83        let row = match pseudo_table {
84            PseudoTable::Old => self.old_row.ok_or_else(|| {
85                ExecutorError::UnsupportedExpression(
86                    "OLD pseudo-variable not available in this trigger context".to_string(),
87                )
88            })?,
89            PseudoTable::New => self.new_row.ok_or_else(|| {
90                ExecutorError::UnsupportedExpression(
91                    "NEW pseudo-variable not available in this trigger context".to_string(),
92                )
93            })?,
94        };
95
96        // Find column index in schema
97        let col_idx =
98            self.table_schema.columns.iter().position(|c| c.name == column).ok_or_else(|| {
99                ExecutorError::ColumnNotFound {
100                    column_name: column.to_string(),
101                    table_name: self.table_schema.name.clone(),
102                    searched_tables: vec![self.table_schema.name.clone()],
103                    available_columns: self
104                        .table_schema
105                        .columns
106                        .iter()
107                        .map(|c| c.name.clone())
108                        .collect(),
109                }
110            })?;
111
112        // Return the value
113        Ok(row.values[col_idx].clone())
114    }
115}
116
117/// Helper struct for trigger firing (execution during DML operations)
118pub struct TriggerFirer;
119
120impl TriggerFirer {
121    /// Find triggers for a table and event
122    ///
123    /// # Arguments
124    /// * `db` - Database reference
125    /// * `table_name` - Name of the table to find triggers for
126    /// * `timing` - Trigger timing (BEFORE, AFTER, INSTEAD OF)
127    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
128    ///
129    /// # Returns
130    /// Vector of trigger definitions matching the criteria, sorted by creation order
131    pub fn find_triggers(
132        db: &Database,
133        table_name: &str,
134        timing: TriggerTiming,
135        event: TriggerEvent,
136    ) -> Vec<TriggerDefinition> {
137        db.catalog
138            .get_triggers_for_table(table_name, Some(event.clone()))
139            .filter(|trigger| trigger.timing == timing && trigger.enabled) // Skip disabled triggers
140            .cloned()
141            .collect()
142    }
143
144    /// Check if an UPDATE OF trigger should fire based on which columns changed
145    ///
146    /// # Arguments
147    /// * `trigger` - Trigger definition
148    /// * `old_row` - OLD row values
149    /// * `new_row` - NEW row values
150    /// * `table_schema` - Table schema for column lookup
151    ///
152    /// # Returns
153    /// true if the trigger should fire, false otherwise
154    fn should_fire_update_of(
155        trigger: &TriggerDefinition,
156        old_row: &Row,
157        new_row: &Row,
158        table_schema: &TableSchema,
159    ) -> bool {
160        match &trigger.event {
161            TriggerEvent::Update(Some(columns)) => {
162                // Check if any of the specified columns changed
163                for col_name in columns {
164                    if let Some(col_idx) =
165                        table_schema.columns.iter().position(|c| &c.name == col_name)
166                    {
167                        if col_idx < old_row.values.len()
168                            && col_idx < new_row.values.len()
169                            && old_row.values[col_idx] != new_row.values[col_idx]
170                        {
171                            return true; // At least one monitored column changed
172                        }
173                    }
174                }
175                false // None of the monitored columns changed
176            }
177            _ => true, // Not an UPDATE OF trigger, always fire
178        }
179    }
180
181    /// Execute a single trigger
182    ///
183    /// # Arguments
184    /// * `db` - Mutable database reference
185    /// * `trigger` - Trigger definition to execute
186    /// * `old_row` - OLD row for UPDATE/DELETE (None for INSERT)
187    /// * `new_row` - NEW row for INSERT/UPDATE (None for DELETE)
188    ///
189    /// # Returns
190    /// Ok(()) if trigger executed successfully, Err if execution failed
191    ///
192    /// # Notes
193    /// - For ROW-level triggers, this is called once per affected row
194    /// - For STATEMENT-level triggers, this is called once per statement
195    /// - WHEN conditions are evaluated here
196    pub fn execute_trigger(
197        db: &mut Database,
198        trigger: &TriggerDefinition,
199        old_row: Option<&Row>,
200        new_row: Option<&Row>,
201    ) -> Result<(), ExecutorError> {
202        // 1. Evaluate WHEN condition (if present)
203        if let Some(when_expr) = &trigger.when_condition {
204            let condition_result = Self::evaluate_when_condition(
205                db,
206                &trigger.table_name,
207                when_expr,
208                old_row,
209                new_row,
210            )?;
211
212            // Skip trigger execution if WHEN condition is false
213            if !condition_result {
214                return Ok(());
215            }
216        }
217
218        // 2. Execute trigger action
219        Self::execute_trigger_action(db, trigger, old_row, new_row)?;
220
221        Ok(())
222    }
223
224    /// Evaluate WHEN condition for a trigger
225    ///
226    /// # Arguments
227    /// * `db` - Database reference
228    /// * `table_name` - Name of the table
229    /// * `when_expr` - WHEN condition expression
230    /// * `old_row` - OLD row (for UPDATE/DELETE)
231    /// * `new_row` - NEW row (for INSERT/UPDATE)
232    ///
233    /// # Returns
234    /// Ok(true) if condition evaluates to true, Ok(false) otherwise
235    fn evaluate_when_condition(
236        db: &Database,
237        table_name: &str,
238        when_expr: &vibesql_ast::Expression,
239        old_row: Option<&Row>,
240        new_row: Option<&Row>,
241    ) -> Result<bool, ExecutorError> {
242        // Get table schema
243        let schema = db
244            .catalog
245            .get_table(table_name)
246            .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?;
247
248        // Use NEW row as the base row for evaluation (prefer NEW over OLD)
249        // The trigger context will handle OLD/NEW pseudo-variable references
250        let row = new_row.or(old_row).ok_or_else(|| {
251            ExecutorError::UnsupportedExpression(
252                "WHEN condition requires a row context".to_string(),
253            )
254        })?;
255
256        // Create trigger context for OLD/NEW pseudo-variable resolution
257        let trigger_context = TriggerContext { old_row, new_row, table_schema: schema };
258
259        // Create evaluator with trigger context
260        let evaluator =
261            crate::ExpressionEvaluator::with_trigger_context(schema, db, &trigger_context);
262        let result = evaluator.eval(when_expr, row)?;
263
264        // Convert to boolean
265        match result {
266            vibesql_types::SqlValue::Boolean(b) => Ok(b),
267            vibesql_types::SqlValue::Null => Ok(false),
268            _ => Err(ExecutorError::UnsupportedExpression(
269                "WHEN condition must evaluate to boolean".to_string(),
270            )),
271        }
272    }
273
274    /// Execute trigger action statements
275    ///
276    /// # Arguments
277    /// * `db` - Mutable database reference
278    /// * `trigger` - Trigger definition
279    /// * `old_row` - OLD row (for UPDATE/DELETE)
280    /// * `new_row` - NEW row (for INSERT/UPDATE)
281    ///
282    /// # Returns
283    /// Ok(()) if action executed successfully, Err if execution failed
284    fn execute_trigger_action(
285        db: &mut Database,
286        trigger: &TriggerDefinition,
287        old_row: Option<&Row>,
288        new_row: Option<&Row>,
289    ) -> Result<(), ExecutorError> {
290        // Extract SQL from trigger action
291        let sql = match &trigger.triggered_action {
292            vibesql_ast::TriggerAction::RawSql(sql) => sql.clone(),
293        };
294
295        // Parse the trigger action SQL
296        let statements = Self::parse_trigger_sql(&sql)?;
297
298        // Get table schema for trigger context (clone to avoid borrow checker issues)
299        let schema = db
300            .catalog
301            .get_table(&trigger.table_name)
302            .ok_or_else(|| ExecutorError::TableNotFound(trigger.table_name.clone()))?
303            .clone();
304
305        // Create trigger context for OLD/NEW pseudo-variable resolution
306        let trigger_context = TriggerContext { old_row, new_row, table_schema: &schema };
307
308        // Execute each statement in the trigger body with trigger context
309        for statement in statements {
310            Self::execute_statement(db, &statement, &trigger_context)?;
311        }
312
313        Ok(())
314    }
315
316    /// Parse trigger SQL into statements
317    ///
318    /// # Arguments
319    /// * `sql` - Raw SQL string from trigger action
320    ///
321    /// # Returns
322    /// Vector of parsed statements
323    fn parse_trigger_sql(sql: &str) -> Result<Vec<vibesql_ast::Statement>, ExecutorError> {
324        // Strip BEGIN/END wrapper if present
325        let sql = sql.trim();
326        let sql = if sql.to_uppercase().starts_with("BEGIN") {
327            // Remove BEGIN and END
328            let sql = sql[5..].trim();
329            if sql.to_uppercase().ends_with("END") {
330                &sql[..sql.len() - 3]
331            } else {
332                sql
333            }
334        } else {
335            sql
336        };
337
338        // Split by semicolons and parse each statement
339        let mut statements = Vec::new();
340        for stmt_sql in sql.split(';') {
341            let stmt_sql = stmt_sql.trim();
342            if stmt_sql.is_empty() || stmt_sql.starts_with("--") {
343                // Skip empty statements or comments
344                continue;
345            }
346
347            match vibesql_parser::Parser::parse_sql(stmt_sql) {
348                Ok(stmt) => statements.push(stmt),
349                Err(e) => {
350                    return Err(ExecutorError::UnsupportedExpression(format!(
351                        "Failed to parse trigger SQL: {}",
352                        e.message
353                    )))
354                }
355            }
356        }
357
358        // If no statements parsed (e.g., trigger body was only comments), that's OK
359        // Just return empty vector
360        Ok(statements)
361    }
362
363    /// Execute a single statement from trigger body
364    ///
365    /// # Arguments
366    /// * `db` - Mutable database reference
367    /// * `statement` - Statement to execute
368    /// * `trigger_context` - Trigger context with OLD/NEW row data
369    ///
370    /// # Returns
371    /// Ok(()) if statement executed successfully
372    fn execute_statement(
373        db: &mut Database,
374        statement: &vibesql_ast::Statement,
375        trigger_context: &TriggerContext,
376    ) -> Result<(), ExecutorError> {
377        use vibesql_ast::Statement;
378
379        match statement {
380            Statement::Insert(insert_stmt) => {
381                // Execute INSERT with trigger context support
382                crate::insert::execute_insert_with_trigger_context(
383                    db,
384                    insert_stmt,
385                    trigger_context,
386                )?;
387                Ok(())
388            }
389            Statement::Update(update_stmt) => {
390                // Execute UPDATE with trigger context support
391                crate::update::execute_update_with_trigger_context(
392                    db,
393                    update_stmt,
394                    trigger_context,
395                )?;
396                Ok(())
397            }
398            Statement::Delete(delete_stmt) => {
399                // Execute DELETE with trigger context support
400                crate::delete::execute_delete_with_trigger_context(
401                    db,
402                    delete_stmt,
403                    trigger_context,
404                )?;
405                Ok(())
406            }
407            Statement::Select(select_stmt) => {
408                // Execute SELECT but ignore results (useful for side effects)
409                // Note: SELECT doesn't need special trigger context handling since
410                // it can reference OLD/NEW through normal expression evaluation
411                let executor = crate::SelectExecutor::new(db);
412                executor.execute_with_columns(select_stmt)?;
413                Ok(())
414            }
415            _ => Err(ExecutorError::UnsupportedExpression(format!(
416                "Statement type not supported in triggers: {:?}",
417                statement
418            ))),
419        }
420    }
421
422    /// Execute all BEFORE ROW-level triggers for an operation
423    ///
424    /// # Arguments
425    /// * `db` - Mutable database reference
426    /// * `table_name` - Name of the table
427    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
428    /// * `old_row` - OLD row (for UPDATE/DELETE)
429    /// * `new_row` - NEW row (for INSERT/UPDATE)
430    ///
431    /// # Returns
432    /// Ok(()) if all triggers executed successfully
433    pub fn execute_before_triggers(
434        db: &mut Database,
435        table_name: &str,
436        event: TriggerEvent,
437        old_row: Option<&Row>,
438        new_row: Option<&Row>,
439    ) -> Result<(), ExecutorError> {
440        // Check recursion depth before executing any triggers
441        let _guard = RecursionGuard::new()?;
442
443        let triggers = Self::find_triggers(db, table_name, TriggerTiming::Before, event);
444
445        // Get table schema for UPDATE OF checking
446        let table_schema = db
447            .catalog
448            .get_table(table_name)
449            .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?
450            .clone();
451
452        for trigger in triggers {
453            // Only execute ROW-level triggers in this method
454            if trigger.granularity == TriggerGranularity::Row {
455                // For UPDATE OF triggers, check if monitored columns changed
456                if let (Some(old), Some(new)) = (old_row, new_row) {
457                    if !Self::should_fire_update_of(&trigger, old, new, &table_schema) {
458                        continue; // Skip this trigger
459                    }
460                }
461
462                Self::execute_trigger(db, &trigger, old_row, new_row)?;
463            }
464        }
465
466        Ok(())
467    }
468
469    /// Execute all BEFORE STATEMENT-level triggers for an operation
470    ///
471    /// # Arguments
472    /// * `db` - Mutable database reference
473    /// * `table_name` - Name of the table
474    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
475    ///
476    /// # Returns
477    /// Ok(()) if all triggers executed successfully
478    pub fn execute_before_statement_triggers(
479        db: &mut Database,
480        table_name: &str,
481        event: TriggerEvent,
482    ) -> Result<(), ExecutorError> {
483        // Check recursion depth before executing any triggers
484        let _guard = RecursionGuard::new()?;
485
486        let triggers = Self::find_triggers(db, table_name, TriggerTiming::Before, event);
487
488        for trigger in triggers {
489            // Only execute STATEMENT-level triggers in this method
490            if trigger.granularity == TriggerGranularity::Statement {
491                // Statement-level triggers don't have OLD/NEW row access
492                Self::execute_trigger(db, &trigger, None, None)?;
493            }
494        }
495
496        Ok(())
497    }
498
499    /// Execute all AFTER ROW-level triggers for an operation
500    ///
501    /// # Arguments
502    /// * `db` - Mutable database reference
503    /// * `table_name` - Name of the table
504    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
505    /// * `old_row` - OLD row (for UPDATE/DELETE)
506    /// * `new_row` - NEW row (for INSERT/UPDATE)
507    ///
508    /// # Returns
509    /// Ok(()) if all triggers executed successfully
510    pub fn execute_after_triggers(
511        db: &mut Database,
512        table_name: &str,
513        event: TriggerEvent,
514        old_row: Option<&Row>,
515        new_row: Option<&Row>,
516    ) -> Result<(), ExecutorError> {
517        // Check recursion depth before executing any triggers
518        let _guard = RecursionGuard::new()?;
519
520        let triggers = Self::find_triggers(db, table_name, TriggerTiming::After, event);
521
522        // Get table schema for UPDATE OF checking
523        let table_schema = db
524            .catalog
525            .get_table(table_name)
526            .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?
527            .clone();
528
529        for trigger in triggers {
530            // Only execute ROW-level triggers in this method
531            if trigger.granularity == TriggerGranularity::Row {
532                // For UPDATE OF triggers, check if monitored columns changed
533                if let (Some(old), Some(new)) = (old_row, new_row) {
534                    if !Self::should_fire_update_of(&trigger, old, new, &table_schema) {
535                        continue; // Skip this trigger
536                    }
537                }
538
539                Self::execute_trigger(db, &trigger, old_row, new_row)?;
540            }
541        }
542
543        Ok(())
544    }
545
546    /// Execute all AFTER STATEMENT-level triggers for an operation
547    ///
548    /// # Arguments
549    /// * `db` - Mutable database reference
550    /// * `table_name` - Name of the table
551    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
552    ///
553    /// # Returns
554    /// Ok(()) if all triggers executed successfully
555    pub fn execute_after_statement_triggers(
556        db: &mut Database,
557        table_name: &str,
558        event: TriggerEvent,
559    ) -> Result<(), ExecutorError> {
560        // Check recursion depth before executing any triggers
561        let _guard = RecursionGuard::new()?;
562
563        let triggers = Self::find_triggers(db, table_name, TriggerTiming::After, event);
564
565        for trigger in triggers {
566            // Only execute STATEMENT-level triggers in this method
567            if trigger.granularity == TriggerGranularity::Statement {
568                // Statement-level triggers don't have OLD/NEW row access
569                Self::execute_trigger(db, &trigger, None, None)?;
570            }
571        }
572
573        Ok(())
574    }
575}