vibesql_executor/trigger_execution.rs
1//! Trigger execution logic for firing triggers on DML operations
2
3use std::cell::Cell;
4
5use vibesql_ast::{PseudoTable, TriggerEvent, TriggerGranularity, TriggerTiming};
6use vibesql_catalog::{TableSchema, TriggerDefinition};
7use vibesql_storage::{Database, Row};
8use vibesql_types::SqlValue;
9
10use crate::errors::ExecutorError;
11
12/// Maximum trigger recursion depth to prevent infinite loops
13const MAX_TRIGGER_RECURSION_DEPTH: usize = 16;
14
15thread_local! {
16 /// Current trigger recursion depth for this thread
17 static TRIGGER_RECURSION_DEPTH: Cell<usize> = const { Cell::new(0) };
18}
19
20/// RAII guard for managing trigger recursion depth
21/// Increments depth on creation, decrements on drop
22struct RecursionGuard;
23
24impl RecursionGuard {
25 /// Create a new recursion guard, incrementing the depth
26 ///
27 /// # Returns
28 /// Ok(RecursionGuard) if depth is within limits, Err if limit exceeded
29 fn new() -> Result<Self, ExecutorError> {
30 TRIGGER_RECURSION_DEPTH.with(|depth| {
31 let current = depth.get();
32 if current >= MAX_TRIGGER_RECURSION_DEPTH {
33 Err(ExecutorError::UnsupportedExpression(format!(
34 "Trigger recursion depth limit exceeded (max: {}). Possible infinite trigger loop.",
35 MAX_TRIGGER_RECURSION_DEPTH
36 )))
37 } else {
38 depth.set(current + 1);
39 Ok(RecursionGuard)
40 }
41 })
42 }
43}
44
45impl Drop for RecursionGuard {
46 fn drop(&mut self) {
47 TRIGGER_RECURSION_DEPTH.with(|depth| {
48 depth.set(depth.get().saturating_sub(1));
49 });
50 }
51}
52
53/// Execution context for triggers with OLD/NEW row access
54/// Provides pseudo-variable resolution for trigger bodies
55pub struct TriggerContext<'a> {
56 /// OLD row - available for UPDATE and DELETE triggers
57 pub old_row: Option<&'a Row>,
58 /// NEW row - available for INSERT and UPDATE triggers
59 pub new_row: Option<&'a Row>,
60 /// Table schema for column lookups
61 pub table_schema: &'a TableSchema,
62}
63
64impl<'a> TriggerContext<'a> {
65 /// Resolve a pseudo-variable reference to a SqlValue
66 ///
67 /// # Arguments
68 /// * `pseudo_table` - Which pseudo-table (OLD or NEW)
69 /// * `column` - Column name to retrieve
70 ///
71 /// # Returns
72 /// Ok(SqlValue) with the column value, or Err if invalid
73 ///
74 /// # Errors
75 /// - If OLD/NEW is not available for this trigger type
76 /// - If column doesn't exist in table schema
77 pub fn resolve_pseudo_var(
78 &self,
79 pseudo_table: PseudoTable,
80 column: &str,
81 ) -> Result<SqlValue, ExecutorError> {
82 // Get the appropriate row
83 let row = match pseudo_table {
84 PseudoTable::Old => self.old_row.ok_or_else(|| {
85 ExecutorError::UnsupportedExpression(
86 "OLD pseudo-variable not available in this trigger context".to_string(),
87 )
88 })?,
89 PseudoTable::New => self.new_row.ok_or_else(|| {
90 ExecutorError::UnsupportedExpression(
91 "NEW pseudo-variable not available in this trigger context".to_string(),
92 )
93 })?,
94 };
95
96 // Find column index in schema
97 let col_idx =
98 self.table_schema.columns.iter().position(|c| c.name == column).ok_or_else(|| {
99 ExecutorError::ColumnNotFound {
100 column_name: column.to_string(),
101 table_name: self.table_schema.name.clone(),
102 searched_tables: vec![self.table_schema.name.clone()],
103 available_columns: self
104 .table_schema
105 .columns
106 .iter()
107 .map(|c| c.name.clone())
108 .collect(),
109 }
110 })?;
111
112 // Return the value
113 Ok(row.values[col_idx].clone())
114 }
115}
116
117/// Helper struct for trigger firing (execution during DML operations)
118pub struct TriggerFirer;
119
120impl TriggerFirer {
121 /// Find triggers for a table and event
122 ///
123 /// # Arguments
124 /// * `db` - Database reference
125 /// * `table_name` - Name of the table to find triggers for
126 /// * `timing` - Trigger timing (BEFORE, AFTER, INSTEAD OF)
127 /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
128 ///
129 /// # Returns
130 /// Vector of trigger definitions matching the criteria, sorted by creation order
131 pub fn find_triggers(
132 db: &Database,
133 table_name: &str,
134 timing: TriggerTiming,
135 event: TriggerEvent,
136 ) -> Vec<TriggerDefinition> {
137 db.catalog
138 .get_triggers_for_table(table_name, Some(event.clone()))
139 .filter(|trigger| trigger.timing == timing && trigger.enabled) // Skip disabled triggers
140 .cloned()
141 .collect()
142 }
143
144 /// Check if an UPDATE OF trigger should fire based on which columns changed
145 ///
146 /// # Arguments
147 /// * `trigger` - Trigger definition
148 /// * `old_row` - OLD row values
149 /// * `new_row` - NEW row values
150 /// * `table_schema` - Table schema for column lookup
151 ///
152 /// # Returns
153 /// true if the trigger should fire, false otherwise
154 fn should_fire_update_of(
155 trigger: &TriggerDefinition,
156 old_row: &Row,
157 new_row: &Row,
158 table_schema: &TableSchema,
159 ) -> bool {
160 match &trigger.event {
161 TriggerEvent::Update(Some(columns)) => {
162 // Check if any of the specified columns changed
163 for col_name in columns {
164 if let Some(col_idx) =
165 table_schema.columns.iter().position(|c| &c.name == col_name)
166 {
167 if col_idx < old_row.values.len()
168 && col_idx < new_row.values.len()
169 && old_row.values[col_idx] != new_row.values[col_idx]
170 {
171 return true; // At least one monitored column changed
172 }
173 }
174 }
175 false // None of the monitored columns changed
176 }
177 _ => true, // Not an UPDATE OF trigger, always fire
178 }
179 }
180
181 /// Execute a single trigger
182 ///
183 /// # Arguments
184 /// * `db` - Mutable database reference
185 /// * `trigger` - Trigger definition to execute
186 /// * `old_row` - OLD row for UPDATE/DELETE (None for INSERT)
187 /// * `new_row` - NEW row for INSERT/UPDATE (None for DELETE)
188 ///
189 /// # Returns
190 /// Ok(()) if trigger executed successfully, Err if execution failed
191 ///
192 /// # Notes
193 /// - For ROW-level triggers, this is called once per affected row
194 /// - For STATEMENT-level triggers, this is called once per statement
195 /// - WHEN conditions are evaluated here
196 pub fn execute_trigger(
197 db: &mut Database,
198 trigger: &TriggerDefinition,
199 old_row: Option<&Row>,
200 new_row: Option<&Row>,
201 ) -> Result<(), ExecutorError> {
202 // 1. Evaluate WHEN condition (if present)
203 if let Some(when_expr) = &trigger.when_condition {
204 let condition_result = Self::evaluate_when_condition(
205 db,
206 &trigger.table_name,
207 when_expr,
208 old_row,
209 new_row,
210 )?;
211
212 // Skip trigger execution if WHEN condition is false
213 if !condition_result {
214 return Ok(());
215 }
216 }
217
218 // 2. Execute trigger action
219 Self::execute_trigger_action(db, trigger, old_row, new_row)?;
220
221 Ok(())
222 }
223
224 /// Evaluate WHEN condition for a trigger
225 ///
226 /// # Arguments
227 /// * `db` - Database reference
228 /// * `table_name` - Name of the table
229 /// * `when_expr` - WHEN condition expression
230 /// * `old_row` - OLD row (for UPDATE/DELETE)
231 /// * `new_row` - NEW row (for INSERT/UPDATE)
232 ///
233 /// # Returns
234 /// Ok(true) if condition evaluates to true, Ok(false) otherwise
235 fn evaluate_when_condition(
236 db: &Database,
237 table_name: &str,
238 when_expr: &vibesql_ast::Expression,
239 old_row: Option<&Row>,
240 new_row: Option<&Row>,
241 ) -> Result<bool, ExecutorError> {
242 // Get table schema
243 let schema = db
244 .catalog
245 .get_table(table_name)
246 .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?;
247
248 // Use NEW row as the base row for evaluation (prefer NEW over OLD)
249 // The trigger context will handle OLD/NEW pseudo-variable references
250 let row = new_row.or(old_row).ok_or_else(|| {
251 ExecutorError::UnsupportedExpression(
252 "WHEN condition requires a row context".to_string(),
253 )
254 })?;
255
256 // Create trigger context for OLD/NEW pseudo-variable resolution
257 let trigger_context = TriggerContext { old_row, new_row, table_schema: schema };
258
259 // Create evaluator with trigger context
260 let evaluator =
261 crate::ExpressionEvaluator::with_trigger_context(schema, db, &trigger_context);
262 let result = evaluator.eval(when_expr, row)?;
263
264 // Convert to boolean
265 match result {
266 vibesql_types::SqlValue::Boolean(b) => Ok(b),
267 vibesql_types::SqlValue::Null => Ok(false),
268 _ => Err(ExecutorError::UnsupportedExpression(
269 "WHEN condition must evaluate to boolean".to_string(),
270 )),
271 }
272 }
273
274 /// Execute trigger action statements
275 ///
276 /// # Arguments
277 /// * `db` - Mutable database reference
278 /// * `trigger` - Trigger definition
279 /// * `old_row` - OLD row (for UPDATE/DELETE)
280 /// * `new_row` - NEW row (for INSERT/UPDATE)
281 ///
282 /// # Returns
283 /// Ok(()) if action executed successfully, Err if execution failed
284 fn execute_trigger_action(
285 db: &mut Database,
286 trigger: &TriggerDefinition,
287 old_row: Option<&Row>,
288 new_row: Option<&Row>,
289 ) -> Result<(), ExecutorError> {
290 // Extract SQL from trigger action
291 let sql = match &trigger.triggered_action {
292 vibesql_ast::TriggerAction::RawSql(sql) => sql.clone(),
293 };
294
295 // Parse the trigger action SQL
296 let statements = Self::parse_trigger_sql(&sql)?;
297
298 // Get table schema for trigger context (clone to avoid borrow checker issues)
299 // For INSTEAD OF triggers on views, build schema from the view definition
300 let schema = if let Some(table_schema) = db.catalog.get_table(&trigger.table_name) {
301 table_schema.clone()
302 } else if let Some(view_def) = db.catalog.get_view(&trigger.table_name) {
303 // Build a pseudo-schema from the view for OLD/NEW column resolution
304 Self::build_view_schema(db, view_def)?
305 } else {
306 return Err(ExecutorError::TableNotFound(trigger.table_name.clone()));
307 };
308
309 // Create trigger context for OLD/NEW pseudo-variable resolution
310 let trigger_context = TriggerContext { old_row, new_row, table_schema: &schema };
311
312 // Execute each statement in the trigger body with trigger context
313 for statement in statements {
314 Self::execute_statement(db, &statement, &trigger_context)?;
315 }
316
317 Ok(())
318 }
319
320 /// Build a pseudo TableSchema from a view definition for trigger OLD/NEW column resolution
321 fn build_view_schema(
322 db: &Database,
323 view_def: &vibesql_catalog::ViewDefinition,
324 ) -> Result<TableSchema, ExecutorError> {
325 // Execute the view's SELECT query to get column names
326 let select_executor = crate::SelectExecutor::new(db);
327 let result = select_executor.execute_with_columns(&view_def.query)?;
328
329 // Use explicit column names if provided, otherwise derive from SELECT
330 let column_names: Vec<String> = if let Some(ref cols) = view_def.columns {
331 cols.clone()
332 } else {
333 result.columns.clone()
334 };
335
336 // Build columns with a generic data type
337 let columns: Vec<vibesql_catalog::ColumnSchema> = column_names
338 .into_iter()
339 .map(|name| {
340 vibesql_catalog::ColumnSchema::new(
341 name,
342 vibesql_types::DataType::Varchar { max_length: None },
343 true,
344 )
345 })
346 .collect();
347
348 Ok(TableSchema::new(view_def.name.clone(), columns))
349 }
350
351 /// Parse trigger SQL into statements
352 ///
353 /// # Arguments
354 /// * `sql` - Raw SQL string from trigger action
355 ///
356 /// # Returns
357 /// Vector of parsed statements
358 fn parse_trigger_sql(sql: &str) -> Result<Vec<vibesql_ast::Statement>, ExecutorError> {
359 // Strip BEGIN/END wrapper if present
360 let sql = sql.trim();
361 let sql = if sql.to_uppercase().starts_with("BEGIN") {
362 // Remove BEGIN and END
363 let sql = sql[5..].trim();
364 if sql.to_uppercase().ends_with("END") {
365 &sql[..sql.len() - 3]
366 } else {
367 sql
368 }
369 } else {
370 sql
371 };
372
373 // Split by semicolons and parse each statement
374 let mut statements = Vec::new();
375 for stmt_sql in sql.split(';') {
376 let stmt_sql = stmt_sql.trim();
377 if stmt_sql.is_empty() || stmt_sql.starts_with("--") {
378 // Skip empty statements or comments
379 continue;
380 }
381
382 match vibesql_parser::Parser::parse_sql(stmt_sql) {
383 Ok(stmt) => statements.push(stmt),
384 Err(e) => {
385 return Err(ExecutorError::UnsupportedExpression(format!(
386 "Failed to parse trigger SQL: {}",
387 e.message
388 )))
389 }
390 }
391 }
392
393 // If no statements parsed (e.g., trigger body was only comments), that's OK
394 // Just return empty vector
395 Ok(statements)
396 }
397
398 /// Execute a single statement from trigger body
399 ///
400 /// # Arguments
401 /// * `db` - Mutable database reference
402 /// * `statement` - Statement to execute
403 /// * `trigger_context` - Trigger context with OLD/NEW row data
404 ///
405 /// # Returns
406 /// Ok(()) if statement executed successfully
407 fn execute_statement(
408 db: &mut Database,
409 statement: &vibesql_ast::Statement,
410 trigger_context: &TriggerContext,
411 ) -> Result<(), ExecutorError> {
412 use vibesql_ast::Statement;
413
414 match statement {
415 Statement::Insert(insert_stmt) => {
416 // Execute INSERT with trigger context support
417 crate::insert::execute_insert_with_trigger_context(
418 db,
419 insert_stmt,
420 trigger_context,
421 )?;
422 Ok(())
423 }
424 Statement::Update(update_stmt) => {
425 // Execute UPDATE with trigger context support
426 crate::update::execute_update_with_trigger_context(
427 db,
428 update_stmt,
429 trigger_context,
430 )?;
431 Ok(())
432 }
433 Statement::Delete(delete_stmt) => {
434 // Execute DELETE with trigger context support
435 crate::delete::execute_delete_with_trigger_context(
436 db,
437 delete_stmt,
438 trigger_context,
439 )?;
440 Ok(())
441 }
442 Statement::Select(select_stmt) => {
443 // Execute SELECT but ignore results (useful for side effects)
444 // Note: SELECT doesn't need special trigger context handling since
445 // it can reference OLD/NEW through normal expression evaluation
446 let executor = crate::SelectExecutor::new(db);
447 executor.execute_with_columns(select_stmt)?;
448 Ok(())
449 }
450 _ => Err(ExecutorError::UnsupportedExpression(format!(
451 "Statement type not supported in triggers: {:?}",
452 statement
453 ))),
454 }
455 }
456
457 /// Execute all BEFORE ROW-level triggers for an operation
458 ///
459 /// # Arguments
460 /// * `db` - Mutable database reference
461 /// * `table_name` - Name of the table
462 /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
463 /// * `old_row` - OLD row (for UPDATE/DELETE)
464 /// * `new_row` - NEW row (for INSERT/UPDATE)
465 ///
466 /// # Returns
467 /// Ok(()) if all triggers executed successfully
468 pub fn execute_before_triggers(
469 db: &mut Database,
470 table_name: &str,
471 event: TriggerEvent,
472 old_row: Option<&Row>,
473 new_row: Option<&Row>,
474 ) -> Result<(), ExecutorError> {
475 // Check recursion depth before executing any triggers
476 let _guard = RecursionGuard::new()?;
477
478 let triggers = Self::find_triggers(db, table_name, TriggerTiming::Before, event);
479
480 // Get table schema for UPDATE OF checking
481 let table_schema = db
482 .catalog
483 .get_table(table_name)
484 .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?
485 .clone();
486
487 for trigger in triggers {
488 // Only execute ROW-level triggers in this method
489 if trigger.granularity == TriggerGranularity::Row {
490 // For UPDATE OF triggers, check if monitored columns changed
491 if let (Some(old), Some(new)) = (old_row, new_row) {
492 if !Self::should_fire_update_of(&trigger, old, new, &table_schema) {
493 continue; // Skip this trigger
494 }
495 }
496
497 Self::execute_trigger(db, &trigger, old_row, new_row)?;
498 }
499 }
500
501 Ok(())
502 }
503
504 /// Execute all BEFORE STATEMENT-level triggers for an operation
505 ///
506 /// # Arguments
507 /// * `db` - Mutable database reference
508 /// * `table_name` - Name of the table
509 /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
510 ///
511 /// # Returns
512 /// Ok(()) if all triggers executed successfully
513 pub fn execute_before_statement_triggers(
514 db: &mut Database,
515 table_name: &str,
516 event: TriggerEvent,
517 ) -> Result<(), ExecutorError> {
518 // Check recursion depth before executing any triggers
519 let _guard = RecursionGuard::new()?;
520
521 let triggers = Self::find_triggers(db, table_name, TriggerTiming::Before, event);
522
523 for trigger in triggers {
524 // Only execute STATEMENT-level triggers in this method
525 if trigger.granularity == TriggerGranularity::Statement {
526 // Statement-level triggers don't have OLD/NEW row access
527 Self::execute_trigger(db, &trigger, None, None)?;
528 }
529 }
530
531 Ok(())
532 }
533
534 /// Execute all AFTER ROW-level triggers for an operation
535 ///
536 /// # Arguments
537 /// * `db` - Mutable database reference
538 /// * `table_name` - Name of the table
539 /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
540 /// * `old_row` - OLD row (for UPDATE/DELETE)
541 /// * `new_row` - NEW row (for INSERT/UPDATE)
542 ///
543 /// # Returns
544 /// Ok(()) if all triggers executed successfully
545 pub fn execute_after_triggers(
546 db: &mut Database,
547 table_name: &str,
548 event: TriggerEvent,
549 old_row: Option<&Row>,
550 new_row: Option<&Row>,
551 ) -> Result<(), ExecutorError> {
552 // Check recursion depth before executing any triggers
553 let _guard = RecursionGuard::new()?;
554
555 let triggers = Self::find_triggers(db, table_name, TriggerTiming::After, event);
556
557 // Get table schema for UPDATE OF checking
558 let table_schema = db
559 .catalog
560 .get_table(table_name)
561 .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?
562 .clone();
563
564 for trigger in triggers {
565 // Only execute ROW-level triggers in this method
566 if trigger.granularity == TriggerGranularity::Row {
567 // For UPDATE OF triggers, check if monitored columns changed
568 if let (Some(old), Some(new)) = (old_row, new_row) {
569 if !Self::should_fire_update_of(&trigger, old, new, &table_schema) {
570 continue; // Skip this trigger
571 }
572 }
573
574 Self::execute_trigger(db, &trigger, old_row, new_row)?;
575 }
576 }
577
578 Ok(())
579 }
580
581 /// Execute all AFTER STATEMENT-level triggers for an operation
582 ///
583 /// # Arguments
584 /// * `db` - Mutable database reference
585 /// * `table_name` - Name of the table
586 /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
587 ///
588 /// # Returns
589 /// Ok(()) if all triggers executed successfully
590 pub fn execute_after_statement_triggers(
591 db: &mut Database,
592 table_name: &str,
593 event: TriggerEvent,
594 ) -> Result<(), ExecutorError> {
595 // Check recursion depth before executing any triggers
596 let _guard = RecursionGuard::new()?;
597
598 let triggers = Self::find_triggers(db, table_name, TriggerTiming::After, event);
599
600 for trigger in triggers {
601 // Only execute STATEMENT-level triggers in this method
602 if trigger.granularity == TriggerGranularity::Statement {
603 // Statement-level triggers don't have OLD/NEW row access
604 Self::execute_trigger(db, &trigger, None, None)?;
605 }
606 }
607
608 Ok(())
609 }
610}