vibesql_executor/trigger_execution.rs
1//! Trigger execution logic for firing triggers on DML operations
2
3use std::cell::Cell;
4
5use vibesql_ast::{PseudoTable, TriggerEvent, TriggerGranularity, TriggerTiming};
6use vibesql_catalog::{TableSchema, TriggerDefinition};
7use vibesql_storage::{Database, Row};
8use vibesql_types::SqlValue;
9
10use crate::errors::ExecutorError;
11
12/// Maximum trigger recursion depth to prevent infinite loops
13const MAX_TRIGGER_RECURSION_DEPTH: usize = 16;
14
15thread_local! {
16 /// Current trigger recursion depth for this thread
17 static TRIGGER_RECURSION_DEPTH: Cell<usize> = const { Cell::new(0) };
18}
19
20/// RAII guard for managing trigger recursion depth
21/// Increments depth on creation, decrements on drop
22struct RecursionGuard;
23
24impl RecursionGuard {
25 /// Create a new recursion guard, incrementing the depth
26 ///
27 /// # Returns
28 /// Ok(RecursionGuard) if depth is within limits, Err if limit exceeded
29 fn new() -> Result<Self, ExecutorError> {
30 TRIGGER_RECURSION_DEPTH.with(|depth| {
31 let current = depth.get();
32 if current >= MAX_TRIGGER_RECURSION_DEPTH {
33 Err(ExecutorError::UnsupportedExpression(format!(
34 "Trigger recursion depth limit exceeded (max: {}). Possible infinite trigger loop.",
35 MAX_TRIGGER_RECURSION_DEPTH
36 )))
37 } else {
38 depth.set(current + 1);
39 Ok(RecursionGuard)
40 }
41 })
42 }
43}
44
45impl Drop for RecursionGuard {
46 fn drop(&mut self) {
47 TRIGGER_RECURSION_DEPTH.with(|depth| {
48 depth.set(depth.get().saturating_sub(1));
49 });
50 }
51}
52
53/// Execution context for triggers with OLD/NEW row access
54/// Provides pseudo-variable resolution for trigger bodies
55pub struct TriggerContext<'a> {
56 /// OLD row - available for UPDATE and DELETE triggers
57 pub old_row: Option<&'a Row>,
58 /// NEW row - available for INSERT and UPDATE triggers
59 pub new_row: Option<&'a Row>,
60 /// Table schema for column lookups
61 pub table_schema: &'a TableSchema,
62}
63
64impl<'a> TriggerContext<'a> {
65 /// Resolve a pseudo-variable reference to a SqlValue
66 ///
67 /// # Arguments
68 /// * `pseudo_table` - Which pseudo-table (OLD or NEW)
69 /// * `column` - Column name to retrieve
70 ///
71 /// # Returns
72 /// Ok(SqlValue) with the column value, or Err if invalid
73 ///
74 /// # Errors
75 /// - If OLD/NEW is not available for this trigger type
76 /// - If column doesn't exist in table schema
77 pub fn resolve_pseudo_var(
78 &self,
79 pseudo_table: PseudoTable,
80 column: &str,
81 ) -> Result<SqlValue, ExecutorError> {
82 // Get the appropriate row
83 let row = match pseudo_table {
84 PseudoTable::Old => self.old_row.ok_or_else(|| {
85 ExecutorError::UnsupportedExpression(
86 "OLD pseudo-variable not available in this trigger context".to_string(),
87 )
88 })?,
89 PseudoTable::New => self.new_row.ok_or_else(|| {
90 ExecutorError::UnsupportedExpression(
91 "NEW pseudo-variable not available in this trigger context".to_string(),
92 )
93 })?,
94 };
95
96 // Find column index in schema
97 let col_idx =
98 self.table_schema.columns.iter().position(|c| c.name == column).ok_or_else(|| {
99 ExecutorError::ColumnNotFound {
100 column_name: column.to_string(),
101 table_name: self.table_schema.name.clone(),
102 searched_tables: vec![self.table_schema.name.clone()],
103 available_columns: self
104 .table_schema
105 .columns
106 .iter()
107 .map(|c| c.name.clone())
108 .collect(),
109 }
110 })?;
111
112 // Return the value
113 Ok(row.values[col_idx].clone())
114 }
115}
116
117/// Helper struct for trigger firing (execution during DML operations)
118pub struct TriggerFirer;
119
120impl TriggerFirer {
121 /// Find triggers for a table and event
122 ///
123 /// # Arguments
124 /// * `db` - Database reference
125 /// * `table_name` - Name of the table to find triggers for
126 /// * `timing` - Trigger timing (BEFORE, AFTER, INSTEAD OF)
127 /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
128 ///
129 /// # Returns
130 /// Vector of trigger definitions matching the criteria, sorted by creation order
131 pub fn find_triggers(
132 db: &Database,
133 table_name: &str,
134 timing: TriggerTiming,
135 event: TriggerEvent,
136 ) -> Vec<TriggerDefinition> {
137 db.catalog
138 .get_triggers_for_table(table_name, Some(event.clone()))
139 .filter(|trigger| trigger.timing == timing && trigger.enabled) // Skip disabled triggers
140 .cloned()
141 .collect()
142 }
143
144 /// Check if an UPDATE OF trigger should fire based on which columns changed
145 ///
146 /// # Arguments
147 /// * `trigger` - Trigger definition
148 /// * `old_row` - OLD row values
149 /// * `new_row` - NEW row values
150 /// * `table_schema` - Table schema for column lookup
151 ///
152 /// # Returns
153 /// true if the trigger should fire, false otherwise
154 fn should_fire_update_of(
155 trigger: &TriggerDefinition,
156 old_row: &Row,
157 new_row: &Row,
158 table_schema: &TableSchema,
159 ) -> bool {
160 match &trigger.event {
161 TriggerEvent::Update(Some(columns)) => {
162 // Check if any of the specified columns changed
163 for col_name in columns {
164 if let Some(col_idx) =
165 table_schema.columns.iter().position(|c| &c.name == col_name)
166 {
167 if col_idx < old_row.values.len()
168 && col_idx < new_row.values.len()
169 && old_row.values[col_idx] != new_row.values[col_idx]
170 {
171 return true; // At least one monitored column changed
172 }
173 }
174 }
175 false // None of the monitored columns changed
176 }
177 _ => true, // Not an UPDATE OF trigger, always fire
178 }
179 }
180
181 /// Execute a single trigger
182 ///
183 /// # Arguments
184 /// * `db` - Mutable database reference
185 /// * `trigger` - Trigger definition to execute
186 /// * `old_row` - OLD row for UPDATE/DELETE (None for INSERT)
187 /// * `new_row` - NEW row for INSERT/UPDATE (None for DELETE)
188 ///
189 /// # Returns
190 /// Ok(()) if trigger executed successfully, Err if execution failed
191 ///
192 /// # Notes
193 /// - For ROW-level triggers, this is called once per affected row
194 /// - For STATEMENT-level triggers, this is called once per statement
195 /// - WHEN conditions are evaluated here
196 pub fn execute_trigger(
197 db: &mut Database,
198 trigger: &TriggerDefinition,
199 old_row: Option<&Row>,
200 new_row: Option<&Row>,
201 ) -> Result<(), ExecutorError> {
202 // 1. Evaluate WHEN condition (if present)
203 if let Some(when_expr) = &trigger.when_condition {
204 let condition_result = Self::evaluate_when_condition(
205 db,
206 &trigger.table_name,
207 when_expr,
208 old_row,
209 new_row,
210 )?;
211
212 // Skip trigger execution if WHEN condition is false
213 if !condition_result {
214 return Ok(());
215 }
216 }
217
218 // 2. Execute trigger action
219 Self::execute_trigger_action(db, trigger, old_row, new_row)?;
220
221 Ok(())
222 }
223
224 /// Evaluate WHEN condition for a trigger
225 ///
226 /// # Arguments
227 /// * `db` - Database reference
228 /// * `table_name` - Name of the table
229 /// * `when_expr` - WHEN condition expression
230 /// * `old_row` - OLD row (for UPDATE/DELETE)
231 /// * `new_row` - NEW row (for INSERT/UPDATE)
232 ///
233 /// # Returns
234 /// Ok(true) if condition evaluates to true, Ok(false) otherwise
235 fn evaluate_when_condition(
236 db: &Database,
237 table_name: &str,
238 when_expr: &vibesql_ast::Expression,
239 old_row: Option<&Row>,
240 new_row: Option<&Row>,
241 ) -> Result<bool, ExecutorError> {
242 // Get table schema
243 let schema = db
244 .catalog
245 .get_table(table_name)
246 .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?;
247
248 // Use NEW row as the base row for evaluation (prefer NEW over OLD)
249 // The trigger context will handle OLD/NEW pseudo-variable references
250 let row = new_row.or(old_row).ok_or_else(|| {
251 ExecutorError::UnsupportedExpression(
252 "WHEN condition requires a row context".to_string(),
253 )
254 })?;
255
256 // Create trigger context for OLD/NEW pseudo-variable resolution
257 let trigger_context = TriggerContext { old_row, new_row, table_schema: schema };
258
259 // Create evaluator with trigger context
260 let evaluator =
261 crate::ExpressionEvaluator::with_trigger_context(schema, db, &trigger_context);
262 let result = evaluator.eval(when_expr, row)?;
263
264 // Convert to boolean
265 match result {
266 vibesql_types::SqlValue::Boolean(b) => Ok(b),
267 vibesql_types::SqlValue::Null => Ok(false),
268 _ => Err(ExecutorError::UnsupportedExpression(
269 "WHEN condition must evaluate to boolean".to_string(),
270 )),
271 }
272 }
273
274 /// Execute trigger action statements
275 ///
276 /// # Arguments
277 /// * `db` - Mutable database reference
278 /// * `trigger` - Trigger definition
279 /// * `old_row` - OLD row (for UPDATE/DELETE)
280 /// * `new_row` - NEW row (for INSERT/UPDATE)
281 ///
282 /// # Returns
283 /// Ok(()) if action executed successfully, Err if execution failed
284 fn execute_trigger_action(
285 db: &mut Database,
286 trigger: &TriggerDefinition,
287 old_row: Option<&Row>,
288 new_row: Option<&Row>,
289 ) -> Result<(), ExecutorError> {
290 // Extract SQL from trigger action
291 let sql = match &trigger.triggered_action {
292 vibesql_ast::TriggerAction::RawSql(sql) => sql.clone(),
293 };
294
295 // Parse the trigger action SQL
296 let statements = Self::parse_trigger_sql(&sql)?;
297
298 // Get table schema for trigger context (clone to avoid borrow checker issues)
299 let schema = db
300 .catalog
301 .get_table(&trigger.table_name)
302 .ok_or_else(|| ExecutorError::TableNotFound(trigger.table_name.clone()))?
303 .clone();
304
305 // Create trigger context for OLD/NEW pseudo-variable resolution
306 let trigger_context = TriggerContext { old_row, new_row, table_schema: &schema };
307
308 // Execute each statement in the trigger body with trigger context
309 for statement in statements {
310 Self::execute_statement(db, &statement, &trigger_context)?;
311 }
312
313 Ok(())
314 }
315
316 /// Parse trigger SQL into statements
317 ///
318 /// # Arguments
319 /// * `sql` - Raw SQL string from trigger action
320 ///
321 /// # Returns
322 /// Vector of parsed statements
323 fn parse_trigger_sql(sql: &str) -> Result<Vec<vibesql_ast::Statement>, ExecutorError> {
324 // Strip BEGIN/END wrapper if present
325 let sql = sql.trim();
326 let sql = if sql.to_uppercase().starts_with("BEGIN") {
327 // Remove BEGIN and END
328 let sql = sql[5..].trim();
329 if sql.to_uppercase().ends_with("END") {
330 &sql[..sql.len() - 3]
331 } else {
332 sql
333 }
334 } else {
335 sql
336 };
337
338 // Split by semicolons and parse each statement
339 let mut statements = Vec::new();
340 for stmt_sql in sql.split(';') {
341 let stmt_sql = stmt_sql.trim();
342 if stmt_sql.is_empty() || stmt_sql.starts_with("--") {
343 // Skip empty statements or comments
344 continue;
345 }
346
347 match vibesql_parser::Parser::parse_sql(stmt_sql) {
348 Ok(stmt) => statements.push(stmt),
349 Err(e) => {
350 return Err(ExecutorError::UnsupportedExpression(format!(
351 "Failed to parse trigger SQL: {}",
352 e.message
353 )))
354 }
355 }
356 }
357
358 // If no statements parsed (e.g., trigger body was only comments), that's OK
359 // Just return empty vector
360 Ok(statements)
361 }
362
363 /// Execute a single statement from trigger body
364 ///
365 /// # Arguments
366 /// * `db` - Mutable database reference
367 /// * `statement` - Statement to execute
368 /// * `trigger_context` - Trigger context with OLD/NEW row data
369 ///
370 /// # Returns
371 /// Ok(()) if statement executed successfully
372 fn execute_statement(
373 db: &mut Database,
374 statement: &vibesql_ast::Statement,
375 trigger_context: &TriggerContext,
376 ) -> Result<(), ExecutorError> {
377 use vibesql_ast::Statement;
378
379 match statement {
380 Statement::Insert(insert_stmt) => {
381 // Execute INSERT with trigger context support
382 crate::insert::execute_insert_with_trigger_context(
383 db,
384 insert_stmt,
385 trigger_context,
386 )?;
387 Ok(())
388 }
389 Statement::Update(update_stmt) => {
390 // Execute UPDATE with trigger context support
391 crate::update::execute_update_with_trigger_context(
392 db,
393 update_stmt,
394 trigger_context,
395 )?;
396 Ok(())
397 }
398 Statement::Delete(delete_stmt) => {
399 // Execute DELETE with trigger context support
400 crate::delete::execute_delete_with_trigger_context(
401 db,
402 delete_stmt,
403 trigger_context,
404 )?;
405 Ok(())
406 }
407 Statement::Select(select_stmt) => {
408 // Execute SELECT but ignore results (useful for side effects)
409 // Note: SELECT doesn't need special trigger context handling since
410 // it can reference OLD/NEW through normal expression evaluation
411 let executor = crate::SelectExecutor::new(db);
412 executor.execute_with_columns(select_stmt)?;
413 Ok(())
414 }
415 _ => Err(ExecutorError::UnsupportedExpression(format!(
416 "Statement type not supported in triggers: {:?}",
417 statement
418 ))),
419 }
420 }
421
422 /// Execute all BEFORE ROW-level triggers for an operation
423 ///
424 /// # Arguments
425 /// * `db` - Mutable database reference
426 /// * `table_name` - Name of the table
427 /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
428 /// * `old_row` - OLD row (for UPDATE/DELETE)
429 /// * `new_row` - NEW row (for INSERT/UPDATE)
430 ///
431 /// # Returns
432 /// Ok(()) if all triggers executed successfully
433 pub fn execute_before_triggers(
434 db: &mut Database,
435 table_name: &str,
436 event: TriggerEvent,
437 old_row: Option<&Row>,
438 new_row: Option<&Row>,
439 ) -> Result<(), ExecutorError> {
440 // Check recursion depth before executing any triggers
441 let _guard = RecursionGuard::new()?;
442
443 let triggers = Self::find_triggers(db, table_name, TriggerTiming::Before, event);
444
445 // Get table schema for UPDATE OF checking
446 let table_schema = db
447 .catalog
448 .get_table(table_name)
449 .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?
450 .clone();
451
452 for trigger in triggers {
453 // Only execute ROW-level triggers in this method
454 if trigger.granularity == TriggerGranularity::Row {
455 // For UPDATE OF triggers, check if monitored columns changed
456 if let (Some(old), Some(new)) = (old_row, new_row) {
457 if !Self::should_fire_update_of(&trigger, old, new, &table_schema) {
458 continue; // Skip this trigger
459 }
460 }
461
462 Self::execute_trigger(db, &trigger, old_row, new_row)?;
463 }
464 }
465
466 Ok(())
467 }
468
469 /// Execute all BEFORE STATEMENT-level triggers for an operation
470 ///
471 /// # Arguments
472 /// * `db` - Mutable database reference
473 /// * `table_name` - Name of the table
474 /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
475 ///
476 /// # Returns
477 /// Ok(()) if all triggers executed successfully
478 pub fn execute_before_statement_triggers(
479 db: &mut Database,
480 table_name: &str,
481 event: TriggerEvent,
482 ) -> Result<(), ExecutorError> {
483 // Check recursion depth before executing any triggers
484 let _guard = RecursionGuard::new()?;
485
486 let triggers = Self::find_triggers(db, table_name, TriggerTiming::Before, event);
487
488 for trigger in triggers {
489 // Only execute STATEMENT-level triggers in this method
490 if trigger.granularity == TriggerGranularity::Statement {
491 // Statement-level triggers don't have OLD/NEW row access
492 Self::execute_trigger(db, &trigger, None, None)?;
493 }
494 }
495
496 Ok(())
497 }
498
499 /// Execute all AFTER ROW-level triggers for an operation
500 ///
501 /// # Arguments
502 /// * `db` - Mutable database reference
503 /// * `table_name` - Name of the table
504 /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
505 /// * `old_row` - OLD row (for UPDATE/DELETE)
506 /// * `new_row` - NEW row (for INSERT/UPDATE)
507 ///
508 /// # Returns
509 /// Ok(()) if all triggers executed successfully
510 pub fn execute_after_triggers(
511 db: &mut Database,
512 table_name: &str,
513 event: TriggerEvent,
514 old_row: Option<&Row>,
515 new_row: Option<&Row>,
516 ) -> Result<(), ExecutorError> {
517 // Check recursion depth before executing any triggers
518 let _guard = RecursionGuard::new()?;
519
520 let triggers = Self::find_triggers(db, table_name, TriggerTiming::After, event);
521
522 // Get table schema for UPDATE OF checking
523 let table_schema = db
524 .catalog
525 .get_table(table_name)
526 .ok_or_else(|| ExecutorError::TableNotFound(table_name.to_string()))?
527 .clone();
528
529 for trigger in triggers {
530 // Only execute ROW-level triggers in this method
531 if trigger.granularity == TriggerGranularity::Row {
532 // For UPDATE OF triggers, check if monitored columns changed
533 if let (Some(old), Some(new)) = (old_row, new_row) {
534 if !Self::should_fire_update_of(&trigger, old, new, &table_schema) {
535 continue; // Skip this trigger
536 }
537 }
538
539 Self::execute_trigger(db, &trigger, old_row, new_row)?;
540 }
541 }
542
543 Ok(())
544 }
545
546 /// Execute all AFTER STATEMENT-level triggers for an operation
547 ///
548 /// # Arguments
549 /// * `db` - Mutable database reference
550 /// * `table_name` - Name of the table
551 /// * `event` - Trigger event (INSERT, UPDATE, DELETE)
552 ///
553 /// # Returns
554 /// Ok(()) if all triggers executed successfully
555 pub fn execute_after_statement_triggers(
556 db: &mut Database,
557 table_name: &str,
558 event: TriggerEvent,
559 ) -> Result<(), ExecutorError> {
560 // Check recursion depth before executing any triggers
561 let _guard = RecursionGuard::new()?;
562
563 let triggers = Self::find_triggers(db, table_name, TriggerTiming::After, event);
564
565 for trigger in triggers {
566 // Only execute STATEMENT-level triggers in this method
567 if trigger.granularity == TriggerGranularity::Statement {
568 // Statement-level triggers don't have OLD/NEW row access
569 Self::execute_trigger(db, &trigger, None, None)?;
570 }
571 }
572
573 Ok(())
574 }
575}