vibesql_executor/
trigger_execution.rs

1//! Trigger execution logic for firing triggers on DML operations
2
3use std::cell::Cell;
4
5use vibesql_ast::{PseudoTable, TriggerEvent, TriggerGranularity, TriggerTiming};
6use vibesql_catalog::{TableSchema, TriggerDefinition};
7use vibesql_storage::{Database, Row};
8use vibesql_types::SqlValue;
9
10use crate::errors::ExecutorError;
11
12/// Maximum trigger recursion depth to prevent infinite loops
13const MAX_TRIGGER_RECURSION_DEPTH: usize = 16;
14
15thread_local! {
16    /// Current trigger recursion depth for this thread
17    static TRIGGER_RECURSION_DEPTH: Cell<usize> = const { Cell::new(0) };
18}
19
20/// RAII guard for managing trigger recursion depth
21/// Increments depth on creation, decrements on drop
22struct RecursionGuard;
23
24impl RecursionGuard {
25    /// Create a new recursion guard, incrementing the depth
26    ///
27    /// # Returns
28    /// Ok(RecursionGuard) if depth is within limits, Err if limit exceeded
29    fn new() -> Result<Self, ExecutorError> {
30        TRIGGER_RECURSION_DEPTH.with(|depth| {
31            let current = depth.get();
32            if current >= MAX_TRIGGER_RECURSION_DEPTH {
33                Err(ExecutorError::UnsupportedExpression(format!(
34                    "Trigger recursion depth limit exceeded (max: {}). Possible infinite trigger loop.",
35                    MAX_TRIGGER_RECURSION_DEPTH
36                )))
37            } else {
38                depth.set(current + 1);
39                Ok(RecursionGuard)
40            }
41        })
42    }
43}
44
45impl Drop for RecursionGuard {
46    fn drop(&mut self) {
47        TRIGGER_RECURSION_DEPTH.with(|depth| {
48            depth.set(depth.get().saturating_sub(1));
49        });
50    }
51}
52
53/// Execution context for triggers with OLD/NEW row access
54/// Provides pseudo-variable resolution for trigger bodies
55pub struct TriggerContext<'a> {
56    /// OLD row - available for UPDATE and DELETE triggers
57    pub old_row: Option<&'a Row>,
58    /// NEW row - available for INSERT and UPDATE triggers
59    pub new_row: Option<&'a Row>,
60    /// Table schema for column lookups
61    pub table_schema: &'a TableSchema,
62}
63
64impl<'a> TriggerContext<'a> {
65    /// Resolve a pseudo-variable reference to a SqlValue
66    ///
67    /// # Arguments
68    /// * `pseudo_table` - Which pseudo-table (OLD or NEW)
69    /// * `column` - Column name to retrieve
70    ///
71    /// # Returns
72    /// Ok(SqlValue) with the column value, or Err if invalid
73    ///
74    /// # Errors
75    /// - If OLD/NEW is not available for this trigger type
76    /// - If column doesn't exist in table schema
77    pub fn resolve_pseudo_var(
78        &self,
79        pseudo_table: PseudoTable,
80        column: &str,
81    ) -> Result<SqlValue, ExecutorError> {
82        // Get the appropriate row
83        let row = match pseudo_table {
84            PseudoTable::Old => self.old_row.ok_or_else(|| {
85                ExecutorError::UnsupportedExpression(
86                    "OLD pseudo-variable not available in this trigger context".to_string(),
87                )
88            })?,
89            PseudoTable::New => self.new_row.ok_or_else(|| {
90                ExecutorError::UnsupportedExpression(
91                    "NEW pseudo-variable not available in this trigger context".to_string(),
92                )
93            })?,
94        };
95
96        // Find column index in schema
97        let col_idx =
98            self.table_schema.columns.iter().position(|c| c.name == column).ok_or_else(|| {
99                ExecutorError::ColumnNotFound {
100                    column_name: column.to_string(),
101                    table_name: self.table_schema.name.clone(),
102                    searched_tables: vec![self.table_schema.name.clone()],
103                    available_columns: self
104                        .table_schema
105                        .columns
106                        .iter()
107                        .map(|c| c.name.clone())
108                        .collect(),
109                }
110            })?;
111
112        // Return the value
113        Ok(row.values[col_idx].clone())
114    }
115}
116
117/// Helper struct for trigger firing (execution during DML operations)
118pub struct TriggerFirer;
119
120impl TriggerFirer {
121    /// Find triggers for a table and event
122    ///
123    /// # Arguments
124    /// * `db` - Database reference
125    /// * `table_name` - Name of the table to find triggers for
126    /// * `timing` - Trigger timing (BEFORE, AFTER, INSTEAD OF)
127    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
128    ///
129    /// # Returns
130    /// Vector of trigger definitions matching the criteria, sorted by creation order
131    pub fn find_triggers(
132        db: &Database,
133        table_name: &str,
134        timing: TriggerTiming,
135        event: TriggerEvent,
136    ) -> Vec<TriggerDefinition> {
137        db.catalog
138            .get_triggers_for_table(table_name, Some(event.clone()))
139            .filter(|trigger| trigger.timing == timing && trigger.enabled) // Skip disabled triggers
140            .cloned()
141            .collect()
142    }
143
144    /// Check if an UPDATE OF trigger should fire based on which columns changed
145    ///
146    /// # Arguments
147    /// * `trigger` - Trigger definition
148    /// * `old_row` - OLD row values
149    /// * `new_row` - NEW row values
150    /// * `table_schema` - Table schema for column lookup
151    ///
152    /// # Returns
153    /// true if the trigger should fire, false otherwise
154    fn should_fire_update_of(
155        trigger: &TriggerDefinition,
156        old_row: &Row,
157        new_row: &Row,
158        table_schema: &TableSchema,
159    ) -> bool {
160        match &trigger.event {
161            TriggerEvent::Update(Some(columns)) => {
162                // Check if any of the specified columns changed
163                for col_name in columns {
164                    if let Some(col_idx) =
165                        table_schema.columns.iter().position(|c| &c.name == col_name)
166                    {
167                        if col_idx < old_row.values.len()
168                            && col_idx < new_row.values.len()
169                            && old_row.values[col_idx] != new_row.values[col_idx]
170                        {
171                            return true; // At least one monitored column changed
172                        }
173                    }
174                }
175                false // None of the monitored columns changed
176            }
177            _ => true, // Not an UPDATE OF trigger, always fire
178        }
179    }
180
181    /// Execute a single trigger
182    ///
183    /// # Arguments
184    /// * `db` - Mutable database reference
185    /// * `trigger` - Trigger definition to execute
186    /// * `old_row` - OLD row for UPDATE/DELETE (None for INSERT)
187    /// * `new_row` - NEW row for INSERT/UPDATE (None for DELETE)
188    ///
189    /// # Returns
190    /// Ok(()) if trigger executed successfully, Err if execution failed
191    ///
192    /// # Notes
193    /// - For ROW-level triggers, this is called once per affected row
194    /// - For STATEMENT-level triggers, this is called once per statement
195    /// - WHEN conditions are evaluated here
196    pub fn execute_trigger(
197        db: &mut Database,
198        trigger: &TriggerDefinition,
199        old_row: Option<&Row>,
200        new_row: Option<&Row>,
201    ) -> Result<(), ExecutorError> {
202        // 1. Evaluate WHEN condition (if present)
203        if let Some(when_expr) = &trigger.when_condition {
204            let condition_result = Self::evaluate_when_condition(
205                db,
206                &trigger.table_name,
207                when_expr,
208                old_row,
209                new_row,
210            )?;
211
212            // Skip trigger execution if WHEN condition is false
213            if !condition_result {
214                return Ok(());
215            }
216        }
217
218        // 2. Execute trigger action
219        Self::execute_trigger_action(db, trigger, old_row, new_row)?;
220
221        Ok(())
222    }
223
224    /// Evaluate WHEN condition for a trigger
225    ///
226    /// # Arguments
227    /// * `db` - Database reference
228    /// * `table_name` - Name of the table
229    /// * `when_expr` - WHEN condition expression
230    /// * `old_row` - OLD row (for UPDATE/DELETE)
231    /// * `new_row` - NEW row (for INSERT/UPDATE)
232    ///
233    /// # Returns
234    /// Ok(true) if condition evaluates to true, Ok(false) otherwise
235    fn evaluate_when_condition(
236        db: &Database,
237        table_name: &str,
238        when_expr: &vibesql_ast::Expression,
239        old_row: Option<&Row>,
240        new_row: Option<&Row>,
241    ) -> Result<bool, ExecutorError> {
242        // Get table schema
243        let schema = db
244            .catalog
245            .get_table(table_name)
246            .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?;
247
248        // Use NEW row as the base row for evaluation (prefer NEW over OLD)
249        // The trigger context will handle OLD/NEW pseudo-variable references
250        let row = new_row.or(old_row).ok_or_else(|| {
251            ExecutorError::UnsupportedExpression(
252                "WHEN condition requires a row context".to_string(),
253            )
254        })?;
255
256        // Create trigger context for OLD/NEW pseudo-variable resolution
257        let trigger_context = TriggerContext { old_row, new_row, table_schema: schema };
258
259        // Create evaluator with trigger context
260        let evaluator =
261            crate::ExpressionEvaluator::with_trigger_context(schema, db, &trigger_context);
262        let result = evaluator.eval(when_expr, row)?;
263
264        // Convert to boolean
265        match result {
266            vibesql_types::SqlValue::Boolean(b) => Ok(b),
267            vibesql_types::SqlValue::Null => Ok(false),
268            _ => Err(ExecutorError::UnsupportedExpression(
269                "WHEN condition must evaluate to boolean".to_string(),
270            )),
271        }
272    }
273
274    /// Execute trigger action statements
275    ///
276    /// # Arguments
277    /// * `db` - Mutable database reference
278    /// * `trigger` - Trigger definition
279    /// * `old_row` - OLD row (for UPDATE/DELETE)
280    /// * `new_row` - NEW row (for INSERT/UPDATE)
281    ///
282    /// # Returns
283    /// Ok(()) if action executed successfully, Err if execution failed
284    fn execute_trigger_action(
285        db: &mut Database,
286        trigger: &TriggerDefinition,
287        old_row: Option<&Row>,
288        new_row: Option<&Row>,
289    ) -> Result<(), ExecutorError> {
290        // Extract SQL from trigger action
291        let sql = match &trigger.triggered_action {
292            vibesql_ast::TriggerAction::RawSql(sql) => sql.clone(),
293        };
294
295        // Parse the trigger action SQL
296        let statements = Self::parse_trigger_sql(&sql)?;
297
298        // Get table schema for trigger context (clone to avoid borrow checker issues)
299        // For INSTEAD OF triggers on views, build schema from the view definition
300        let schema = if let Some(table_schema) = db.catalog.get_table(&trigger.table_name) {
301            table_schema.clone()
302        } else if let Some(view_def) = db.catalog.get_view(&trigger.table_name) {
303            // Build a pseudo-schema from the view for OLD/NEW column resolution
304            Self::build_view_schema(db, view_def)?
305        } else {
306            return Err(ExecutorError::TableNotFound(trigger.table_name.clone()));
307        };
308
309        // Create trigger context for OLD/NEW pseudo-variable resolution
310        let trigger_context = TriggerContext { old_row, new_row, table_schema: &schema };
311
312        // Execute each statement in the trigger body with trigger context
313        for statement in statements {
314            Self::execute_statement(db, &statement, &trigger_context)?;
315        }
316
317        Ok(())
318    }
319
320    /// Build a pseudo TableSchema from a view definition for trigger OLD/NEW column resolution
321    fn build_view_schema(
322        db: &Database,
323        view_def: &vibesql_catalog::ViewDefinition,
324    ) -> Result<TableSchema, ExecutorError> {
325        // Execute the view's SELECT query to get column names
326        let select_executor = crate::SelectExecutor::new(db);
327        let result = select_executor.execute_with_columns(&view_def.query)?;
328
329        // Use explicit column names if provided, otherwise derive from SELECT
330        let column_names: Vec<String> = if let Some(ref cols) = view_def.columns {
331            cols.clone()
332        } else {
333            result.columns.clone()
334        };
335
336        // Build columns with a generic data type
337        let columns: Vec<vibesql_catalog::ColumnSchema> = column_names
338            .into_iter()
339            .map(|name| {
340                vibesql_catalog::ColumnSchema::new(
341                    name,
342                    vibesql_types::DataType::Varchar { max_length: None },
343                    true,
344                )
345            })
346            .collect();
347
348        Ok(TableSchema::new(view_def.name.clone(), columns))
349    }
350
351    /// Parse trigger SQL into statements
352    ///
353    /// # Arguments
354    /// * `sql` - Raw SQL string from trigger action
355    ///
356    /// # Returns
357    /// Vector of parsed statements
358    fn parse_trigger_sql(sql: &str) -> Result<Vec<vibesql_ast::Statement>, ExecutorError> {
359        // Strip BEGIN/END wrapper if present
360        let sql = sql.trim();
361        let sql = if sql.to_uppercase().starts_with("BEGIN") {
362            // Remove BEGIN and END
363            let sql = sql[5..].trim();
364            if sql.to_uppercase().ends_with("END") {
365                &sql[..sql.len() - 3]
366            } else {
367                sql
368            }
369        } else {
370            sql
371        };
372
373        // Split by semicolons and parse each statement
374        let mut statements = Vec::new();
375        for stmt_sql in sql.split(';') {
376            let stmt_sql = stmt_sql.trim();
377            if stmt_sql.is_empty() || stmt_sql.starts_with("--") {
378                // Skip empty statements or comments
379                continue;
380            }
381
382            match vibesql_parser::Parser::parse_sql(stmt_sql) {
383                Ok(stmt) => statements.push(stmt),
384                Err(e) => {
385                    return Err(ExecutorError::UnsupportedExpression(format!(
386                        "Failed to parse trigger SQL: {}",
387                        e.message
388                    )))
389                }
390            }
391        }
392
393        // If no statements parsed (e.g., trigger body was only comments), that's OK
394        // Just return empty vector
395        Ok(statements)
396    }
397
398    /// Execute a single statement from trigger body
399    ///
400    /// # Arguments
401    /// * `db` - Mutable database reference
402    /// * `statement` - Statement to execute
403    /// * `trigger_context` - Trigger context with OLD/NEW row data
404    ///
405    /// # Returns
406    /// Ok(()) if statement executed successfully
407    fn execute_statement(
408        db: &mut Database,
409        statement: &vibesql_ast::Statement,
410        trigger_context: &TriggerContext,
411    ) -> Result<(), ExecutorError> {
412        use vibesql_ast::Statement;
413
414        match statement {
415            Statement::Insert(insert_stmt) => {
416                // Execute INSERT with trigger context support
417                crate::insert::execute_insert_with_trigger_context(
418                    db,
419                    insert_stmt,
420                    trigger_context,
421                )?;
422                Ok(())
423            }
424            Statement::Update(update_stmt) => {
425                // Execute UPDATE with trigger context support
426                crate::update::execute_update_with_trigger_context(
427                    db,
428                    update_stmt,
429                    trigger_context,
430                )?;
431                Ok(())
432            }
433            Statement::Delete(delete_stmt) => {
434                // Execute DELETE with trigger context support
435                crate::delete::execute_delete_with_trigger_context(
436                    db,
437                    delete_stmt,
438                    trigger_context,
439                )?;
440                Ok(())
441            }
442            Statement::Select(select_stmt) => {
443                // Execute SELECT but ignore results (useful for side effects)
444                // Note: SELECT doesn't need special trigger context handling since
445                // it can reference OLD/NEW through normal expression evaluation
446                let executor = crate::SelectExecutor::new(db);
447                executor.execute_with_columns(select_stmt)?;
448                Ok(())
449            }
450            _ => Err(ExecutorError::UnsupportedExpression(format!(
451                "Statement type not supported in triggers: {:?}",
452                statement
453            ))),
454        }
455    }
456
457    /// Execute all BEFORE ROW-level triggers for an operation
458    ///
459    /// # Arguments
460    /// * `db` - Mutable database reference
461    /// * `table_name` - Name of the table
462    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
463    /// * `old_row` - OLD row (for UPDATE/DELETE)
464    /// * `new_row` - NEW row (for INSERT/UPDATE)
465    ///
466    /// # Returns
467    /// Ok(()) if all triggers executed successfully
468    pub fn execute_before_triggers(
469        db: &mut Database,
470        table_name: &str,
471        event: TriggerEvent,
472        old_row: Option<&Row>,
473        new_row: Option<&Row>,
474    ) -> Result<(), ExecutorError> {
475        // Check recursion depth before executing any triggers
476        let _guard = RecursionGuard::new()?;
477
478        let triggers = Self::find_triggers(db, table_name, TriggerTiming::Before, event);
479
480        // Get table schema for UPDATE OF checking
481        let table_schema = db
482            .catalog
483            .get_table(table_name)
484            .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?
485            .clone();
486
487        for trigger in triggers {
488            // Only execute ROW-level triggers in this method
489            if trigger.granularity == TriggerGranularity::Row {
490                // For UPDATE OF triggers, check if monitored columns changed
491                if let (Some(old), Some(new)) = (old_row, new_row) {
492                    if !Self::should_fire_update_of(&trigger, old, new, &table_schema) {
493                        continue; // Skip this trigger
494                    }
495                }
496
497                Self::execute_trigger(db, &trigger, old_row, new_row)?;
498            }
499        }
500
501        Ok(())
502    }
503
504    /// Execute all BEFORE STATEMENT-level triggers for an operation
505    ///
506    /// # Arguments
507    /// * `db` - Mutable database reference
508    /// * `table_name` - Name of the table
509    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
510    ///
511    /// # Returns
512    /// Ok(()) if all triggers executed successfully
513    pub fn execute_before_statement_triggers(
514        db: &mut Database,
515        table_name: &str,
516        event: TriggerEvent,
517    ) -> Result<(), ExecutorError> {
518        // Check recursion depth before executing any triggers
519        let _guard = RecursionGuard::new()?;
520
521        let triggers = Self::find_triggers(db, table_name, TriggerTiming::Before, event);
522
523        for trigger in triggers {
524            // Only execute STATEMENT-level triggers in this method
525            if trigger.granularity == TriggerGranularity::Statement {
526                // Statement-level triggers don't have OLD/NEW row access
527                Self::execute_trigger(db, &trigger, None, None)?;
528            }
529        }
530
531        Ok(())
532    }
533
534    /// Execute all AFTER ROW-level triggers for an operation
535    ///
536    /// # Arguments
537    /// * `db` - Mutable database reference
538    /// * `table_name` - Name of the table
539    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
540    /// * `old_row` - OLD row (for UPDATE/DELETE)
541    /// * `new_row` - NEW row (for INSERT/UPDATE)
542    ///
543    /// # Returns
544    /// Ok(()) if all triggers executed successfully
545    pub fn execute_after_triggers(
546        db: &mut Database,
547        table_name: &str,
548        event: TriggerEvent,
549        old_row: Option<&Row>,
550        new_row: Option<&Row>,
551    ) -> Result<(), ExecutorError> {
552        // Check recursion depth before executing any triggers
553        let _guard = RecursionGuard::new()?;
554
555        let triggers = Self::find_triggers(db, table_name, TriggerTiming::After, event);
556
557        // Get table schema for UPDATE OF checking
558        let table_schema = db
559            .catalog
560            .get_table(table_name)
561            .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?
562            .clone();
563
564        for trigger in triggers {
565            // Only execute ROW-level triggers in this method
566            if trigger.granularity == TriggerGranularity::Row {
567                // For UPDATE OF triggers, check if monitored columns changed
568                if let (Some(old), Some(new)) = (old_row, new_row) {
569                    if !Self::should_fire_update_of(&trigger, old, new, &table_schema) {
570                        continue; // Skip this trigger
571                    }
572                }
573
574                Self::execute_trigger(db, &trigger, old_row, new_row)?;
575            }
576        }
577
578        Ok(())
579    }
580
581    /// Execute all AFTER STATEMENT-level triggers for an operation
582    ///
583    /// # Arguments
584    /// * `db` - Mutable database reference
585    /// * `table_name` - Name of the table
586    /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
587    ///
588    /// # Returns
589    /// Ok(()) if all triggers executed successfully
590    pub fn execute_after_statement_triggers(
591        db: &mut Database,
592        table_name: &str,
593        event: TriggerEvent,
594    ) -> Result<(), ExecutorError> {
595        // Check recursion depth before executing any triggers
596        let _guard = RecursionGuard::new()?;
597
598        let triggers = Self::find_triggers(db, table_name, TriggerTiming::After, event);
599
600        for trigger in triggers {
601            // Only execute STATEMENT-level triggers in this method
602            if trigger.granularity == TriggerGranularity::Statement {
603                // Statement-level triggers don't have OLD/NEW row access
604                Self::execute_trigger(db, &trigger, None, None)?;
605            }
606        }
607
608        Ok(())
609    }
610}